Skip to content

Commit be2e589

Browse files
committed
Merge branch 'AIasm' of https://github.com/SOHAMPAL23/graph_weather into AIasm
2 parents e7ff8b9 + bf3db23 commit be2e589

File tree

6 files changed

+147
-155
lines changed

6 files changed

+147
-155
lines changed

graph_weather/models/ai_assimilation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
This package provides neural networks that learn to produce optimal analysis states
88
by minimizing the 3D-Var cost function in a self-supervised manner, without requiring
99
ground-truth labels.
10-
"""
10+
"""

graph_weather/models/ai_assimilation/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,4 @@ def create_observation_operator(
293293
if 0 <= idx < state_size:
294294
H[i, idx] = 1.0
295295

296-
return H
296+
return H

graph_weather/models/ai_assimilation/loss.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,13 @@ def forward(
6464
bg_term = torch.sum(bg_quadratic, dim=-1)
6565
else:
6666
# Simplified: assume identity covariance (sum of squares)
67-
bg_term = torch.sum(bg_diff ** 2, dim=-1)
67+
bg_term = torch.sum(bg_diff**2, dim=-1)
6868

6969
# Observation term: (y - H x_a)^T R^{-1} (y - H x_a)
7070
if self.observation_operator is not None:
7171
# Apply observation operator
7272
hx = torch.matmul(
73-
analysis.unsqueeze(1),
74-
self.observation_operator.transpose(-1, -2)
73+
analysis.unsqueeze(1), self.observation_operator.transpose(-1, -2)
7574
).squeeze(1)
7675
else:
7776
# Identity observation operator (direct comparison)
@@ -85,7 +84,7 @@ def forward(
8584
obs_term = torch.sum(obs_quadratic, dim=-1)
8685
else:
8786
# Simplified: assume identity covariance (sum of squares)
88-
obs_term = torch.sum(obs_diff ** 2, dim=-1)
87+
obs_term = torch.sum(obs_diff**2, dim=-1)
8988

9089
# Combine terms with equal weighting (can be adjusted)
9190
total_cost = 0.5 * (torch.mean(bg_term) + torch.mean(obs_term))
@@ -157,17 +156,17 @@ def forward(
157156

158157
# Weighted combination
159158
total_loss = (
160-
self.three_d_var_weight * three_d_var_loss +
161-
self.smoothness_weight * smoothness_loss +
162-
self.conservation_weight * conservation_loss
159+
self.three_d_var_weight * three_d_var_loss
160+
+ self.smoothness_weight * smoothness_loss
161+
+ self.conservation_weight * conservation_loss
163162
)
164163

165164
# Return components for monitoring
166165
components = {
167-
'three_d_var': three_d_var_loss.item(),
168-
'smoothness': smoothness_loss.item(),
169-
'conservation': conservation_loss.item(),
170-
'total': total_loss.item()
166+
"three_d_var": three_d_var_loss.item(),
167+
"smoothness": smoothness_loss.item(),
168+
"conservation": conservation_loss.item(),
169+
"total": total_loss.item(),
171170
}
172171

173-
return total_loss, components
172+
return total_loss, components

graph_weather/models/ai_assimilation/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,4 +203,4 @@ def forward(self, batch_size: int, device: torch.device = None) -> torch.Tensor:
203203
(batch_size, self.state_size), self.init_value, device=device, dtype=torch.float32
204204
)
205205

206-
return first_guess
206+
return first_guess

0 commit comments

Comments
 (0)