diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 3710e24a..2b0be7d8 100755 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -7,6 +7,7 @@ WrapperImageModel, WrapperMetaModel, ) +from .graphcast import GraphCast, GraphCastConfig from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/graphcast/__init__.py b/graph_weather/models/graphcast/__init__.py new file mode 100644 index 00000000..38d02220 --- /dev/null +++ b/graph_weather/models/graphcast/__init__.py @@ -0,0 +1,5 @@ +"""GraphCast model with gradient checkpointing.""" + +from .model import GraphCast, GraphCastConfig + +__all__ = ["GraphCast", "GraphCastConfig"] diff --git a/graph_weather/models/graphcast/model.py b/graph_weather/models/graphcast/model.py new file mode 100644 index 00000000..f9cba2f2 --- /dev/null +++ b/graph_weather/models/graphcast/model.py @@ -0,0 +1,345 @@ +"""GraphCast model with hierarchical gradient checkpointing. + +This module provides a complete GraphCast-style weather forecasting model +with NVIDIA-style hierarchical gradient checkpointing for memory-efficient training. + +Based on: +- NVIDIA PhysicsNeMo GraphCast implementation +""" + +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint + +from graph_weather.models.layers.decoder import Decoder +from graph_weather.models.layers.encoder import Encoder +from graph_weather.models.layers.processor import Processor + + +class GraphCast(torch.nn.Module): + """GraphCast model with hierarchical gradient checkpointing. + + This model combines Encoder, Processor, and Decoder with NVIDIA-style + hierarchical checkpointing controls for flexible memory-compute tradeoffs. + + Hierarchical checkpointing methods: + - set_checkpoint_model(flag): Checkpoint entire forward pass + - set_checkpoint_encoder(flag): Checkpoint encoder section + - set_checkpoint_processor(segments): Checkpoint processor with configurable segments + - set_checkpoint_decoder(flag): Checkpoint decoder section + """ + + def __init__( + self, + lat_lons: list, + resolution: int = 2, + input_dim: int = 78, + output_dim: int = 78, + hidden_dim: int = 256, + num_processor_blocks: int = 9, + hidden_layers: int = 2, + mlp_norm_type: str = "LayerNorm", + use_checkpointing: bool = False, + efficient_batching: bool = False, + ): + """ + Initialize GraphCast model with hierarchical checkpointing support. + + Args: + lat_lons: List of (lat, lon) tuples defining the grid points + resolution: H3 resolution level + input_dim: Input feature dimension + output_dim: Output feature dimension + hidden_dim: Hidden dimension for all layers + num_processor_blocks: Number of message passing blocks in processor + hidden_layers: Number of hidden layers in MLPs + mlp_norm_type: Normalization type for MLPs + use_checkpointing: Enable fine-grained checkpointing in all layers + efficient_batching: Use efficient batching (avoid graph replication) + """ + super().__init__() + + self.lat_lons = lat_lons + self.input_dim = input_dim + self.output_dim = output_dim + self.efficient_batching = efficient_batching + + # Initialize components + self.encoder = Encoder( + lat_lons=lat_lons, + resolution=resolution, + input_dim=input_dim, + output_dim=hidden_dim, + output_edge_dim=hidden_dim, + hidden_dim_processor_node=hidden_dim, + hidden_dim_processor_edge=hidden_dim, + hidden_layers_processor_node=hidden_layers, + hidden_layers_processor_edge=hidden_layers, + mlp_norm_type=mlp_norm_type, + use_checkpointing=use_checkpointing, + efficient_batching=efficient_batching, + ) + + self.processor = Processor( + input_dim=hidden_dim, + edge_dim=hidden_dim, + num_blocks=num_processor_blocks, + hidden_dim_processor_node=hidden_dim, + hidden_dim_processor_edge=hidden_dim, + hidden_layers_processor_node=hidden_layers, + hidden_layers_processor_edge=hidden_layers, + mlp_norm_type=mlp_norm_type, + use_checkpointing=use_checkpointing, + ) + + self.decoder = Decoder( + lat_lons=lat_lons, + resolution=resolution, + input_dim=hidden_dim, + output_dim=output_dim, + hidden_dim_processor_node=hidden_dim, + hidden_dim_processor_edge=hidden_dim, + hidden_layers_processor_node=hidden_layers, + hidden_layers_processor_edge=hidden_layers, + mlp_norm_type=mlp_norm_type, + hidden_dim_decoder=hidden_dim, + hidden_layers_decoder=hidden_layers, + use_checkpointing=use_checkpointing, + efficient_batching=efficient_batching, + ) + + # Hierarchical checkpointing flags (default: use fine-grained checkpointing) + self._checkpoint_model = False + self._checkpoint_encoder = False + self._checkpoint_processor_segments = 0 # 0 = use layer's internal checkpointing + self._checkpoint_decoder = False + + def set_checkpoint_model(self, checkpoint_flag: bool): + """ + Checkpoint entire model as a single segment. + + When enabled, creates one checkpoint for the entire forward pass. + This provides maximum memory savings but highest recomputation cost. + Disables all other hierarchical checkpointing when enabled. + + Args: + checkpoint_flag: If True, checkpoint entire model. If False, use hierarchical checkpointing. + """ + self._checkpoint_model = checkpoint_flag + if checkpoint_flag: + # Disable all fine-grained checkpointing + self._checkpoint_encoder = False + self._checkpoint_processor_segments = 0 + self._checkpoint_decoder = False + + def set_checkpoint_encoder(self, checkpoint_flag: bool): + """ + Checkpoint encoder section. + + Checkpoints the encoder forward pass as a single segment. + Only effective when set_checkpoint_model(False). + + Args: + checkpoint_flag: If True, checkpoint encoder section. + """ + self._checkpoint_encoder = checkpoint_flag + + def set_checkpoint_processor(self, checkpoint_segments: int): + """ + Checkpoint processor with configurable segments. + + Controls how the processor is checkpointed: + - 0: Use processor's internal per-block checkpointing + - -1: Checkpoint entire processor as one segment + - N > 0: Checkpoint every N blocks (not yet implemented) + + Only effective when set_checkpoint_model(False). + + Args: + checkpoint_segments: Checkpointing strategy (0, -1, or positive integer). + """ + self._checkpoint_processor_segments = checkpoint_segments + + def set_checkpoint_decoder(self, checkpoint_flag: bool): + """ + Checkpoint decoder section. + + Checkpoints the decoder forward pass as a single segment. + Only effective when set_checkpoint_model(False). + + Args: + checkpoint_flag: If True, checkpoint decoder section. + """ + self._checkpoint_decoder = checkpoint_flag + + def _encoder_forward(self, features: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Encoder forward pass (for checkpointing). + """ + return self.encoder(features) + + def _processor_forward( + self, + x: Tensor, + edge_index: Tensor, + edge_attr: Tensor, + batch_size: Optional[int] = None, + ) -> Tensor: + """ + Processor forward pass (for checkpointing). + """ + return self.processor( + x, + edge_index, + edge_attr, + batch_size=batch_size, + efficient_batching=self.efficient_batching, + ) + + def _decoder_forward( + self, + processed_features: Tensor, + original_features: Tensor, + batch_size: int, + ) -> Tensor: + """ + Decoder forward pass (for checkpointing). + """ + return self.decoder(processed_features, original_features, batch_size) + + def _custom_forward(self, features: Tensor) -> Tensor: + """ + Forward pass with hierarchical checkpointing. + """ + batch_size = features.shape[0] + + # Encoder + if self._checkpoint_encoder: + latent_features, edge_index, edge_attr = checkpoint( + self._encoder_forward, + features, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + latent_features, edge_index, edge_attr = self.encoder(features) + + # Processor + if self._checkpoint_processor_segments == -1: + # Checkpoint entire processor as one block + processed_features = checkpoint( + self._processor_forward, + latent_features, + edge_index, + edge_attr, + batch_size if self.efficient_batching else None, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + # Use processor's internal checkpointing (controlled by use_checkpointing) + processed_features = self.processor( + latent_features, + edge_index, + edge_attr, + batch_size=batch_size, + efficient_batching=self.efficient_batching, + ) + + # Decoder + if self._checkpoint_decoder: + output = checkpoint( + self._decoder_forward, + processed_features, + features, + batch_size, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + output = self.decoder(processed_features, features, batch_size) + + return output + + def forward(self, features: Tensor) -> Tensor: + """Forward pass through GraphCast model. + + Args: + features: Input features of shape [batch_size, num_points, input_dim] + + Returns: + Output predictions of shape [batch_size, num_points, output_dim] + """ + if self._checkpoint_model: + # Checkpoint entire model as one segment + return checkpoint( + self._custom_forward, + features, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + # Use hierarchical checkpointing + return self._custom_forward(features) + + +class GraphCastConfig: + """Configuration helper for GraphCast checkpointing strategies. + + Provides pre-defined checkpointing strategies for different use cases. + """ + + @staticmethod + def no_checkpointing(model: GraphCast): + """ + Disable all checkpointing (maximum speed, maximum memory). + """ + model.set_checkpoint_model(False) + model.set_checkpoint_encoder(False) + model.set_checkpoint_processor(0) + model.set_checkpoint_decoder(False) + + @staticmethod + def full_checkpointing(model: GraphCast): + """ + Checkpoint entire model (maximum memory savings, slowest). + """ + model.set_checkpoint_model(True) + + @staticmethod + def balanced_checkpointing(model: GraphCast): + """ + Balanced strategy (good memory savings, moderate speed). + """ + model.set_checkpoint_model(False) + model.set_checkpoint_encoder(True) + model.set_checkpoint_processor(-1) + model.set_checkpoint_decoder(True) + + @staticmethod + def processor_only_checkpointing(model: GraphCast): + """ + Checkpoint only processor (targets main memory bottleneck). + """ + model.set_checkpoint_model(False) + model.set_checkpoint_encoder(False) + model.set_checkpoint_processor(-1) + model.set_checkpoint_decoder(False) + + @staticmethod + def fine_grained_checkpointing(model: GraphCast): + """ + Fine-grained per-layer checkpointing (best memory savings). + + This checkpoints each individual MLP and processor block separately. + Provides the best memory savings with moderate recomputation cost. + Note: Model must be created with use_checkpointing=True. + """ + # Fine-grained is enabled via use_checkpointing=True in __init__ + # This just disables hierarchical checkpointing + model.set_checkpoint_model(False) + model.set_checkpoint_encoder(False) + model.set_checkpoint_processor(0) + model.set_checkpoint_decoder(False) diff --git a/graph_weather/models/layers/assimilator_decoder.py b/graph_weather/models/layers/assimilator_decoder.py index 03470504..28127f7f 100755 --- a/graph_weather/models/layers/assimilator_decoder.py +++ b/graph_weather/models/layers/assimilator_decoder.py @@ -117,6 +117,7 @@ def __init__( hidden_layers_node=hidden_layers_processor_node, hidden_layers_edge=hidden_layers_processor_edge, norm_type=mlp_norm_type, + use_checkpointing=self.use_checkpointing, ) self.node_decoder = MLP( input_dim, diff --git a/graph_weather/models/layers/encoder.py b/graph_weather/models/layers/encoder.py index 8470df05..56c54ae5 100755 --- a/graph_weather/models/layers/encoder.py +++ b/graph_weather/models/layers/encoder.py @@ -147,6 +147,7 @@ def __init__( hidden_layers_processor_node, hidden_layers_processor_edge, mlp_norm_type, + use_checkpointing=self.use_checkpointing, ) def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/graph_weather/models/layers/graph_net_block.py b/graph_weather/models/layers/graph_net_block.py index 4be21121..ffcefeba 100755 --- a/graph_weather/models/layers/graph_net_block.py +++ b/graph_weather/models/layers/graph_net_block.py @@ -71,7 +71,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: The transformed tensor """ if self.use_checkpointing: - out = checkpoint(self.model, x, use_reentrant=False) + out = checkpoint(self.model, x, use_reentrant=False, preserve_rng_state=False) else: out = self.model(x) return out @@ -241,6 +241,7 @@ def __init__( hidden_layers_node: int = 2, hidden_layers_edge: int = 2, norm_type: str = "LayerNorm", + use_checkpointing: bool = False, ): """ Graph Processor @@ -255,9 +256,11 @@ def __init__( hidden_layers_edge: Number of hidden layers for edge processing norm_type: Normalization type one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None + use_checkpointing: Whether to use gradient checkpointing """ super(GraphProcessor, self).__init__() + self.use_checkpointing = use_checkpointing self.blocks = nn.ModuleList() for _ in range(mp_iterations): @@ -288,6 +291,11 @@ def forward( Updated nodes and edge attributes """ for block in self.blocks: - x, edge_attr, _ = block(x, edge_index, edge_attr) + if self.use_checkpointing: + x, edge_attr, _ = checkpoint( + block, x, edge_index, edge_attr, use_reentrant=False, preserve_rng_state=False + ) + else: + x, edge_attr, _ = block(x, edge_index, edge_attr) return x, edge_attr diff --git a/graph_weather/models/layers/processor.py b/graph_weather/models/layers/processor.py index 2572cb43..653878d7 100755 --- a/graph_weather/models/layers/processor.py +++ b/graph_weather/models/layers/processor.py @@ -28,6 +28,7 @@ def __init__( hidden_layers_processor_edge: int = 2, mlp_norm_type: str = "LayerNorm", use_thermalizer: bool = False, + use_checkpointing: bool = False, ): """ Latent graph processor @@ -43,12 +44,14 @@ def __init__( mlp_norm_type: Type of norm for the MLPs one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None use_thermalizer: Whether to use the thermalizer layer + use_checkpointing: Whether to use gradient checkpointing """ super().__init__() # Build the default graph # Take features from encoder and put into processor graph self.input_dim = input_dim self.use_thermalizer = use_thermalizer + self.checkpoint_segments = 0 # 0 = use layer's internal checkpointing self.graph_processor = GraphProcessor( num_blocks, @@ -59,10 +62,24 @@ def __init__( hidden_layers_processor_node, hidden_layers_processor_edge, mlp_norm_type, + use_checkpointing, ) if self.use_thermalizer: self.thermalizer = ThermalizerLayer(input_dim) + def set_checkpoint_segments(self, checkpoint_segments: int): + """Sets checkpoint segments for the processor. + + This matches NVIDIA's API for controlling processor checkpointing. + + Args: + checkpoint_segments: Number of checkpointing segments for gradient computation. + - 0: Use processor's internal per-block checkpointing (controlled by use_checkpointing) + - -1: Checkpoint entire processor as one segment (handled at GraphCast level) + - N > 0: Checkpoint every N blocks (not yet implemented) + """ + self.checkpoint_segments = checkpoint_segments + def forward( self, x: torch.Tensor, diff --git a/scripts/benchmark_memory.py b/scripts/benchmark_memory.py deleted file mode 100644 index 44425322..00000000 --- a/scripts/benchmark_memory.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Benchmark memory usage for efficient batching.""" - -import gc -import json -import time -from datetime import datetime -from typing import Dict, List, Tuple - -import numpy as np -import torch - -from graph_weather.models.layers.decoder import Decoder -from graph_weather.models.layers.encoder import Encoder -from graph_weather.models.layers.processor import Processor - - -def get_memory_stats() -> Dict[str, float]: - """Get current memory usage statistics.""" - try: - import psutil - - process = psutil.Process() - cpu_mb = process.memory_info().rss / 1024 / 1024 - except ImportError: - cpu_mb = 0 - - stats = {"cpu_memory_mb": cpu_mb} - - if torch.cuda.is_available(): - stats["gpu_memory_mb"] = torch.cuda.memory_allocated() / 1024 / 1024 - stats["gpu_peak_mb"] = torch.cuda.max_memory_allocated() / 1024 / 1024 - else: - stats["gpu_memory_mb"] = 0 - stats["gpu_peak_mb"] = 0 - - return stats - - -def create_lat_lon_grid(resolution_deg: float) -> List[Tuple[float, float]]: - """Create a lat/lon grid at specified resolution.""" - lat_lons = [] - lats = np.arange(-90, 90, resolution_deg) - lons = np.arange(0, 360, resolution_deg) - for lat in lats: - for lon in lons: - lat_lons.append((float(lat), float(lon))) - return lat_lons - - -def benchmark_config( - resolution_deg: float, - batch_size: int, - device: str = "cuda", - num_iterations: int = 3, -) -> Dict: - """Benchmark a specific configuration.""" - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - lat_lons = create_lat_lon_grid(resolution_deg) - num_nodes = len(lat_lons) - - try: - encoder = ( - Encoder(lat_lons, resolution=2, input_dim=78, output_dim=256, efficient_batching=True) - .to(device) - .eval() - ) - processor = Processor(256, num_blocks=9).to(device).eval() - decoder = ( - Decoder(lat_lons, resolution=2, input_dim=256, output_dim=78, efficient_batching=True) - .to(device) - .eval() - ) - - features = torch.randn(batch_size, num_nodes, 78, device=device) - - # Warmup - with torch.no_grad(): - x, edge_idx, edge_attr = encoder(features) - x = processor(x, edge_idx, edge_attr, batch_size=batch_size, efficient_batching=True) - _ = decoder(x, features) - - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - - # Benchmark - times = [] - with torch.no_grad(): - for _ in range(num_iterations): - start = time.time() - x, edge_idx, edge_attr = encoder(features) - x = processor( - x, edge_idx, edge_attr, batch_size=batch_size, efficient_batching=True - ) - output = decoder(x, features) - if torch.cuda.is_available(): - torch.cuda.synchronize() - times.append(time.time() - start) - - mem_stats = get_memory_stats() - - result = { - "resolution_deg": resolution_deg, - "batch_size": batch_size, - "num_nodes": num_nodes, - "device": device, - "success": True, - "avg_time_s": float(np.mean(times)), - "std_time_s": float(np.std(times)), - **mem_stats, - } - - del encoder, processor, decoder, features, x, edge_idx, edge_attr, output - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return result - - except RuntimeError as e: - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - return { - "resolution_deg": resolution_deg, - "batch_size": batch_size, - "num_nodes": num_nodes, - "device": device, - "success": False, - "error": str(e), - } - - -def run_benchmarks() -> List[Dict]: - """Run benchmark suite.""" - device = "cuda" if torch.cuda.is_available() else "cpu" - - configs = [ - (5.0, [1, 2, 4, 8]), - (2.5, [1, 2, 4]), - (1.0, [1]), - ] - - results = [] - for resolution_deg, batch_sizes in configs: - for batch_size in batch_sizes: - result = benchmark_config(resolution_deg, batch_size, device) - results.append(result) - - if not result["success"] and "out of memory" in result.get("error", "").lower(): - break - - return results - - -if __name__ == "__main__": - results = run_benchmarks() - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - output_file = f"benchmark_results_{timestamp}.json" - - with open(output_file, "w") as f: - json.dump(results, f, indent=2) - - print(f"Results saved to {output_file}") diff --git a/scripts/benchmark_memory_optimizations.py b/scripts/benchmark_memory_optimizations.py new file mode 100644 index 00000000..d1793c01 --- /dev/null +++ b/scripts/benchmark_memory_optimizations.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +"""Benchmark memory optimizations (efficient batching + gradient checkpointing). + +This script benchmarks memory usage and performance for different optimization strategies: +- Efficient batching (avoids graph replication) +- Gradient checkpointing (trades compute for memory) +- Combined optimizations + +Results are saved to JSON files. +""" + +import gc +import json +import time +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn + +from graph_weather.models import GraphCast, GraphCastConfig +from graph_weather.models.layers.decoder import Decoder +from graph_weather.models.layers.encoder import Encoder +from graph_weather.models.layers.processor import Processor + + +@dataclass +class BenchmarkResult: + """Results from a single benchmark run.""" + + test_type: str # "efficient_batching" or "gradient_checkpointing" + grid_resolution: float + batch_size: int + num_grid_points: int + efficient_batching: bool + use_checkpointing: bool + checkpoint_strategy: Optional[str] + peak_memory_mb: float + allocated_memory_mb: float + forward_time_ms: Optional[float] + backward_time_ms: Optional[float] + total_time_ms: Optional[float] + success: bool + error: Optional[str] = None + + +def get_memory_stats() -> Dict[str, float]: + """Get current memory usage statistics.""" + stats = {} + if torch.cuda.is_available(): + stats["peak_memory_mb"] = torch.cuda.max_memory_allocated() / 1024 / 1024 + stats["allocated_memory_mb"] = torch.cuda.memory_allocated() / 1024 / 1024 + else: + stats["peak_memory_mb"] = 0 + stats["allocated_memory_mb"] = 0 + return stats + + +def reset_memory(): + """Reset memory tracking.""" + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + gc.collect() + + +def create_lat_lon_grid(resolution_deg: float) -> List[Tuple[float, float]]: + """Create a lat/lon grid at specified resolution.""" + lat_lons = [] + lats = np.arange(-90, 90, resolution_deg) + lons = np.arange(0, 360, resolution_deg) + for lat in lats: + for lon in lons: + lat_lons.append((float(lat), float(lon))) + return lat_lons + + +def benchmark_efficient_batching( + resolution_deg: float, + batch_size: int, + device: str = "cuda", + num_iterations: int = 3, +) -> BenchmarkResult: + """Benchmark efficient batching optimization (inference mode).""" + reset_memory() + + lat_lons = create_lat_lon_grid(resolution_deg) + num_nodes = len(lat_lons) + + try: + encoder = Encoder( + lat_lons, resolution=2, input_dim=78, output_dim=256, efficient_batching=True + ).to(device) + processor = Processor(256, num_blocks=9).to(device) + decoder = Decoder( + lat_lons, resolution=2, input_dim=256, output_dim=78, efficient_batching=True + ).to(device) + + encoder.eval() + processor.eval() + decoder.eval() + + features = torch.randn(batch_size, num_nodes, 78, device=device) + + # Warmup + with torch.no_grad(): + x, edge_idx, edge_attr = encoder(features) + x = processor(x, edge_idx, edge_attr, batch_size=batch_size, efficient_batching=True) + _ = decoder(x, features, batch_size) + + reset_memory() + + # Benchmark + times = [] + with torch.no_grad(): + for _ in range(num_iterations): + start = time.time() + x, edge_idx, edge_attr = encoder(features) + x = processor( + x, edge_idx, edge_attr, batch_size=batch_size, efficient_batching=True + ) + output = decoder(x, features, batch_size) + if torch.cuda.is_available(): + torch.cuda.synchronize() + times.append(time.time() - start) + + mem_stats = get_memory_stats() + + del encoder, processor, decoder, features, x, edge_idx, edge_attr, output + reset_memory() + + return BenchmarkResult( + test_type="efficient_batching", + grid_resolution=resolution_deg, + batch_size=batch_size, + num_grid_points=num_nodes, + efficient_batching=True, + use_checkpointing=False, + checkpoint_strategy=None, + forward_time_ms=None, + backward_time_ms=None, + total_time_ms=float(np.mean(times)) * 1000, + success=True, + **mem_stats, + ) + + except RuntimeError as e: + reset_memory() + return BenchmarkResult( + test_type="efficient_batching", + grid_resolution=resolution_deg, + batch_size=batch_size, + num_grid_points=num_nodes, + efficient_batching=True, + use_checkpointing=False, + checkpoint_strategy=None, + peak_memory_mb=0, + allocated_memory_mb=0, + forward_time_ms=None, + backward_time_ms=None, + total_time_ms=None, + success=False, + error=str(e), + ) + + +def benchmark_gradient_checkpointing( + resolution_deg: float, + batch_size: int, + strategy: str, + device: str = "cuda", + num_iterations: int = 3, +) -> BenchmarkResult: + """Benchmark gradient checkpointing optimization (training mode).""" + reset_memory() + + lat_lons = create_lat_lon_grid(resolution_deg) + num_nodes = len(lat_lons) + + try: + # Create model + use_fine_grained = strategy == "fine_grained" + model = GraphCast( + lat_lons=lat_lons, + resolution=2, + input_dim=78, + output_dim=78, + hidden_dim=256, + num_processor_blocks=9, + hidden_layers=2, + mlp_norm_type="LayerNorm", + use_checkpointing=use_fine_grained, + efficient_batching=True, + ).to(device) + + # Apply checkpointing strategy + if strategy == "none": + GraphCastConfig.no_checkpointing(model) + elif strategy == "full": + GraphCastConfig.full_checkpointing(model) + elif strategy == "balanced": + GraphCastConfig.balanced_checkpointing(model) + elif strategy == "processor_only": + GraphCastConfig.processor_only_checkpointing(model) + elif strategy == "fine_grained": + GraphCastConfig.no_checkpointing(model) + else: + raise ValueError(f"Unknown strategy: {strategy}") + + model.train() + + features = torch.randn(batch_size, num_nodes, 78, device=device) + target = torch.randn(batch_size, num_nodes, 78, device=device) + + # Warmup + output = model(features) + loss = nn.functional.mse_loss(output, target) + loss.backward() + model.zero_grad() + + reset_memory() + + # Benchmark + forward_times = [] + backward_times = [] + + for _ in range(num_iterations): + # Forward pass + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_time = time.perf_counter() + + output = model(features) + loss = nn.functional.mse_loss(output, target) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + forward_time = (time.perf_counter() - start_time) * 1000 + forward_times.append(forward_time) + + # Backward pass + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_time = time.perf_counter() + + loss.backward() + + if torch.cuda.is_available(): + torch.cuda.synchronize() + backward_time = (time.perf_counter() - start_time) * 1000 + backward_times.append(backward_time) + + model.zero_grad() + + mem_stats = get_memory_stats() + + del model, features, target, output, loss + reset_memory() + + return BenchmarkResult( + test_type="gradient_checkpointing", + grid_resolution=resolution_deg, + batch_size=batch_size, + num_grid_points=num_nodes, + efficient_batching=True, + use_checkpointing=(strategy != "none"), + checkpoint_strategy=strategy, + forward_time_ms=float(np.mean(forward_times)), + backward_time_ms=float(np.mean(backward_times)), + total_time_ms=float(np.mean(forward_times) + np.mean(backward_times)), + success=True, + **mem_stats, + ) + + except RuntimeError as e: + reset_memory() + return BenchmarkResult( + test_type="gradient_checkpointing", + grid_resolution=resolution_deg, + batch_size=batch_size, + num_grid_points=num_nodes, + efficient_batching=True, + use_checkpointing=(strategy != "none"), + checkpoint_strategy=strategy, + peak_memory_mb=0, + allocated_memory_mb=0, + forward_time_ms=None, + backward_time_ms=None, + total_time_ms=None, + success=False, + error=str(e), + ) + + +def run_efficient_batching_benchmarks() -> List[Dict]: + """Run efficient batching benchmarks.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + + configs = [ + (5.0, [1, 2, 4, 8]), + (2.5, [1, 2, 4]), + (1.0, [1]), + ] + + results = [] + for resolution_deg, batch_sizes in configs: + for batch_size in batch_sizes: + result = benchmark_efficient_batching(resolution_deg, batch_size, device) + results.append(result) + + if not result.success and "out of memory" in result.error.lower(): + break + + return [vars(r) for r in results] + + +def run_gradient_checkpointing_benchmarks() -> List[Dict]: + """Run gradient checkpointing benchmarks.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + + configs = [ + (5.0, [1, 2, 4, 8]), + (2.5, [1, 2, 4]), + (1.0, [1]), + ] + + strategies = ["none", "fine_grained", "processor_only", "balanced", "full"] + + results = [] + for resolution_deg, batch_sizes in configs: + skip_rest = False + for batch_size in batch_sizes: + if skip_rest: + continue + + for strategy in strategies: + result = benchmark_gradient_checkpointing( + resolution_deg, batch_size, strategy, device + ) + results.append(result) + + # Check if all strategies failed + batch_results = results[-len(strategies) :] + if all(not r.success for r in batch_results): + skip_rest = True + + return [vars(r) for r in results] + + +def main(): + """Run all benchmarks.""" + import argparse + + parser = argparse.ArgumentParser(description="Benchmark memory optimizations") + parser.add_argument( + "--mode", + choices=["efficient_batching", "gradient_checkpointing", "all"], + default="all", + help="Which benchmarks to run", + ) + args = parser.parse_args() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + if args.mode in ["efficient_batching", "all"]: + results = run_efficient_batching_benchmarks() + output_file = f"benchmark_efficient_batching_{timestamp}.json" + with open(output_file, "w") as f: + json.dump(results, f, indent=2) + + if args.mode in ["gradient_checkpointing", "all"]: + results = run_gradient_checkpointing_benchmarks() + output_file = f"benchmark_gradient_checkpointing_{timestamp}.json" + with open(output_file, "w") as f: + json.dump(results, f, indent=2) + + +if __name__ == "__main__": + main() diff --git a/tests/models/test_gradient_checkpointing.py b/tests/models/test_gradient_checkpointing.py new file mode 100644 index 00000000..771a58cf --- /dev/null +++ b/tests/models/test_gradient_checkpointing.py @@ -0,0 +1,443 @@ +"""Tests for gradient checkpointing implementation. + +This test suite ensures that gradient checkpointing: +1. Produces identical outputs to non-checkpointed versions +2. Reduces memory usage during training +3. Works correctly with all checkpointing strategies +4. Maintains backward compatibility +5. Works with efficient batching +""" + +import numpy as np +import pytest +import torch + +from graph_weather.models import GraphCast, GraphCastConfig +from graph_weather.models.layers.decoder import Decoder +from graph_weather.models.layers.encoder import Encoder +from graph_weather.models.layers.processor import Processor + + +def create_lat_lon_grid(resolution_deg: float): + """Create a lat/lon grid at specified resolution.""" + lat_lons = [] + lats = np.arange(-90, 90, resolution_deg) + lons = np.arange(0, 360, resolution_deg) + for lat in lats: + for lon in lons: + lat_lons.append((float(lat), float(lon))) + return lat_lons + + +# Layer-level checkpointing tests + + +@pytest.mark.parametrize("use_checkpointing", [False, True]) +def test_encoder_checkpointing(use_checkpointing): + """Test that Encoder with checkpointing produces identical outputs.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 2 + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78)) + + encoder = Encoder( + lat_lons, + resolution=2, + input_dim=78, + output_dim=256, + use_checkpointing=use_checkpointing, + efficient_batching=True, + ) + encoder.eval() + + with torch.no_grad(): + x, edge_idx, edge_attr = encoder(features) + + # Verify output shape + assert x.shape[1] == 256 + assert edge_idx.shape[0] == 2 + + +@pytest.mark.parametrize("use_checkpointing", [False, True]) +def test_processor_checkpointing(use_checkpointing): + """Test that Processor with checkpointing produces identical outputs.""" + batch_size = 2 + num_nodes = 5882 # H3 resolution 2 + num_edges = 41162 + + torch.manual_seed(42) + x = torch.randn((batch_size * num_nodes, 256)) + edge_index = torch.randint(0, num_nodes, (2, num_edges)) + edge_attr = torch.randn((num_edges, 256)) + + processor = Processor( + input_dim=256, + edge_dim=256, + num_blocks=9, + use_checkpointing=use_checkpointing, + ) + processor.eval() + + with torch.no_grad(): + out = processor(x, edge_index, edge_attr, batch_size=batch_size, efficient_batching=True) + + # Verify output shape matches input + assert out.shape == x.shape + + +@pytest.mark.parametrize("use_checkpointing", [False, True]) +def test_decoder_checkpointing(use_checkpointing): + """Test that Decoder with checkpointing produces identical outputs.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 2 + num_h3 = 5882 + + torch.manual_seed(42) + processor_features = torch.randn((batch_size * num_h3, 256)) + start_features = torch.randn((batch_size, len(lat_lons), 78)) + + decoder = Decoder( + lat_lons, + resolution=2, + input_dim=256, + output_dim=78, + use_checkpointing=use_checkpointing, + efficient_batching=True, + ) + decoder.eval() + + with torch.no_grad(): + out = decoder(processor_features, start_features, batch_size) + + # Verify output shape + assert out.shape == start_features.shape + + +# Output equivalence tests + + +def test_encoder_output_equivalence(): + """Verify encoder outputs are identical with/without checkpointing.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 2 + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78)) + + # Without checkpointing + encoder_no_cp = Encoder( + lat_lons, resolution=2, use_checkpointing=False, efficient_batching=True + ) + encoder_no_cp.eval() + + # With checkpointing + encoder_with_cp = Encoder( + lat_lons, resolution=2, use_checkpointing=True, efficient_batching=True + ) + encoder_with_cp.load_state_dict(encoder_no_cp.state_dict()) + encoder_with_cp.eval() + + with torch.no_grad(): + x_no_cp, edge_idx_no_cp, edge_attr_no_cp = encoder_no_cp(features) + x_with_cp, edge_idx_with_cp, edge_attr_with_cp = encoder_with_cp(features) + + # Verify outputs are identical + assert torch.allclose(x_no_cp, x_with_cp, atol=1e-6) + assert torch.equal(edge_idx_no_cp, edge_idx_with_cp) + assert torch.allclose(edge_attr_no_cp, edge_attr_with_cp, atol=1e-6) + + +def test_full_pipeline_output_equivalence(): + """Verify full pipeline outputs are identical with/without checkpointing.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 2 + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78)) + + # Create models without checkpointing + encoder_no_cp = Encoder(lat_lons, use_checkpointing=False, efficient_batching=True) + processor_no_cp = Processor(256, num_blocks=9, use_checkpointing=False) + decoder_no_cp = Decoder( + lat_lons, input_dim=256, use_checkpointing=False, efficient_batching=True + ) + + encoder_no_cp.eval() + processor_no_cp.eval() + decoder_no_cp.eval() + + # Run without checkpointing + with torch.no_grad(): + x, edge_idx, edge_attr = encoder_no_cp(features) + x = processor_no_cp(x, edge_idx, edge_attr, batch_size=batch_size, efficient_batching=True) + out_no_cp = decoder_no_cp(x, features, batch_size) + + # Create models with checkpointing + encoder_with_cp = Encoder(lat_lons, use_checkpointing=True, efficient_batching=True) + processor_with_cp = Processor(256, num_blocks=9, use_checkpointing=True) + decoder_with_cp = Decoder( + lat_lons, input_dim=256, use_checkpointing=True, efficient_batching=True + ) + + # Load same weights + encoder_with_cp.load_state_dict(encoder_no_cp.state_dict()) + processor_with_cp.load_state_dict(processor_no_cp.state_dict()) + decoder_with_cp.load_state_dict(decoder_no_cp.state_dict()) + + encoder_with_cp.eval() + processor_with_cp.eval() + decoder_with_cp.eval() + + # Run with checkpointing + with torch.no_grad(): + x, edge_idx, edge_attr = encoder_with_cp(features) + x = processor_with_cp( + x, edge_idx, edge_attr, batch_size=batch_size, efficient_batching=True + ) + out_with_cp = decoder_with_cp(x, features, batch_size) + + # Verify outputs are identical + assert torch.allclose(out_no_cp, out_with_cp, atol=1e-5) + + +# GraphCast model tests + + +def test_graphcast_no_checkpointing(): + """Test GraphCast with no checkpointing.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 2 + + model = GraphCast(lat_lons, use_checkpointing=False, efficient_batching=True) + GraphCastConfig.no_checkpointing(model) + model.eval() + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78)) + + with torch.no_grad(): + output = model(features) + + assert output.shape == features.shape + + +def test_graphcast_full_checkpointing(): + """Test GraphCast with full model checkpointing.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 2 + + model = GraphCast(lat_lons, use_checkpointing=False, efficient_batching=True) + GraphCastConfig.full_checkpointing(model) + model.eval() + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78)) + + with torch.no_grad(): + output = model(features) + + assert output.shape == features.shape + + +def test_graphcast_balanced_checkpointing(): + """Test GraphCast with balanced checkpointing.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 2 + + model = GraphCast(lat_lons, use_checkpointing=False, efficient_batching=True) + GraphCastConfig.balanced_checkpointing(model) + model.eval() + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78)) + + with torch.no_grad(): + output = model(features) + + assert output.shape == features.shape + + +def test_graphcast_processor_only_checkpointing(): + """Test GraphCast with processor-only checkpointing.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 2 + + model = GraphCast(lat_lons, use_checkpointing=False, efficient_batching=True) + GraphCastConfig.processor_only_checkpointing(model) + model.eval() + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78)) + + with torch.no_grad(): + output = model(features) + + assert output.shape == features.shape + + +def test_graphcast_output_equivalence(): + """Test that all checkpointing strategies produce identical outputs.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 1 # Use smaller batch for equivalence test + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78)) + + # No checkpointing (baseline) + model_baseline = GraphCast(lat_lons, use_checkpointing=False, efficient_batching=True) + GraphCastConfig.no_checkpointing(model_baseline) + model_baseline.eval() + + with torch.no_grad(): + output_baseline = model_baseline(features) + + # Test each strategy + strategies = [ + ("full", GraphCastConfig.full_checkpointing), + ("balanced", GraphCastConfig.balanced_checkpointing), + ("processor_only", GraphCastConfig.processor_only_checkpointing), + ] + + for strategy_name, strategy_fn in strategies: + model = GraphCast(lat_lons, use_checkpointing=False, efficient_batching=True) + model.load_state_dict(model_baseline.state_dict()) + strategy_fn(model) + model.eval() + + with torch.no_grad(): + output = model(features) + + # Verify output is identical to baseline + assert torch.allclose( + output_baseline, output, atol=1e-5 + ), f"{strategy_name} checkpointing produced different output" + + +# Backward compatibility tests + + +def test_original_api_still_works(): + """Test that the original API (without GraphCast wrapper) still works.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 2 + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78)) + + # Original usage pattern + encoder = Encoder(lat_lons, efficient_batching=True) + processor = Processor(256, num_blocks=9) + decoder = Decoder(lat_lons, efficient_batching=True) + + encoder.eval() + processor.eval() + decoder.eval() + + with torch.no_grad(): + x, edge_idx, edge_attr = encoder(features) + x = processor(x, edge_idx, edge_attr, batch_size=batch_size, efficient_batching=True) + output = decoder(x, features, batch_size) + + assert output.shape == features.shape + + +def test_efficient_batching_unaffected(): + """Test that efficient batching still works with checkpointing.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 4 + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78)) + + # With efficient batching AND checkpointing + model = GraphCast(lat_lons, use_checkpointing=True, efficient_batching=True) + GraphCastConfig.balanced_checkpointing(model) + model.eval() + + with torch.no_grad(): + output = model(features) + + assert output.shape == features.shape + + +# Gradient flow tests + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") +def test_backward_pass_with_checkpointing(): + """Test that backward pass works with checkpointing.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 1 + + device = torch.device("cuda") + model = GraphCast(lat_lons, use_checkpointing=True, efficient_batching=True).to(device) + GraphCastConfig.balanced_checkpointing(model) + model.train() + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78), device=device) + target = torch.randn((batch_size, len(lat_lons), 78), device=device) + + # Forward pass + output = model(features) + + # Backward pass + loss = torch.nn.functional.mse_loss(output, target) + loss.backward() + + # Verify gradients exist + has_grads = False + for param in model.parameters(): + if param.grad is not None: + has_grads = True + assert not torch.isnan(param.grad).any(), "NaN in gradients" + assert not torch.isinf(param.grad).any(), "Inf in gradients" + + assert has_grads, "No gradients computed" + + +def test_gradient_equivalence(): + """Test that gradients are identical with/without checkpointing.""" + lat_lons = create_lat_lon_grid(resolution_deg=10.0) + batch_size = 1 + + torch.manual_seed(42) + features = torch.randn((batch_size, len(lat_lons), 78)) + target = torch.randn((batch_size, len(lat_lons), 78)) + + # Model without checkpointing + model_no_cp = GraphCast(lat_lons, use_checkpointing=False, efficient_batching=True) + GraphCastConfig.no_checkpointing(model_no_cp) + model_no_cp.train() + + output_no_cp = model_no_cp(features) + loss_no_cp = torch.nn.functional.mse_loss(output_no_cp, target) + loss_no_cp.backward() + + # Collect gradients + grads_no_cp = [] + for param in model_no_cp.parameters(): + if param.grad is not None: + grads_no_cp.append(param.grad.clone()) + + # Model with checkpointing + model_with_cp = GraphCast(lat_lons, use_checkpointing=False, efficient_batching=True) + model_with_cp.load_state_dict(model_no_cp.state_dict()) + GraphCastConfig.balanced_checkpointing(model_with_cp) + model_with_cp.train() + + output_with_cp = model_with_cp(features) + loss_with_cp = torch.nn.functional.mse_loss(output_with_cp, target) + loss_with_cp.backward() + + # Collect gradients + grads_with_cp = [] + for param in model_with_cp.parameters(): + if param.grad is not None: + grads_with_cp.append(param.grad.clone()) + + # Compare gradients + assert len(grads_no_cp) == len(grads_with_cp) + for g1, g2 in zip(grads_no_cp, grads_with_cp): + assert torch.allclose(g1, g2, atol=1e-5), "Gradients differ with checkpointing"