Skip to content

Commit d9611d9

Browse files
committed
submit an RL script
1 parent 44f5a1d commit d9611d9

File tree

3 files changed

+182
-4
lines changed

3 files changed

+182
-4
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#!/usr/bin/env bash
2+
# Qwen3.5-35B-A3B MoE GRPO RL with Megatron (single node, 8 GPUs, geo3k dataset)
3+
#
4+
# notes on vllm:
5+
# by 20260225, the latest vllm nightly does not support qwen3.5 rollout, to use this script, you need to
6+
# 1. wait until vllm supports qwen3.5 officially, and build a verl docker with that version of vllm
7+
# 2. self build a verl docker image with vllm from source code with qwen3.5 support (main branch 20260225 is OK)
8+
# I succeeded in running this script with the main branch of vllm on 20260225, yet there are still some minor issues
9+
# the vllm qwen3.5 during initialization, need to be fixed. Also, the cuda_graph is somehow not working, need to be
10+
# fixed, either by verl team with supoorts to vllm0.16, or by vllm team.
11+
# Requirements:
12+
# - 8 GPUs (80GB each, e.g. 1x8 H100/H200)
13+
# - Additional packages on top of the base image:
14+
# pip install --upgrade transformers
15+
# pip install flash-linear-attention
16+
# pip install -U git+https://github.com/ISEEKYAN/mbridge.git
17+
# - Megatron-LM dev branch with Qwen3.5 GDN support
18+
#
19+
# Qwen3.5 architecture notes:
20+
# Qwen3.5 uses Gated Delta Net (GDN) linear attention which currently does
21+
# NOT support packed sequences (THD format) in Megatron-LM. Therefore:
22+
# - actor.megatron.use_remove_padding=False (forces bshd compute format)
23+
# - model.use_remove_padding=True (keeps NestedTensor in data pipeline)
24+
# - actor.use_dynamic_bsz=False (required for bshd mode)
25+
#
26+
# Once Megatron-LM adds THD support for Qwen3.5 GDN, use_remove_padding
27+
# can be set to True for better performance.
28+
#
29+
# Tested parallelism config (8 GPUs / 1 node):
30+
# TP=2 PP=1 CP=1 EP=8 ETP=1 GEN_TP=8
31+
#
32+
33+
export CUDA_DEVICE_MAX_CONNECTIONS=1
34+
export VLLM_USE_V1=1
35+
export VLLM_ALLREDUCE_USE_SYMM_MEM=0
36+
37+
set -xeuo pipefail
38+
39+
########################### Quick Config ###########################
40+
41+
TP=${TP:-2}
42+
PP=${PP:-1}
43+
CP=${CP:-1}
44+
EP=${EP:-8}
45+
ETP=${ETP:-1}
46+
GEN_TP=${GEN_TP:-8}
47+
48+
ALL_OFFLOAD=${ALL_OFFLOAD:-True}
49+
50+
rollout_name="vllm"
51+
project_name='verl_grpo_qwen3_5_35b_geo3k'
52+
exp_name='qwen3_5_35b_megatron'
53+
adv_estimator=grpo
54+
55+
HF_MODEL_PATH=${HF_MODEL_PATH:-"Qwen3.5-35B-A3B"}
56+
train_path=${train_path:-$HOME/data/geo3k/train.parquet}
57+
test_path=${test_path:-$HOME/data/geo3k/test.parquet}
58+
59+
########################### Parameter Arrays ###########################
60+
61+
DATA=(
62+
data.train_files=${train_path}
63+
data.val_files=${test_path}
64+
data.train_batch_size=32
65+
data.max_prompt_length=1024
66+
data.max_response_length=2048
67+
data.truncation='error'
68+
data.filter_overlong_prompts=True
69+
)
70+
71+
MODEL=(
72+
actor_rollout_ref.model.path=${HF_MODEL_PATH}
73+
actor_rollout_ref.model.trust_remote_code=True
74+
actor_rollout_ref.model.use_remove_padding=True
75+
)
76+
77+
ACTOR=(
78+
actor_rollout_ref.actor.optim.lr=1e-6
79+
actor_rollout_ref.actor.ppo_mini_batch_size=32
80+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1
81+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=4096
82+
actor_rollout_ref.actor.use_dynamic_bsz=False
83+
actor_rollout_ref.actor.use_kl_loss=True
84+
actor_rollout_ref.actor.kl_loss_coef=0.01
85+
actor_rollout_ref.actor.kl_loss_type=low_var_kl
86+
actor_rollout_ref.actor.entropy_coeff=0
87+
actor_rollout_ref.actor.megatron.use_mbridge=True
88+
actor_rollout_ref.actor.megatron.vanilla_mbridge=True
89+
actor_rollout_ref.actor.megatron.use_remove_padding=False
90+
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${TP}
91+
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${PP}
92+
actor_rollout_ref.actor.megatron.context_parallel_size=${CP}
93+
actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP}
94+
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP}
95+
actor_rollout_ref.actor.megatron.param_offload=${ALL_OFFLOAD}
96+
actor_rollout_ref.actor.megatron.optimizer_offload=${ALL_OFFLOAD}
97+
actor_rollout_ref.actor.megatron.grad_offload=${ALL_OFFLOAD}
98+
actor_rollout_ref.actor.megatron.dtype=bfloat16
99+
++actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend=auto
100+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
101+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
102+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
103+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_aux_loss_coeff=0.01
104+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_z_loss_coeff=0.001
105+
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1
106+
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True
107+
+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True
108+
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True
109+
)
110+
111+
ROLLOUT=(
112+
actor_rollout_ref.rollout.name=${rollout_name}
113+
actor_rollout_ref.rollout.tensor_model_parallel_size=${GEN_TP}
114+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6
115+
actor_rollout_ref.rollout.n=5
116+
actor_rollout_ref.rollout.mode=async
117+
actor_rollout_ref.rollout.enforce_eager=True
118+
actor_rollout_ref.rollout.dtype=bfloat16
119+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1
120+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False
121+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=4096
122+
)
123+
124+
REF=(
125+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1
126+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False
127+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=4096
128+
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${TP}
129+
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${PP}
130+
actor_rollout_ref.ref.megatron.context_parallel_size=${CP}
131+
actor_rollout_ref.ref.megatron.expert_model_parallel_size=${EP}
132+
actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${ETP}
133+
actor_rollout_ref.ref.megatron.param_offload=${ALL_OFFLOAD}
134+
)
135+
136+
ALGORITHM=(
137+
algorithm.adv_estimator=${adv_estimator}
138+
algorithm.use_kl_in_reward=False
139+
)
140+
141+
TRAINER=(
142+
trainer.critic_warmup=0
143+
trainer.logger='["console","wandb"]'
144+
trainer.project_name=${project_name}
145+
trainer.experiment_name=${exp_name}
146+
trainer.n_gpus_per_node=8
147+
trainer.nnodes=1
148+
trainer.save_freq=20
149+
trainer.val_before_train=False
150+
trainer.test_freq=5
151+
trainer.total_epochs=15
152+
)
153+
154+
########################### Launch ###########################
155+
156+
python3 -m verl.trainer.main_ppo \
157+
--config-path=config \
158+
--config-name='ppo_megatron_trainer.yaml' \
159+
"${DATA[@]}" \
160+
"${ALGORITHM[@]}" \
161+
"${MODEL[@]}" \
162+
"${ROLLOUT[@]}" \
163+
"${ACTOR[@]}" \
164+
"${REF[@]}" \
165+
"${TRAINER[@]}" \
166+
"$@"

verl/models/mcore/model_forward.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,22 +122,33 @@ def model_forward(
122122
When using the bshd format, we have to add paddings to the input_ids to meet the longest sequence length,
123123
so it is recommended to disable dynamic batch size and set batch size to 1
124124
"""
125-
assert not vision_model, "vision model does not support bshd format"
126125
assert fp8 is None, "fp8 is not supported for bshd format yet"
127126

128127
batch_size, sequence_length = attention_mask.shape[:2]
128+
position_ids_for_preprocess = (
129+
torch.arange(sequence_length, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
130+
if vision_model
131+
else position_ids
132+
)
133+
pre_process_for_bshd = True if vision_model else pre_process
129134
new_input_ids, new_attention_mask, new_position_ids = preprocess_bshd(
130-
input_ids, attention_mask, position_ids, sequence_parallel=sp, pre_process=pre_process
135+
input_ids,
136+
attention_mask,
137+
position_ids_for_preprocess,
138+
sequence_parallel=sp,
139+
pre_process=pre_process_for_bshd,
131140
)
132141
output_orig = model(
133142
input_ids=new_input_ids,
134-
position_ids=new_position_ids,
143+
position_ids=None if vision_model else new_position_ids,
135144
attention_mask=new_attention_mask,
136145
**model_kwargs,
137146
)
138147
if post_process and logits_processor is not None:
139148
args = {
140-
k: preprocess_bshd(v, attention_mask, position_ids, sequence_parallel=sp, pre_process=True)[0]
149+
k: preprocess_bshd(
150+
v, attention_mask, position_ids_for_preprocess, sequence_parallel=sp, pre_process=True
151+
)[0]
141152
for k, v in logits_processor_args.items()
142153
}
143154
output_dict = logits_processor(output_orig, **args)

verl/models/mcore/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class SupportedVLM(Enum):
3030
QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration"
3131
QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration"
3232
QWEN3_VL = "Qwen3VLForConditionalGeneration"
33+
QWEN3_5_MOE_VL = "Qwen3_5MoeForConditionalGeneration"
3334

3435

3536
supported_vlm = [member.value for member in SupportedVLM]

0 commit comments

Comments
 (0)