Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions graph_weather/models/data_assimilation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Data assimilation module initialization."""

from .data_assimilation_base import DataAssimilationBase, EnsembleGenerator
from .kalman_filter_da import KalmanFilterDA
from .particle_filter_da import ParticleFilterDA
from .variational_da import VariationalDA

__all__ = [
"DataAssimilationBase",
"EnsembleGenerator",
"KalmanFilterDA",
"ParticleFilterDA",
"VariationalDA",
]
110 changes: 110 additions & 0 deletions graph_weather/models/data_assimilation/data_assimilation_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Base classes for data assimilation modules."""

import abc
from typing import Any, Dict, Union

import torch
from torch_geometric.data import Data


class EnsembleGenerator:
"""Class to generate ensemble members from a background state."""

def __init__(self, noise_std: float = 0.1, method: str = "gaussian"):
self.noise_std = noise_std
self.method = method

def generate_ensemble(self, state: Union[torch.Tensor, Data], num_members: int):
if isinstance(state, torch.Tensor):
return self._generate_tensor_ensemble(state, num_members)
elif isinstance(state, Data):
return self._generate_graph_ensemble(state, num_members)
else:
raise TypeError(f"Unsupported state type: {type(state)}")

def _generate_tensor_ensemble(self, state: torch.Tensor, num_members: int) -> torch.Tensor:
batch_size, nodes, features = state.shape
ensemble = torch.zeros(batch_size, num_members, nodes, features, device=state.device)

for i in range(num_members):
if self.method == "gaussian":
noise = torch.randn_like(state) * self.noise_std
ensemble[:, i] = state + noise
elif self.method == "dropout":
mask = torch.bernoulli(torch.ones_like(state) * 0.9) # Keep 90% of values
noise = torch.randn_like(state) * self.noise_std * 0.1
ensemble[:, i] = (state * mask) + noise
elif self.method == "perturbation":
perturbation = (
torch.randn_like(state)
* self.noise_std
* torch.linspace(0.1, 1.0, num_members)[i]
)
ensemble[:, i] = state + perturbation
else:
raise ValueError(f"Unknown ensemble generation method: {self.method}")

return ensemble

def _generate_graph_ensemble(self, state: Data, num_members: int) -> Data:
x_expanded = torch.zeros(
state.x.shape[0], num_members, state.x.shape[1], device=state.x.device
)

for i in range(num_members):
if self.method == "gaussian":
noise = torch.randn_like(state.x) * self.noise_std
x_expanded[:, i] = state.x + noise
elif self.method == "dropout":
mask = torch.bernoulli(torch.ones_like(state.x) * 0.9)
noise = torch.randn_like(state.x) * self.noise_std * 0.1
x_expanded[:, i] = (state.x * mask) + noise
elif self.method == "perturbation":
perturbation = (
torch.randn_like(state.x)
* self.noise_std
* torch.linspace(0.1, 1.0, num_members)[i]
)
x_expanded[:, i] = state.x + perturbation
else:
raise ValueError(f"Unknown ensemble generation method: {self.method}")

new_state = Data(
x=x_expanded,
edge_index=state.edge_index,
edge_attr=getattr(state, "edge_attr", None),
pos=getattr(state, "pos", None),
)

return new_state


class DataAssimilationBase(abc.ABC):
"""Abstract base class for data assimilation modules."""

def __init__(self, config: Dict[str, Any]):
self.config = config
self.ensemble_generator = EnsembleGenerator(
noise_std=config.get("noise_std", 0.1), method=config.get("ensemble_method", "gaussian")
)

@abc.abstractmethod
def initialize_ensemble(self, background_state: Union[torch.Tensor, Data], num_members: int):
pass

@abc.abstractmethod
def assimilate(self, ensemble: Union[torch.Tensor, Data], observations: torch.Tensor):
pass

@abc.abstractmethod
def _compute_analysis(self, ensemble: Union[torch.Tensor, Data]) -> Union[torch.Tensor, Data]:
pass

def forward(
self, state: Union[torch.Tensor, Data], observations: torch.Tensor, num_ensemble: int = 10
):
ensemble = self.initialize_ensemble(state, num_ensemble)
updated_ensemble = self.assimilate(ensemble, observations)
analysis = self._compute_analysis(updated_ensemble)

return updated_ensemble, analysis
99 changes: 99 additions & 0 deletions tests/models/data_assimilation/test_data_assimilation_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
import torch
from torch_geometric.data import Data

import sys

sys.path.insert(0, "../../../graph_weather/models/data_assimilation")

# Execute modules directly to avoid import issues
exec(open("graph_weather/models/data_assimilation/data_assimilation_base.py").read())


class MockDA(DataAssimilationBase):
"""Mock implementation of DataAssimilationBase for testing purposes."""

def initialize_ensemble(self, background_state, num_members):
return self.ensemble_generator.generate_ensemble(background_state, num_members)

def assimilate(self, ensemble, observations):
return ensemble # Return unchanged for testing

def _compute_analysis(self, ensemble):
if isinstance(ensemble, torch.Tensor):
return torch.mean(ensemble, dim=1)
elif isinstance(ensemble, Data):
return ensemble # Return as is for testing
else:
raise TypeError(f"Unsupported ensemble type: {type(ensemble)}")


def test_ensemble_generator_tensor():
"""Test ensemble generation for tensor inputs."""
generator = EnsembleGenerator(noise_std=0.1, method="gaussian")

# Test tensor input
state = torch.randn(2, 5, 3) # [batch, nodes, features]
ensemble = generator.generate_ensemble(state, 4)

assert ensemble.shape == (2, 4, 5, 3) # [batch, members, nodes, features]
assert not torch.equal(state, ensemble[:, 0]) # Should have noise added


def test_ensemble_generator_graph():
"""Test ensemble generation for graph inputs."""
generator = EnsembleGenerator(noise_std=0.1, method="gaussian")

# Test graph input
x = torch.randn(10, 4) # Node features
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
graph_state = Data(x=x, edge_index=edge_index)

ensemble = generator.generate_ensemble(graph_state, 3)

# Check that ensemble preserves structure
assert hasattr(ensemble, "x")
assert hasattr(ensemble, "edge_index")
assert ensemble.x.shape[1] == 3 # Ensemble dimension


def test_data_assimilation_base_abstract_methods():
"""Test that abstract methods are properly defined."""
config = {"param": "value"}
da_module = MockDA(config)

assert da_module.config == config

# Test ensemble generation
state = torch.randn(2, 5, 3)
ensemble = da_module.initialize_ensemble(state, 4)
assert ensemble.shape == (2, 4, 5, 3)


def test_compute_analysis_tensor():
"""Test analysis computation for tensor ensembles."""
da_module = MockDA({})

# Create ensemble: [batch, members, nodes, features]
ensemble = torch.stack(
[
torch.ones(2, 5, 3), # First member
2 * torch.ones(2, 5, 3), # Second member
3 * torch.ones(2, 5, 3), # Third member
],
dim=1,
) # Shape: [2, 3, 5, 3]

analysis = da_module._compute_analysis(ensemble)

# Mean should be (1 + 2 + 3) / 3 = 2
expected = 2 * torch.ones(2, 5, 3)
assert torch.allclose(analysis, expected)


if __name__ == "__main__":
test_ensemble_generator_tensor()
test_ensemble_generator_graph()
test_data_assimilation_base_abstract_methods()
test_compute_analysis_tensor()
print("All tests passed!")