@@ -442,6 +442,11 @@ def offload_megatron_model_to_cpu(models):
442442 # if the grad_data size is already zero, we assume that it is already offloaded
443443 buffer .grad_data_size = buffer .grad_data .storage ().size ()
444444 buffer .grad_data .storage ().resize_ (0 )
445+ # Offload frozen parameters not in DDP buffers (e.g. base model in LoRA/PEFT)
446+ # DDP buffers only contain requires_grad=True params, so frozen params must be offloaded separately.
447+ for param in model_chunk .module .parameters ():
448+ if not param .requires_grad and param .device .type != "cpu" :
449+ param .data = param .data .to ("cpu" , non_blocking = True )
445450 else :
446451 # we need this for ref module
447452 for _ , param in model_chunk .named_parameters ():
@@ -453,7 +458,14 @@ def offload_megatron_model_to_cpu(models):
453458
454459
455460@torch .no_grad ()
456- def load_megatron_model_to_gpu (models , load_grad = True ):
461+ def load_megatron_model_to_gpu (models , load_grad = True , load_frozen_params = True ):
462+ """
463+ Load megatron model to GPU.
464+ Args:
465+ models: The model to load.
466+ load_grad: Whether to load gradients.
467+ load_frozen_params: Whether to load frozen parameters.
468+ """
457469 for model_chunk in models :
458470 if isinstance (model_chunk , DDP ):
459471 model_chunk_all_buffers = [model_chunk .buffers , model_chunk .expert_parallel_buffers ]
@@ -468,6 +480,13 @@ def load_megatron_model_to_gpu(models, load_grad=True):
468480 buffer .param_data .storage ().resize_ (buffer .param_data_size )
469481 # copy data from cpu to cuda
470482 buffer .param_data .copy_ (buffer .param_data .cpu_data , non_blocking = True )
483+
484+ # Load frozen parameters that were offloaded (e.g. base model in LoRA/PEFT)
485+ if load_frozen_params :
486+ device_id = get_device_id ()
487+ for param in model_chunk .module .parameters ():
488+ if not param .requires_grad and param .device .type == "cpu" :
489+ param .data = param .data .to (device_id , non_blocking = True )
471490 else :
472491 # we need this for ref module
473492 device_id = get_device_id ()
0 commit comments