-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[BREAKING][megatron] feat: support Megatron-FSDP as a new training backend #5423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| set -x | ||
|
|
||
| export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping | ||
| unset ROCR_VISIBLE_DEVICES | ||
| # export VLLM_USE_V1=1 | ||
| # export VLLM_ALLREDUCE_USE_SYMM_MEM=0 | ||
|
|
||
| rollout_mode="async" | ||
| export VLLM_USE_V1=1 | ||
| return_raw_chat="True" | ||
|
|
||
| gsm8k_train_path=$HOME/data/gsm8k/train.parquet | ||
| gsm8k_test_path=$HOME/data/gsm8k/test.parquet | ||
| math_train_path=$HOME/data/math/train.parquet | ||
| math_test_path=$HOME/data/math/test.parquet | ||
|
|
||
| train_files="['$gsm8k_train_path', '$math_train_path']" | ||
| test_files="['$gsm8k_test_path', '$math_test_path']" | ||
|
|
||
| USE_FUSED_KERNELS=False | ||
|
|
||
| python3 -m verl.trainer.main_ppo --config-path=config \ | ||
| --config-name='ppo_megatron_trainer.yaml'\ | ||
| algorithm.adv_estimator=grpo \ | ||
| data.train_files="$train_files" \ | ||
| data.val_files="$test_files" \ | ||
| data.return_raw_chat=$return_raw_chat \ | ||
| data.train_batch_size=32 \ | ||
| data.max_prompt_length=512 \ | ||
| data.max_response_length=512 \ | ||
| data.filter_overlong_prompts=True \ | ||
| data.truncation='error' \ | ||
| actor_rollout_ref.model.path=/root/models/Qwen2.5-3B-Instruct \ | ||
| actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS \ | ||
| actor_rollout_ref.actor.optim.lr=1e-6 \ | ||
| actor_rollout_ref.actor.ppo_mini_batch_size=16 \ | ||
| actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ | ||
| actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \ | ||
| actor_rollout_ref.actor.megatron.tensor_model_parallel_size=1 \ | ||
| actor_rollout_ref.actor.megatron.use_mbridge=True \ | ||
| actor_rollout_ref.actor.megatron.vanilla_mbridge=True \ | ||
| +actor_rollout_ref.actor.megatron.use_megatron_fsdp=True \ | ||
| actor_rollout_ref.actor.use_kl_loss=True \ | ||
| actor_rollout_ref.actor.kl_loss_coef=0.001 \ | ||
| actor_rollout_ref.actor.kl_loss_type=low_var_kl \ | ||
| actor_rollout_ref.actor.entropy_coeff=0 \ | ||
| actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ | ||
| actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ | ||
| actor_rollout_ref.rollout.name=vllm \ | ||
| actor_rollout_ref.rollout.mode=$rollout_mode \ | ||
| actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ | ||
| actor_rollout_ref.rollout.n=2 \ | ||
| actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ | ||
| actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \ | ||
| actor_rollout_ref.ref.megatron.tensor_model_parallel_size=1 \ | ||
| algorithm.use_kl_in_reward=False \ | ||
| trainer.critic_warmup=0 \ | ||
| trainer.logger='["console","wandb"]' \ | ||
| trainer.project_name='verl_grpo_example_gsm8k_math' \ | ||
| trainer.experiment_name='qwen2_7b_megatron' \ | ||
| trainer.n_gpus_per_node=8 \ | ||
| trainer.nnodes=1 \ | ||
| trainer.save_freq=20 \ | ||
| trainer.test_freq=5 \ | ||
| trainer.total_epochs=15 $@ | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -175,6 +175,7 @@ class McoreModuleWrapperConfig: | |||||||||||||||||||||||||||
| share_embeddings_and_output_weights: bool = False | ||||||||||||||||||||||||||||
| wrap_with_ddp: bool = True | ||||||||||||||||||||||||||||
| use_distributed_optimizer: bool = True | ||||||||||||||||||||||||||||
| use_megatron_fsdp: bool = False | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def make_megatron_module( | ||||||||||||||||||||||||||||
|
|
@@ -276,12 +277,26 @@ def peft_pre_wrap_hook(model): | |||||||||||||||||||||||||||
| # Extract TransformerConfig from the created model | ||||||||||||||||||||||||||||
| tf_config = get_model_config(model[0] if isinstance(model, list) else model) | ||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| ddp_config = {} | ||||||||||||||||||||||||||||
| if override_ddp_config is not None: | ||||||||||||||||||||||||||||
| ddp_config.update(override_ddp_config) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if wrap_config.use_megatron_fsdp: | ||||||||||||||||||||||||||||
| ddp_config.setdefault("use_distributed_optimizer", True) | ||||||||||||||||||||||||||||
| ddp_config.setdefault("check_for_nan_in_grad", True) | ||||||||||||||||||||||||||||
| ddp_config.setdefault("use_megatron_fsdp", True) | ||||||||||||||||||||||||||||
| ddp_config.setdefault("data_parallel_sharding_strategy", "optim_grads_params") | ||||||||||||||||||||||||||||
| ddp_config.setdefault("overlap_grad_reduce", True) | ||||||||||||||||||||||||||||
| wrap_config.wrap_with_ddp = True | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| model = bridge.get_model( | ||||||||||||||||||||||||||||
| post_model_creation_callbacks=post_model_creation_callbacks, | ||||||||||||||||||||||||||||
| wrap_with_ddp=wrap_config.wrap_with_ddp, | ||||||||||||||||||||||||||||
| fp16=tf_config.fp16, | ||||||||||||||||||||||||||||
| bf16=tf_config.bf16, | ||||||||||||||||||||||||||||
| ddp_config=override_ddp_config, | ||||||||||||||||||||||||||||
| use_megatron_fsdp=wrap_config.use_megatron_fsdp, | ||||||||||||||||||||||||||||
| ddp_config=ddp_config, | ||||||||||||||||||||||||||||
| data_parallel_random_init=False, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if isinstance(tf_config, MLATransformerConfig): | ||||||||||||||||||||||||||||
|
|
@@ -316,7 +331,13 @@ def megatron_model_provider(pre_process, post_process, vp_stage=None): | |||||||||||||||||||||||||||
| return model, tf_config | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) | ||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||
| from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as _MegatronFSDP | ||||||||||||||||||||||||||||
| from megatron.core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp import MegatronFSDP | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module, _MegatronFSDP, MegatronFSDP) | ||||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||||
| ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): | ||||||||||||||||||||||||||||
|
|
@@ -334,6 +355,25 @@ def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): | |||||||||||||||||||||||||||
| return unwrapped_model | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def synchronize_megatron_fsdp_params(model_chunks: list) -> bool: | ||||||||||||||||||||||||||||
| """Synchronize FSDP parameter state from raw sharded tensors back to DTensors. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Returns True if synchronization was performed. | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| for model_chunk in model_chunks: | ||||||||||||||||||||||||||||
| fsdp = model_chunk.module | ||||||||||||||||||||||||||||
| if getattr(fsdp, "data_parallel_sharding_strategy", None) == "optim_grads_params": | ||||||||||||||||||||||||||||
| fsdp.synchronize_param_gather() | ||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||
|
Comment on lines
+363
to
+368
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function
Suggested change
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def restore_megatron_fsdp_params(model_chunks: list): | ||||||||||||||||||||||||||||
| """Restore FSDP parameters to raw sharded state for training.""" | ||||||||||||||||||||||||||||
| for model_chunk in model_chunks: | ||||||||||||||||||||||||||||
| model_chunk.start_param_sync() | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig: | ||||||||||||||||||||||||||||
| """[Deprecated] convert config | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model path is hardcoded to
/root/models/Qwen2.5-3B-Instruct. This makes the example script not portable and difficult for other users to run without modification. It's better to use an environment variable for the model path to make the script more generic and easier to use.For example, you could add
MODEL_PATH=${MODEL_PATH:-/path/to/your/model}at the top of the script and then use$MODEL_PATHhere.