Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions examples/grpo_trainer/run_qwen2-7b_math_megatron_fsdp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
set -x

export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
unset ROCR_VISIBLE_DEVICES
# export VLLM_USE_V1=1
# export VLLM_ALLREDUCE_USE_SYMM_MEM=0

rollout_mode="async"
export VLLM_USE_V1=1
return_raw_chat="True"

gsm8k_train_path=$HOME/data/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
math_train_path=$HOME/data/math/train.parquet
math_test_path=$HOME/data/math/test.parquet

train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']"

USE_FUSED_KERNELS=False

python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml'\
algorithm.adv_estimator=grpo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.return_raw_chat=$return_raw_chat \
data.train_batch_size=32 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=/root/models/Qwen2.5-3B-Instruct \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The model path is hardcoded to /root/models/Qwen2.5-3B-Instruct. This makes the example script not portable and difficult for other users to run without modification. It's better to use an environment variable for the model path to make the script more generic and easier to use.

For example, you could add MODEL_PATH=${MODEL_PATH:-/path/to/your/model} at the top of the script and then use $MODEL_PATH here.

Suggested change
actor_rollout_ref.model.path=/root/models/Qwen2.5-3B-Instruct \
actor_rollout_ref.model.path=${MODEL_PATH} \

actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=1 \
actor_rollout_ref.actor.megatron.use_mbridge=True \
actor_rollout_ref.actor.megatron.vanilla_mbridge=True \
+actor_rollout_ref.actor.megatron.use_megatron_fsdp=True \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=$rollout_mode \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.rollout.n=2 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=1 \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console","wandb"]' \
trainer.project_name='verl_grpo_example_gsm8k_math' \
trainer.experiment_name='qwen2_7b_megatron' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
44 changes: 42 additions & 2 deletions verl/utils/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class McoreModuleWrapperConfig:
share_embeddings_and_output_weights: bool = False
wrap_with_ddp: bool = True
use_distributed_optimizer: bool = True
use_megatron_fsdp: bool = False


def make_megatron_module(
Expand Down Expand Up @@ -276,12 +277,26 @@ def peft_pre_wrap_hook(model):
# Extract TransformerConfig from the created model
tf_config = get_model_config(model[0] if isinstance(model, list) else model)
else:
ddp_config = {}
if override_ddp_config is not None:
ddp_config.update(override_ddp_config)

if wrap_config.use_megatron_fsdp:
ddp_config.setdefault("use_distributed_optimizer", True)
ddp_config.setdefault("check_for_nan_in_grad", True)
ddp_config.setdefault("use_megatron_fsdp", True)
ddp_config.setdefault("data_parallel_sharding_strategy", "optim_grads_params")
ddp_config.setdefault("overlap_grad_reduce", True)
wrap_config.wrap_with_ddp = True

model = bridge.get_model(
post_model_creation_callbacks=post_model_creation_callbacks,
wrap_with_ddp=wrap_config.wrap_with_ddp,
fp16=tf_config.fp16,
bf16=tf_config.bf16,
ddp_config=override_ddp_config,
use_megatron_fsdp=wrap_config.use_megatron_fsdp,
ddp_config=ddp_config,
data_parallel_random_init=False,
)

if isinstance(tf_config, MLATransformerConfig):
Expand Down Expand Up @@ -316,7 +331,13 @@ def megatron_model_provider(pre_process, post_process, vp_stage=None):
return model, tf_config


ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)
try:
from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as _MegatronFSDP
from megatron.core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp import MegatronFSDP

ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module, _MegatronFSDP, MegatronFSDP)
except ImportError:
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)


def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
Expand All @@ -334,6 +355,25 @@ def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
return unwrapped_model


def synchronize_megatron_fsdp_params(model_chunks: list) -> bool:
"""Synchronize FSDP parameter state from raw sharded tensors back to DTensors.

Returns True if synchronization was performed.
"""
for model_chunk in model_chunks:
fsdp = model_chunk.module
if getattr(fsdp, "data_parallel_sharding_strategy", None) == "optim_grads_params":
fsdp.synchronize_param_gather()
return True
return False
Comment on lines +363 to +368
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function synchronize_megatron_fsdp_params returns immediately after synchronizing the first FSDP module it finds. If model_chunks can contain multiple FSDP-wrapped modules (e.g., with pipeline parallelism), this will result in only the first one being synchronized, potentially leading to inconsistent model states and silent correctness issues during inference. The function should iterate through all model chunks and synchronize all applicable FSDP modules before returning.

Suggested change
for model_chunk in model_chunks:
fsdp = model_chunk.module
if getattr(fsdp, "data_parallel_sharding_strategy", None) == "optim_grads_params":
fsdp.synchronize_param_gather()
return True
return False
synchronized = False
for model_chunk in model_chunks:
fsdp = model_chunk.module
if getattr(fsdp, "data_parallel_sharding_strategy", None) == "optim_grads_params":
fsdp.synchronize_param_gather()
synchronized = True
return synchronized



def restore_megatron_fsdp_params(model_chunks: list):
"""Restore FSDP parameters to raw sharded state for training."""
for model_chunk in model_chunks:
model_chunk.start_param_sync()


def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig:
"""[Deprecated] convert config

Expand Down
2 changes: 2 additions & 0 deletions verl/workers/config/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class McoreEngineConfig(EngineConfig):
override_ddp_config (dict[str, Any]): Override configuration for DDP.
override_transformer_config (dict[str, Any]): Override configuration for transformer.
use_mbridge (bool): Whether to use MBridge for communication.
use_megatron_fsdp (bool): Whether to use Megatron-FSDP (Zero-3 sharding).
dtype (str): Mixed precision training param dtype, default "bfloat16"
"""

Expand All @@ -165,6 +166,7 @@ class McoreEngineConfig(EngineConfig):
override_mcore_model_config: dict[str, Any] = field(default_factory=dict)
use_mbridge: bool = True
vanilla_mbridge: bool = True
use_megatron_fsdp: bool = False
strategy: str = "megatron"

def __post_init__(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _build_megatron_module(self):
share_embeddings_and_output_weights=self.model_config.share_embeddings_and_output_weights,
wrap_with_ddp=wrap_with_ddp,
use_distributed_optimizer=self.engine_config.use_distributed_optimizer,
use_megatron_fsdp=getattr(self.engine_config, "use_megatron_fsdp", False),
)
module, updated_tf_config = make_megatron_module(
wrap_config=wrap_config,
Expand Down
13 changes: 13 additions & 0 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
offload_megatron_optimizer,
per_tensor_generator,
register_megatron_training_hooks,
restore_megatron_fsdp_params,
synchronize_megatron_fsdp_params,
)
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights
Expand Down Expand Up @@ -406,6 +408,7 @@ def _build_model_optimizer(
share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
wrap_with_ddp=True,
use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,
use_megatron_fsdp=self.config.actor.megatron.get("use_megatron_fsdp", False),
)
actor_module, updated_tf_config = make_megatron_module(
wrap_config=wrap_config,
Expand Down Expand Up @@ -700,6 +703,11 @@ async def rollout_mode(self):
self.rollout.sleep_level != 1 and self.config.rollout.free_cache_engine
)

if self.config.actor.megatron.get("use_megatron_fsdp", False):
self._fsdp_needs_param_sync_restore = synchronize_megatron_fsdp_params(self.actor.actor_module)
if hasattr(self, "ref_module") and self.ref_module is not None:
synchronize_megatron_fsdp_params(self.ref_module)

if self.bridge is not None:
if self.vanilla_bridge:
per_tensor_param = self.bridge.export_weights(self.actor.actor_module)
Expand Down Expand Up @@ -747,6 +755,11 @@ async def rollout_mode(self):
@DistProfiler.annotate(color="red", role="actor_update")
def update_actor(self, data: DataProto):
assert self._is_actor

if getattr(self, "_fsdp_needs_param_sync_restore", False):
restore_megatron_fsdp_params(self.actor.actor_module)
self._fsdp_needs_param_sync_restore = False

if self._is_offload_param:
load_megatron_model_to_gpu(self.actor_module)
log_gpu_memory_usage("After load actor params and grad during update_actor", logger=logger)
Expand Down