Skip to content

Commit c4ac00d

Browse files
committed
[megatron] feat: enhance model offloading and loading for frozen parameters
1 parent e3b187a commit c4ac00d

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

verl/utils/megatron_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,11 @@ def offload_megatron_model_to_cpu(models):
442442
# if the grad_data size is already zero, we assume that it is already offloaded
443443
buffer.grad_data_size = buffer.grad_data.storage().size()
444444
buffer.grad_data.storage().resize_(0)
445+
# Offload frozen parameters not in DDP buffers (e.g. base model in LoRA/PEFT)
446+
# DDP buffers only contain requires_grad=True params, so frozen params must be offloaded separately.
447+
for param in model_chunk.module.parameters():
448+
if not param.requires_grad and param.device.type != "cpu":
449+
param.data = param.data.to("cpu", non_blocking=True)
445450
else:
446451
# we need this for ref module
447452
for _, param in model_chunk.named_parameters():
@@ -453,7 +458,14 @@ def offload_megatron_model_to_cpu(models):
453458

454459

455460
@torch.no_grad()
456-
def load_megatron_model_to_gpu(models, load_grad=True):
461+
def load_megatron_model_to_gpu(models, load_grad=True, load_frozen_params=True):
462+
"""
463+
Load megatron model to GPU.
464+
Args:
465+
models: The model to load.
466+
load_grad: Whether to load gradients.
467+
load_frozen_params: Whether to load frozen parameters.
468+
"""
457469
for model_chunk in models:
458470
if isinstance(model_chunk, DDP):
459471
model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]
@@ -468,6 +480,13 @@ def load_megatron_model_to_gpu(models, load_grad=True):
468480
buffer.param_data.storage().resize_(buffer.param_data_size)
469481
# copy data from cpu to cuda
470482
buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True)
483+
484+
# Load frozen parameters that were offloaded (e.g. base model in LoRA/PEFT)
485+
if load_frozen_params:
486+
device_id = get_device_id()
487+
for param in model_chunk.module.parameters():
488+
if not param.requires_grad and param.device.type == "cpu":
489+
param.data = param.data.to(device_id, non_blocking=True)
471490
else:
472491
# we need this for ref module
473492
device_id = get_device_id()

verl/workers/engine/megatron/transformer_impl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -602,12 +602,14 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw
602602
return {}
603603

604604
def get_per_tensor_param(self, base_sync_done=False, **kwargs):
605-
load_megatron_model_to_gpu(self.module, load_grad=False)
606605
peft_config = None
607606
non_merge_lora_sync = self.peft_cls is not None and not self.model_config.lora.get("merge", False)
607+
adapter_only = base_sync_done and non_merge_lora_sync
608+
# when lora adapter only, we only load adapter weights when base sync is done, otherwise load all weights
609+
load_megatron_model_to_gpu(self.module, load_grad=False, load_frozen_params=not adapter_only)
608610
if self.vanilla_bridge:
609611
per_tensor_param = self.bridge.export_weights(self.module)
610-
elif base_sync_done and non_merge_lora_sync:
612+
elif adapter_only:
611613
# Only export adapter weights
612614
peft_config = build_peft_config_for_vllm(self.model_config.lora)
613615
per_tensor_param = self.bridge.export_adapter_weights(self.module)

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,15 @@ async def sleep(self):
618618

619619
if self.rollout_mode == RolloutMode.HYBRID:
620620
# Don't use engine.sleep(level=2) here
621-
await self.engine.collective_rpc("sleep", kwargs={"level": 2})
621+
# lora only update adapter weights, so set sleep level to 1
622+
lora_as_adapter = (
623+
self.model_config.lora_rank > 0 or self.model_config.lora.get("rank", 0) > 0
624+
) and not self.model_config.lora.get("merge", False)
625+
if lora_as_adapter:
626+
sleep_level = 1
627+
else:
628+
sleep_level = 2
629+
await self.engine.collective_rpc("sleep", kwargs={"level": sleep_level})
622630

623631
# clear encoder cache: https://github.com/vllm-project/vllm/pull/33452
624632
# await self.engine.reset_encoder_cache()

0 commit comments

Comments
 (0)