1+ """
2+ Data Module for AI-based Data Assimilation
13
4+ Handles the loading and preprocessing of first-guess states and observations
5+ for the AI-based assimilation approach.
6+ """
27
38import warnings
49from typing import Dict , Optional , Tuple
1116
1217
1318class AIAssimilationDataset (Dataset ):
19+ """
20+ Dataset for AI-based data assimilation.
21+
22+ Each sample contains a first-guess state and corresponding observations.
23+ The dataset is designed to work with self-supervised learning where
24+ no ground-truth analysis is required.
25+ """
1426
1527 def __init__ (
1628 self ,
1729 first_guess_states : torch .Tensor ,
1830 observations : torch .Tensor ,
1931 observation_locations : Optional [torch .Tensor ] = None ,
2032 ):
21-
33+ """
34+ Initialize the AI assimilation dataset.
35+
36+ Args:
37+ first_guess_states: First-guess states (background) [num_samples, state_size]
38+ observations: Observation values [num_samples, obs_size]
39+ observation_locations: Optional tensor indicating observation locations
40+ """
2241 self .first_guess_states = first_guess_states
2342 self .observations = observations
2443 self .observation_locations = observation_locations
@@ -28,10 +47,19 @@ def __init__(
2847 assert first_guess_states .shape [0 ] == observations .shape [0 ], msg
2948
3049 def __len__ (self ) -> int :
50+ """Return the number of samples in the dataset."""
3151 return len (self .first_guess_states )
3252
3353 def __getitem__ (self , idx : int ) -> Dict [str , torch .Tensor ]:
34-
54+ """
55+ Get a sample from the dataset.
56+
57+ Args:
58+ idx: Index of the sample
59+
60+ Returns:
61+ Dictionary containing first_guess, observations, and optionally locations
62+ """
3563 sample = {
3664 "first_guess" : self .first_guess_states [idx ],
3765 "observations" : self .observations [idx ],
@@ -52,7 +80,21 @@ def generate_synthetic_assimilation_data(
5280 spatial_correlation : bool = False ,
5381 grid_shape : Optional [Tuple [int , int ]] = None ,
5482) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
55-
83+ """
84+ Generate synthetic data for AI-based data assimilation experiments.
85+
86+ Args:
87+ num_samples: Number of samples to generate
88+ state_size: Size of the state vector
89+ obs_fraction: Fraction of state variables that have observations
90+ bg_error_std: Standard deviation of background (first-guess) errors
91+ obs_error_std: Standard deviation of observation errors
92+ spatial_correlation: Whether to add spatial correlation to the data
93+ grid_shape: Shape of spatial grid if applicable (h, w)
94+
95+ Returns:
96+ Tuple of (first_guess, observations, true_state) tensors
97+ """
5698 # Generate a true state with possible spatial correlation
5799 if spatial_correlation and grid_shape is not None :
58100 h , w = grid_shape
@@ -116,6 +158,11 @@ def generate_synthetic_assimilation_data(
116158
117159
118160class AIAssimilationDataModule :
161+ """
162+ Data module for AI-based assimilation following PyTorch Lightning pattern.
163+
164+ Handles data splits and provides train/val/test loaders.
165+ """
119166
120167 def __init__ (
121168 self ,
@@ -131,7 +178,22 @@ def __init__(
131178 spatial_correlation : bool = False ,
132179 grid_shape : Optional [Tuple [int , int ]] = None ,
133180 ):
134-
181+ """
182+ Initialize the AI assimilation data module.
183+
184+ Args:
185+ num_samples: Number of total samples
186+ state_size: Size of state vector
187+ obs_fraction: Fraction of observed values
188+ bg_error_std: Background error standard deviation
189+ obs_error_std: Observation error standard deviation
190+ batch_size: Batch size for data loaders
191+ train_ratio: Fraction for training
192+ val_ratio: Fraction for validation
193+ test_ratio: Fraction for testing
194+ spatial_correlation: Whether to include spatial correlation
195+ grid_shape: Shape of spatial grid if applicable
196+ """
135197 self .num_samples = num_samples
136198 self .state_size = state_size
137199 self .obs_fraction = obs_fraction
@@ -153,6 +215,12 @@ def __init__(
153215 self .test_loader = None
154216
155217 def setup (self , stage : Optional [str ] = None ):
218+ """
219+ Setup the datasets and data loaders.
220+
221+ Args:
222+ stage: Stage of training (fit, validate, test, predict)
223+ """
156224 # Generate synthetic data
157225 first_guess , observations , true_state = generate_synthetic_assimilation_data (
158226 num_samples = self .num_samples ,
@@ -198,7 +266,17 @@ def test_dataloader(self) -> DataLoader:
198266def create_observation_operator (
199267 state_size : int , obs_size : int , obs_locations : Optional [np .ndarray ] = None
200268) -> torch .Tensor :
201-
269+ """
270+ Create an observation operator matrix H that maps state space to observation space.
271+
272+ Args:
273+ state_size: Size of the state vector
274+ obs_size: Size of the observation vector
275+ obs_locations: Specific locations of observations (indices in state vector)
276+
277+ Returns:
278+ Observation operator H [obs_size, state_size]
279+ """
202280 if obs_locations is None :
203281 # Randomly select observation locations
204282 obs_indices = np .random .choice (state_size , size = obs_size , replace = False )
@@ -215,4 +293,4 @@ def create_observation_operator(
215293 if 0 <= idx < state_size :
216294 H [i , idx ] = 1.0
217295
218- return H
296+ return H
0 commit comments