From 95f9b6b34fff5363d657f2e30f4523a47b5f51f2 Mon Sep 17 00:00:00 2001 From: conver334 Date: Thu, 26 Feb 2026 21:40:15 -0800 Subject: [PATCH] support Megatron-FSDP --- .../run_qwen2-7b_math_megatron_fsdp.sh | 65 +++++++++++++++++++ verl/utils/megatron_utils.py | 44 ++++++++++++- verl/workers/config/engine.py | 2 + .../engine/megatron/transformer_impl.py | 1 + verl/workers/megatron_workers.py | 13 ++++ 5 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 examples/grpo_trainer/run_qwen2-7b_math_megatron_fsdp.sh diff --git a/examples/grpo_trainer/run_qwen2-7b_math_megatron_fsdp.sh b/examples/grpo_trainer/run_qwen2-7b_math_megatron_fsdp.sh new file mode 100644 index 00000000000..6e6296d9cba --- /dev/null +++ b/examples/grpo_trainer/run_qwen2-7b_math_megatron_fsdp.sh @@ -0,0 +1,65 @@ +set -x + +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping +unset ROCR_VISIBLE_DEVICES +# export VLLM_USE_V1=1 +# export VLLM_ALLREDUCE_USE_SYMM_MEM=0 + +rollout_mode="async" +export VLLM_USE_V1=1 +return_raw_chat="True" + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +USE_FUSED_KERNELS=False + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.return_raw_chat=$return_raw_chat \ + data.train_batch_size=32 \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=/root/models/Qwen2.5-3B-Instruct \ + actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=1 \ + actor_rollout_ref.actor.megatron.use_mbridge=True \ + actor_rollout_ref.actor.megatron.vanilla_mbridge=True \ + +actor_rollout_ref.actor.megatron.use_megatron_fsdp=True \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=$rollout_mode \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.rollout.n=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=1 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.experiment_name='qwen2_7b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 9572fb91962..da13a5ef8db 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -175,6 +175,7 @@ class McoreModuleWrapperConfig: share_embeddings_and_output_weights: bool = False wrap_with_ddp: bool = True use_distributed_optimizer: bool = True + use_megatron_fsdp: bool = False def make_megatron_module( @@ -274,12 +275,26 @@ def peft_pre_wrap_hook(model): # Extract TransformerConfig from the created model tf_config = get_model_config(model[0] if isinstance(model, list) else model) else: + ddp_config = {} + if override_ddp_config is not None: + ddp_config.update(override_ddp_config) + + if wrap_config.use_megatron_fsdp: + ddp_config.setdefault("use_distributed_optimizer", True) + ddp_config.setdefault("check_for_nan_in_grad", True) + ddp_config.setdefault("use_megatron_fsdp", True) + ddp_config.setdefault("data_parallel_sharding_strategy", "optim_grads_params") + ddp_config.setdefault("overlap_grad_reduce", True) + wrap_config.wrap_with_ddp = True + model = bridge.get_model( post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=wrap_config.wrap_with_ddp, fp16=tf_config.fp16, bf16=tf_config.bf16, - ddp_config=override_ddp_config, + use_megatron_fsdp=wrap_config.use_megatron_fsdp, + ddp_config=ddp_config, + data_parallel_random_init=False, ) if isinstance(tf_config, MLATransformerConfig): @@ -314,7 +329,13 @@ def megatron_model_provider(pre_process, post_process, vp_stage=None): return model, tf_config -ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) +try: + from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as _MegatronFSDP + from megatron.core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp import MegatronFSDP + + ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module, _MegatronFSDP, MegatronFSDP) +except ImportError: + ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): @@ -332,6 +353,25 @@ def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): return unwrapped_model +def synchronize_megatron_fsdp_params(model_chunks: list) -> bool: + """Synchronize FSDP parameter state from raw sharded tensors back to DTensors. + + Returns True if synchronization was performed. + """ + for model_chunk in model_chunks: + fsdp = model_chunk.module + if getattr(fsdp, "data_parallel_sharding_strategy", None) == "optim_grads_params": + fsdp.synchronize_param_gather() + return True + return False + + +def restore_megatron_fsdp_params(model_chunks: list): + """Restore FSDP parameters to raw sharded state for training.""" + for model_chunk in model_chunks: + model_chunk.start_param_sync() + + def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig: """[Deprecated] convert config diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index feb559b374c..28b299c43c5 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -140,6 +140,7 @@ class McoreEngineConfig(EngineConfig): override_ddp_config (dict[str, Any]): Override configuration for DDP. override_transformer_config (dict[str, Any]): Override configuration for transformer. use_mbridge (bool): Whether to use MBridge for communication. + use_megatron_fsdp (bool): Whether to use Megatron-FSDP (Zero-3 sharding). dtype (str): Mixed precision training param dtype, default "bfloat16" """ @@ -164,6 +165,7 @@ class McoreEngineConfig(EngineConfig): override_mcore_model_config: dict[str, Any] = field(default_factory=dict) use_mbridge: bool = True vanilla_mbridge: bool = True + use_megatron_fsdp: bool = False strategy: str = "megatron" def __post_init__(self) -> None: diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 0e3f7ff6a29..41a97fa968e 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -214,6 +214,7 @@ def _build_megatron_module(self): share_embeddings_and_output_weights=self.model_config.share_embeddings_and_output_weights, wrap_with_ddp=wrap_with_ddp, use_distributed_optimizer=self.engine_config.use_distributed_optimizer, + use_megatron_fsdp=getattr(self.engine_config, "use_megatron_fsdp", False), ) module, updated_tf_config = make_megatron_module( wrap_config=wrap_config, diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 3baf4020810..6b44bdc4a92 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -61,6 +61,8 @@ offload_megatron_optimizer, per_tensor_generator, register_megatron_training_hooks, + restore_megatron_fsdp_params, + synchronize_megatron_fsdp_params, ) from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights @@ -403,6 +405,7 @@ def _build_model_optimizer( share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, wrap_with_ddp=True, use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, + use_megatron_fsdp=self.config.actor.megatron.get("use_megatron_fsdp", False), ) actor_module, updated_tf_config = make_megatron_module( wrap_config=wrap_config, @@ -696,6 +699,11 @@ async def rollout_mode(self): self.rollout.sleep_level != 1 and self.config.rollout.free_cache_engine ) + if self.config.actor.megatron.get("use_megatron_fsdp", False): + self._fsdp_needs_param_sync_restore = synchronize_megatron_fsdp_params(self.actor.actor_module) + if hasattr(self, "ref_module") and self.ref_module is not None: + synchronize_megatron_fsdp_params(self.ref_module) + if self.bridge is not None: if self.vanilla_bridge: per_tensor_param = self.bridge.export_weights(self.actor.actor_module) @@ -743,6 +751,11 @@ async def rollout_mode(self): @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: DataProto): assert self._is_actor + + if getattr(self, "_fsdp_needs_param_sync_restore", False): + restore_megatron_fsdp_params(self.actor.actor_module) + self._fsdp_needs_param_sync_restore = False + if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) log_gpu_memory_usage("After load actor params and grad during update_actor", logger=logger)