Skip to content

Commit 44f5a1d

Browse files
committed
add qwen3.5 megatron sft example
1 parent 7cf31bf commit 44f5a1d

File tree

8 files changed

+264
-31
lines changed

8 files changed

+264
-31
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#!/usr/bin/env bash
2+
# Qwen3.5-397B-A17B SFT with Megatron backend + mbridge
3+
#
4+
# Requirements:
5+
# - 128+ GPUs (80GB each, e.g. 16x8 H100/H200)
6+
# - Docker: verlai/verl:vllm015 (or equivalent)
7+
# - Additional packages on top of the base image:
8+
# pip install --upgrade transformers
9+
# pip install flash-linear-attention
10+
# pip install -U git+https://github.com/ISEEKYAN/mbridge.git
11+
# - Megatron-LM dev branch with Qwen3.5 GDN support
12+
#
13+
# Qwen3.5 architecture notes:
14+
# Qwen3.5 uses Gated Delta Net (GDN) linear attention which currently does
15+
# NOT support packed sequences (THD format) in Megatron-LM. Therefore:
16+
# - engine.use_remove_padding=False (forces bshd compute format)
17+
# - model.use_remove_padding=True (keeps NestedTensor in data pipeline)
18+
# - data.use_dynamic_bsz=False (required for bshd mode)
19+
#
20+
# Once https://github.com/NVIDIA/Megatron-LM/pull/2644 is merged, THD
21+
# format will be supported and engine.use_remove_padding can be set to True
22+
# for better performance.
23+
#
24+
# Tested parallelism config (128 GPUs / 16 nodes):
25+
# TP=2 PP=4 EP=32 CP=1
26+
27+
set -xeuo pipefail
28+
29+
# ============================================================
30+
# Distributed
31+
# ============================================================
32+
NUM_GPUS=${NUM_GPUS:-8}
33+
MASTER_ADDR=${MASTER_ADDR:-localhost}
34+
MASTER_PORT=${MASTER_PORT:-29500}
35+
NNODES=${NNODES:-16}
36+
NODE_RANK=${NODE_RANK:-0}
37+
38+
# ============================================================
39+
# Data
40+
# ============================================================
41+
DATASET_DIR=${DATASET_DIR:-~/dataset}
42+
TRAIN_FILES=${TRAIN_FILES:-${DATASET_DIR}/train.parquet}
43+
44+
# ============================================================
45+
# Model
46+
# ============================================================
47+
MODEL_PATH=${MODEL_PATH:-Qwen/Qwen3.5-397B-A17B}
48+
49+
# ============================================================
50+
# Parallelism
51+
# ============================================================
52+
TP_SIZE=${TP_SIZE:-2}
53+
PP_SIZE=${PP_SIZE:-4}
54+
VPP_SIZE=${VPP_SIZE:-null}
55+
CP_SIZE=${CP_SIZE:-1}
56+
EP_SIZE=${EP_SIZE:-32}
57+
ETP_SIZE=${ETP_SIZE:-1}
58+
59+
# ============================================================
60+
# Training
61+
# ============================================================
62+
TRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-128}
63+
MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-2}
64+
MAX_LENGTH=${MAX_LENGTH:-2048}
65+
LR=${LR:-2e-5}
66+
MIN_LR=${MIN_LR:-2e-6}
67+
DTYPE=${DTYPE:-bfloat16}
68+
69+
BACKEND=megatron
70+
RESUME_MODE=${RESUME_MODE:-disable}
71+
72+
project_name=verl_sft_qwen3_5
73+
exp_name=qwen3_5-${BACKEND}-tp${TP_SIZE}-pp${PP_SIZE}-cp${CP_SIZE}-ep${EP_SIZE}
74+
ckpts_home=${ckpts_home:-~/verl/checkpoints/${project_name}/${exp_name}}
75+
mkdir -p "${ckpts_home}"
76+
77+
# ============================================================
78+
# Engine config
79+
# ============================================================
80+
# Key Qwen3.5 settings:
81+
# engine.use_remove_padding=False - GDN requires bshd format (no THD)
82+
# engine.vanilla_mbridge=True - use mbridge (not megatron-bridge)
83+
ENGINE_CONFIG="\
84+
engine=${BACKEND} \
85+
optim=${BACKEND} \
86+
optim.lr=${LR} \
87+
optim.min_lr=${MIN_LR} \
88+
optim.lr_warmup_steps=10 \
89+
optim.weight_decay=0.1 \
90+
optim.betas='[0.9,0.95]' \
91+
optim.clip_grad=1.0 \
92+
optim.lr_warmup_init=0 \
93+
optim.lr_decay_style=cosine \
94+
+optim.override_optimizer_config.optimizer_offload_fraction=1 \
95+
+optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \
96+
+optim.override_optimizer_config.use_precision_aware_optimizer=True \
97+
+optim.override_optimizer_config.optimizer_cpu_offload=True \
98+
engine.tensor_model_parallel_size=${TP_SIZE} \
99+
engine.pipeline_model_parallel_size=${PP_SIZE} \
100+
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
101+
engine.context_parallel_size=${CP_SIZE} \
102+
engine.expert_model_parallel_size=${EP_SIZE} \
103+
engine.expert_tensor_parallel_size=${ETP_SIZE} \
104+
engine.use_mbridge=True \
105+
engine.vanilla_mbridge=True \
106+
engine.dtype=${DTYPE} \
107+
engine.use_remove_padding=False \
108+
engine.override_transformer_config.attention_backend=auto \
109+
+engine.override_transformer_config.recompute_method=uniform \
110+
+engine.override_transformer_config.recompute_granularity=full \
111+
+engine.override_transformer_config.recompute_num_layers=1"
112+
113+
# ============================================================
114+
# Launch
115+
# ============================================================
116+
torchrun \
117+
--nproc_per_node=${NUM_GPUS} \
118+
--nnodes=${NNODES} \
119+
--node_rank=${NODE_RANK} \
120+
--master_addr=${MASTER_ADDR} \
121+
--master_port=${MASTER_PORT} \
122+
-m verl.trainer.sft_trainer \
123+
data.train_files="${TRAIN_FILES}" \
124+
data.train_batch_size=${TRAIN_BATCH_SIZE} \
125+
data.micro_batch_size_per_gpu=${MICRO_BATCH_SIZE} \
126+
data.max_length=${MAX_LENGTH} \
127+
data.pad_mode=no_padding \
128+
data.truncation=error \
129+
data.use_dynamic_bsz=False \
130+
data.max_token_len_per_gpu=${MAX_LENGTH} \
131+
data.messages_key=messages \
132+
model.path=${MODEL_PATH} \
133+
model.use_remove_padding=True \
134+
model.trust_remote_code=True \
135+
${ENGINE_CONFIG} \
136+
trainer.test_freq=-1 \
137+
trainer.save_freq=500 \
138+
trainer.logger="['console']" \
139+
trainer.project_name="${project_name}" \
140+
trainer.experiment_name="${exp_name}" \
141+
trainer.total_epochs=1 \
142+
trainer.default_local_dir="${ckpts_home}" \
143+
trainer.resume_mode=${RESUME_MODE}

verl/models/mcore/model_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def gptmodel_forward_no_padding(
258258
output_orig = model(
259259
input_ids=input_ids_bshd,
260260
attention_mask=attention_mask_bshd,
261-
position_ids=position_ids_bshd,
261+
position_ids=None if vision_model else position_ids_bshd,
262262
**model_kwargs,
263263
)
264264
if post_process and logits_processor is not None:

verl/utils/dataset/multiturn_sft_dataset.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from functools import wraps
2323
from typing import Any, Optional
2424

25+
import jinja2
2526
import numpy as np
2627
import pandas as pd
2728
import torch
@@ -215,20 +216,52 @@ def _process_single_message(
215216
Returns:
216217
Tuple of (input_ids, loss_mask, attention_mask, dict[str, torch.Tensor])
217218
"""
218-
processor = self.processor if self.processor is not None else self.tokenizer
219+
has_visual_content = isinstance(message.get("content"), list) and any(
220+
isinstance(c, dict) and c.get("type") in ("image", "video") for c in message["content"]
221+
)
222+
processor = self.processor if self.processor is not None and has_visual_content else self.tokenizer
219223
apply_chat_template_kwargs = {**self.apply_chat_template_kwargs}
220224
if enable_thinking is not None:
221225
apply_chat_template_kwargs["enable_thinking"] = enable_thinking
222226

223-
inputs = processor.apply_chat_template(
224-
[message],
225-
tools=tools,
226-
add_generation_prompt=False,
227-
tokenize=True,
228-
return_dict=True,
229-
return_tensors="pt",
230-
**apply_chat_template_kwargs,
231-
)
227+
try:
228+
inputs = processor.apply_chat_template(
229+
[message],
230+
tools=tools,
231+
add_generation_prompt=False,
232+
tokenize=True,
233+
return_dict=True,
234+
return_tensors="pt",
235+
**apply_chat_template_kwargs,
236+
)
237+
except (jinja2.exceptions.TemplateError, Exception) as e:
238+
if "No user query" not in str(e):
239+
raise
240+
# Chat templates that require a user message (e.g. Qwen3.5) fail
241+
# when tokenising a single non-user message. Fallback: tokenise the
242+
# conversation up to this turn and subtract the prefix.
243+
inputs_full = processor.apply_chat_template(
244+
full_message[: index + 1],
245+
tools=tools,
246+
add_generation_prompt=False,
247+
tokenize=True,
248+
return_dict=True,
249+
return_tensors="pt",
250+
**apply_chat_template_kwargs,
251+
)
252+
prefix_len = 0
253+
if index > 0:
254+
inputs_prev = processor.apply_chat_template(
255+
full_message[:index],
256+
tools=tools if index == 1 else None,
257+
add_generation_prompt=False,
258+
tokenize=True,
259+
return_dict=True,
260+
return_tensors="pt",
261+
**apply_chat_template_kwargs,
262+
)
263+
prefix_len = inputs_prev["input_ids"].shape[-1]
264+
inputs = {k: v[..., prefix_len:] for k, v in inputs_full.items()}
232265

233266
inputs = dict(inputs)
234267
input_ids = inputs.pop("input_ids")[0]
@@ -266,14 +299,16 @@ def _build_messages(self, example: dict):
266299

267300
image_offset, video_offset = 0, 0
268301
for message in messages:
269-
if self.image_key not in example and self.video_key not in example:
270-
continue
271-
assert self.processor is not None, "processor is needed to process image and video"
272-
273302
content = message["content"]
274303
if not isinstance(content, str):
275304
continue
276305

306+
if self.image_key not in example and self.video_key not in example:
307+
if self.processor is not None:
308+
message["content"] = [{"type": "text", "text": content}]
309+
continue
310+
assert self.processor is not None, "processor is needed to process image and video"
311+
277312
content_list = []
278313
segments = re.split("(<image>|<video>)", content)
279314
segments = [item for item in segments if item != ""]

verl/utils/megatron_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,15 @@ def load_megatron_model_to_gpu(models, load_grad=True):
459459
for buffer in buffers:
460460
# sometimes, we don't want to load grad for pure inference
461461
if load_grad and hasattr(buffer, "grad_data_size"):
462-
buffer.grad_data.storage().resize_(buffer.grad_data_size)
463-
buffer.grad_data.zero_()
462+
current_storage_size = buffer.grad_data.storage().size()
463+
if current_storage_size == 0 or current_storage_size == buffer.grad_data_size:
464+
buffer.grad_data.storage().resize_(buffer.grad_data_size)
465+
buffer.grad_data.zero_()
466+
else:
467+
# Non-standard layers (e.g. GatedDeltaNet) may have grad
468+
# buffers with mismatched storage size; skip resize and
469+
# zero in-place with current storage.
470+
buffer.grad_data.zero_()
464471

465472
if buffer.param_data.storage().size() == 0:
466473
buffer.param_data.storage().resize_(buffer.param_data_size)

verl/utils/model.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,24 @@
3030
AutoConfig,
3131
AutoModel,
3232
AutoModelForCausalLM,
33-
AutoModelForImageTextToText,
3433
AutoModelForSequenceClassification,
3534
AutoModelForTokenClassification,
36-
AutoModelForVision2Seq,
3735
GenerationConfig,
3836
MistralForSequenceClassification,
3937
PretrainedConfig,
4038
PreTrainedModel,
4139
)
40+
41+
try:
42+
from transformers import AutoModelForVision2Seq
43+
except ImportError:
44+
AutoModelForVision2Seq = None
45+
46+
try:
47+
from transformers import AutoModelForImageTextToText
48+
except ImportError:
49+
AutoModelForImageTextToText = AutoModelForVision2Seq
50+
4251
from transformers.modeling_outputs import CausalLMOutputWithPast
4352

4453
from verl.models.registry import ModelRegistry

verl/utils/tensordict_utils.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,20 +292,47 @@ def chunk_tensordict(td: TensorDict, chunks: int) -> list[TensorDict]:
292292
evenly divisible by chunks.
293293
294294
Note:
295-
This is a workaround for PyTorch issue #153238 where torch.chunk()
296-
doesn't support 3D jagged tensors (e.g., MRoPE position_ids).
297-
See: https://github.com/pytorch/pytorch/issues/153238
295+
PyTorch NestedTensor has issues with unbind/indexing on 2D and 3D
296+
jagged tensors: unbind() internally calls split_with_sizes() using the
297+
ragged lengths, but the underlying storage may be padded to a different
298+
size, causing a RuntimeError.
299+
- 3D+: https://github.com/pytorch/pytorch/issues/153238
300+
- 2D: select_int -> unbind -> split_with_sizes mismatch
301+
302+
For NestedTensors that can be chunked directly (regular batch dim with
303+
no ragged interaction), we use the standard TensorDict.chunk(). For
304+
those that cannot, we pad -> chunk -> unpad as a workaround.
298305
"""
299306
assert isinstance(td, TensorDict) and len(td) % chunks == 0, (
300307
f"expecting td with length divisible by chunks, but got {len(td)} and {chunks}"
301308
)
302309
chunk_size = len(td) // chunks
303-
keys = {key for key, val in td.items() if isinstance(val, torch.Tensor) and val.is_nested and val.dim() >= 3}
304-
new_td = TensorDict({k: v for k, v in td.items() if k not in keys}, batch_size=td.batch_size, device=td.device)
310+
nested_keys = {key for key, val in td.items() if isinstance(val, torch.Tensor) and val.is_nested}
311+
new_td = TensorDict(
312+
{k: v for k, v in td.items() if k not in nested_keys}, batch_size=td.batch_size, device=td.device
313+
)
305314

306315
tds = new_td.chunk(chunks=chunks)
307-
for key in keys:
308-
tensors = td[key].unbind(dim=0)
316+
for key in nested_keys:
317+
nt = td[key]
318+
# Try the fast path first: direct unbind works for some NestedTensor
319+
# layouts where the batch dim is not entangled with the ragged dim.
320+
try:
321+
tensors = nt.unbind(dim=0)
322+
except RuntimeError:
323+
# Fallback: pad -> chunk -> unpad. This avoids the PyTorch bug
324+
# where unbind/split_with_sizes fails because ragged lengths don't
325+
# match the (padded) storage size.
326+
padded = nt.to_padded_tensor(0)
327+
padded_chunks = padded.chunk(chunks, dim=0)
328+
offsets = nt.offsets()
329+
lengths = offsets.diff().tolist()
330+
for i, chunk_td in enumerate(tds):
331+
chunk_lengths = lengths[i * chunk_size : (i + 1) * chunk_size]
332+
chunk_tensors = [padded_chunks[i][j, :seq_len] for j, seq_len in enumerate(chunk_lengths)]
333+
chunk_td[key] = torch.nested.as_nested_tensor(chunk_tensors, layout=torch.jagged)
334+
continue
335+
309336
for i, chunk_td in enumerate(tds):
310337
chunk_td[key] = torch.nested.as_nested_tensor(
311338
tensors[i * chunk_size : (i + 1) * chunk_size], layout=torch.jagged

verl/workers/engine_workers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from tensordict import NonTensorData, TensorDict
2525
from torch.distributed.device_mesh import init_device_mesh
2626

27+
from verl.workers.config.engine import McoreEngineConfig
28+
2729
try:
2830
from verl.workers.engine.mindspeed.transformer_impl import repatch
2931
except ImportError:
@@ -98,9 +100,12 @@ def __init__(self, config: TrainingWorkerConfig):
98100
self.model_config, self.device_name
99101
)
100102

101-
# we use the one defined in model
102-
# TODO: this is not elegant and should refactor later
103-
self.engine_config.use_remove_padding = self.model_config.use_remove_padding
103+
# For Megatron engine, model.use_remove_padding (data pipeline) and
104+
# engine.use_remove_padding (compute format: thd vs bshd) may differ
105+
# (e.g. Qwen3.5 GDN requires bshd but still uses NestedTensor in data).
106+
# For other engines, keep the original behavior of syncing them.
107+
if not isinstance(self.engine_config, McoreEngineConfig):
108+
self.engine_config.use_remove_padding = self.model_config.use_remove_padding
104109
self.engine_config.use_fused_kernels = self.model_config.use_fused_kernels
105110

106111
if repatch is not None:

verl/workers/fsdp_workers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,17 @@ def _build_model_optimizer(
292292
AutoConfig,
293293
AutoModel,
294294
AutoModelForCausalLM,
295-
AutoModelForImageTextToText,
296-
AutoModelForVision2Seq,
297295
)
298296

297+
try:
298+
from transformers import AutoModelForVision2Seq
299+
except ImportError:
300+
AutoModelForVision2Seq = None
301+
try:
302+
from transformers import AutoModelForImageTextToText
303+
except ImportError:
304+
AutoModelForImageTextToText = AutoModelForVision2Seq
305+
299306
from verl.utils.model import get_generation_config, print_model_size, update_model_config
300307
from verl.utils.torch_dtypes import PrecisionType
301308

0 commit comments

Comments
 (0)