Skip to content

Commit 3ade764

Browse files
committed
ruff checks has passed
1 parent 570a721 commit 3ade764

File tree

4 files changed

+408
-167
lines changed

4 files changed

+408
-167
lines changed

graph_weather/models/ai_assimilation/data.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
"""
2+
Data Module for AI-based Data Assimilation
13
4+
Handles the loading and preprocessing of first-guess states and observations
5+
for the AI-based assimilation approach.
6+
"""
27

38
import warnings
49
from typing import Dict, Optional, Tuple
@@ -11,14 +16,28 @@
1116

1217

1318
class AIAssimilationDataset(Dataset):
19+
"""
20+
Dataset for AI-based data assimilation.
21+
22+
Each sample contains a first-guess state and corresponding observations.
23+
The dataset is designed to work with self-supervised learning where
24+
no ground-truth analysis is required.
25+
"""
1426

1527
def __init__(
1628
self,
1729
first_guess_states: torch.Tensor,
1830
observations: torch.Tensor,
1931
observation_locations: Optional[torch.Tensor] = None,
2032
):
21-
33+
"""
34+
Initialize the AI assimilation dataset.
35+
36+
Args:
37+
first_guess_states: First-guess states (background) [num_samples, state_size]
38+
observations: Observation values [num_samples, obs_size]
39+
observation_locations: Optional tensor indicating observation locations
40+
"""
2241
self.first_guess_states = first_guess_states
2342
self.observations = observations
2443
self.observation_locations = observation_locations
@@ -28,10 +47,19 @@ def __init__(
2847
assert first_guess_states.shape[0] == observations.shape[0], msg
2948

3049
def __len__(self) -> int:
50+
"""Return the number of samples in the dataset."""
3151
return len(self.first_guess_states)
3252

3353
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
34-
54+
"""
55+
Get a sample from the dataset.
56+
57+
Args:
58+
idx: Index of the sample
59+
60+
Returns:
61+
Dictionary containing first_guess, observations, and optionally locations
62+
"""
3563
sample = {
3664
"first_guess": self.first_guess_states[idx],
3765
"observations": self.observations[idx],
@@ -52,7 +80,21 @@ def generate_synthetic_assimilation_data(
5280
spatial_correlation: bool = False,
5381
grid_shape: Optional[Tuple[int, int]] = None,
5482
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
55-
83+
"""
84+
Generate synthetic data for AI-based data assimilation experiments.
85+
86+
Args:
87+
num_samples: Number of samples to generate
88+
state_size: Size of the state vector
89+
obs_fraction: Fraction of state variables that have observations
90+
bg_error_std: Standard deviation of background (first-guess) errors
91+
obs_error_std: Standard deviation of observation errors
92+
spatial_correlation: Whether to add spatial correlation to the data
93+
grid_shape: Shape of spatial grid if applicable (h, w)
94+
95+
Returns:
96+
Tuple of (first_guess, observations, true_state) tensors
97+
"""
5698
# Generate a true state with possible spatial correlation
5799
if spatial_correlation and grid_shape is not None:
58100
h, w = grid_shape
@@ -116,6 +158,11 @@ def generate_synthetic_assimilation_data(
116158

117159

118160
class AIAssimilationDataModule:
161+
"""
162+
Data module for AI-based assimilation following PyTorch Lightning pattern.
163+
164+
Handles data splits and provides train/val/test loaders.
165+
"""
119166

120167
def __init__(
121168
self,
@@ -131,7 +178,22 @@ def __init__(
131178
spatial_correlation: bool = False,
132179
grid_shape: Optional[Tuple[int, int]] = None,
133180
):
134-
181+
"""
182+
Initialize the AI assimilation data module.
183+
184+
Args:
185+
num_samples: Number of total samples
186+
state_size: Size of state vector
187+
obs_fraction: Fraction of observed values
188+
bg_error_std: Background error standard deviation
189+
obs_error_std: Observation error standard deviation
190+
batch_size: Batch size for data loaders
191+
train_ratio: Fraction for training
192+
val_ratio: Fraction for validation
193+
test_ratio: Fraction for testing
194+
spatial_correlation: Whether to include spatial correlation
195+
grid_shape: Shape of spatial grid if applicable
196+
"""
135197
self.num_samples = num_samples
136198
self.state_size = state_size
137199
self.obs_fraction = obs_fraction
@@ -153,6 +215,12 @@ def __init__(
153215
self.test_loader = None
154216

155217
def setup(self, stage: Optional[str] = None):
218+
"""
219+
Setup the datasets and data loaders.
220+
221+
Args:
222+
stage: Stage of training (fit, validate, test, predict)
223+
"""
156224
# Generate synthetic data
157225
first_guess, observations, true_state = generate_synthetic_assimilation_data(
158226
num_samples=self.num_samples,
@@ -198,7 +266,17 @@ def test_dataloader(self) -> DataLoader:
198266
def create_observation_operator(
199267
state_size: int, obs_size: int, obs_locations: Optional[np.ndarray] = None
200268
) -> torch.Tensor:
201-
269+
"""
270+
Create an observation operator matrix H that maps state space to observation space.
271+
272+
Args:
273+
state_size: Size of the state vector
274+
obs_size: Size of the observation vector
275+
obs_locations: Specific locations of observations (indices in state vector)
276+
277+
Returns:
278+
Observation operator H [obs_size, state_size]
279+
"""
202280
if obs_locations is None:
203281
# Randomly select observation locations
204282
obs_indices = np.random.choice(state_size, size=obs_size, replace=False)
@@ -215,4 +293,4 @@ def create_observation_operator(
215293
if 0 <= idx < state_size:
216294
H[i, idx] = 1.0
217295

218-
return H
296+
return H

0 commit comments

Comments
 (0)