Skip to content

Commit ea042c2

Browse files
authored
[doc, worker] feat: Enable Megatron-Bridge for MTP (#5323)
### What does this PR do? There's nothing specific in Megatron-Bridge that stops MTP support. NVIDIA-NeMo/Megatron-Bridge#2387 adds MiMo dense MTP models bridge support so that `examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron.sh` can also be used together with Megatron-Bridge (setting `vanilla_mbridge` to be `False`). ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [X] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. --------- Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent 3481d6e commit ea042c2

File tree

4 files changed

+117
-64
lines changed

4 files changed

+117
-64
lines changed

docs/advance/mtp.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@
22

33
**Author**: `https://github.com/meituan-search`
44

5-
Last updated: 01/30/2026
5+
Last updated: 02/15/2026
66

77
# 1. Scope of Support
88

99
Currently, RL training can be performed on mimo-7B-RL, Qwen-next, and Deepseek series models based on the MTP architecture. The support rules for training and inference engines are as follows:
1010

11-
- **Training Engine**: Only supports the `mbridge + megatron` combination; other training engines are not compatible at this time;
11+
- **Training Engine**: Only supports the `mbridge/Megatron-Bridge + megatron` combination; other training engines are not compatible at this time;
1212

1313
- **Inference Engine**: Compatible with all engines, but the model must be in the corresponding engine's compatibility list;
1414

1515
- **Dependency Versions**:
1616

17-
- mbridge: Use the specified branch: [https://github.com/ArronHZG/mbridge/tree/feature/verl_mtp](https://github.com/ArronHZG/mbridge/tree/feature/verl_mtp) (will be merged into the main branch in the future);
17+
- mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (will be merged into the main branch in the future);
18+
19+
- Megatron-Bridge: Apply the patches and review suggestions from PR if you want to try out mimo-7B-RL: [#2387](https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/2387) (will be merged into the main branch in the future);
1820

1921
- megatron: Use the latest dev version (commit: [23e092f41ec8bc659020e401ddac9576c1cfed7e](https://github.com/NVIDIA/Megatron-LM/tree/23e092f41ec8bc659020e401ddac9576c1cfed7e)), which supports MTP + CP training methods.
2022

verl/models/mcore/mtp_patch.py

Lines changed: 107 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@
2020
import torch
2121
from megatron.core import parallel_state
2222
from megatron.core.models.gpt.gpt_model import GPTModel
23-
from megatron.core.transformer.multi_token_prediction import (
24-
MTPLossAutoScaler,
25-
MTPLossLoggingHelper,
26-
roll_tensor,
27-
)
23+
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor
2824

2925
try:
3026
from megatron.core.utils import unwrap_model
@@ -78,19 +74,45 @@ def _megatron_gptmodel_postprocess(
7874
runtime_gather_output=None,
7975
extra_block_kwargs=None,
8076
inference_context=None,
77+
**kwargs,
8178
):
82-
"""Postprocesses decoder hidden states to generate logits or compute loss.
79+
"""Compatibility patch for GPTModel._postprocess.
8380
84-
Applies Multi-Token Prediction if enabled, generates output logits through
85-
the output layer, and computes language model loss when labels are provided.
81+
For inference (`labels is None`), delegate to the upstream implementation to stay
82+
aligned with Megatron-Core updates.
83+
84+
For training (`labels is not None`), keep VERL's MTP behavior and always return
85+
logits (instead of CE loss) so PPO paths can compute custom losses from logits.
8686
"""
87+
# Keep inference path aligned with whatever upstream Megatron currently expects.
88+
if labels is None:
89+
return self._postprocess_backup(
90+
hidden_states=hidden_states,
91+
input_ids=input_ids,
92+
position_ids=position_ids,
93+
labels=labels,
94+
rotary_pos_emb=rotary_pos_emb,
95+
rotary_pos_cos=rotary_pos_cos,
96+
rotary_pos_sin=rotary_pos_sin,
97+
mtp_in_postprocess=mtp_in_postprocess,
98+
loss_mask=loss_mask,
99+
decoder_input=decoder_input,
100+
attention_mask=attention_mask,
101+
inference_params=inference_params,
102+
packed_seq_params=packed_seq_params,
103+
sequence_len_offset=sequence_len_offset,
104+
runtime_gather_output=runtime_gather_output,
105+
extra_block_kwargs=extra_block_kwargs,
106+
inference_context=inference_context,
107+
**kwargs,
108+
)
87109

88-
# logits and loss
110+
# Training path: keep logits for external loss computation.
89111
output_weight = None
90112
if self.share_embeddings_and_output_weights:
91113
output_weight = self.shared_embedding_or_output_weight()
92114

93-
if mtp_in_postprocess and labels is not None:
115+
if mtp_in_postprocess:
94116
hidden_states = self.mtp(
95117
input_ids=input_ids,
96118
position_ids=position_ids,
@@ -109,60 +131,85 @@ def _megatron_gptmodel_postprocess(
109131
if not self.post_process:
110132
return hidden_states
111133

112-
# Skip when mtp_num_layers is None or 0
113-
if self.config.mtp_num_layers and labels is not None:
114-
mtp_labels = labels.clone()
115-
116-
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
117-
hidden_states = hidden_states_list[0]
118-
if loss_mask is None:
119-
# if loss_mask is not provided, use all ones as loss_mask
120-
loss_mask = torch.ones_like(mtp_labels)
121-
for mtp_layer_number in range(self.config.mtp_num_layers):
122-
# Calc loss for the current Multi-Token Prediction (MTP) layers.
123-
mtp_labels, _ = roll_tensor(
124-
mtp_labels,
125-
shifts=-1,
126-
dims=-1,
127-
cp_group=self.cp_group,
134+
# Skip when mtp_num_layers is None or 0.
135+
if self.config.mtp_num_layers:
136+
cp_group = None
137+
if getattr(self, "pg_collection", None) is not None:
138+
cp_group = self.pg_collection.cp
139+
elif hasattr(self, "cp_group"):
140+
cp_group = self.cp_group
141+
142+
# Prefer upstream helper when available (newer Megatron-LM).
143+
try:
144+
from megatron.core.transformer.multi_token_prediction import process_mtp_loss
145+
146+
hidden_states = process_mtp_loss(
147+
hidden_states=hidden_states,
148+
labels=labels,
149+
loss_mask=loss_mask,
150+
output_layer=self.output_layer,
151+
output_weight=output_weight,
152+
runtime_gather_output=runtime_gather_output,
153+
is_training=self.training,
154+
compute_language_model_loss=self.compute_language_model_loss,
155+
config=self.config,
156+
cp_group=cp_group,
128157
packed_seq_params=packed_seq_params,
129158
)
130-
loss_mask, num_tokens = roll_tensor(
131-
loss_mask,
132-
shifts=-1,
133-
dims=-1,
134-
cp_group=self.cp_group,
135-
packed_seq_params=packed_seq_params,
136-
)
137-
138-
# Compute mtp loss without storing logits to save memory.
139-
mtp_loss = self.compute_output_layer_and_language_model_loss(
140-
hidden_states_list[mtp_layer_number + 1],
141-
labels=mtp_labels,
142-
weight=self.shared_embedding_or_output_weight(),
143-
sequence_parallel_enabled=self.output_layer.sequence_parallel,
144-
column_parallel_linear=self.output_layer,
145-
col_linear_kwargs={
146-
"weight": output_weight,
147-
"runtime_gather_output": runtime_gather_output,
148-
},
149-
)
159+
except (ImportError, AttributeError, TypeError):
160+
# Fallback for older Megatron-LM versions without process_mtp_loss API.
161+
mtp_labels = labels.clone()
162+
163+
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
164+
hidden_states = hidden_states_list[0]
165+
if loss_mask is None:
166+
# if loss_mask is not provided, use all ones as loss_mask
167+
loss_mask = torch.ones_like(mtp_labels)
168+
for mtp_layer_number in range(self.config.mtp_num_layers):
169+
# Calc loss for the current Multi-Token Prediction (MTP) layers.
170+
mtp_labels, _ = roll_tensor(
171+
mtp_labels,
172+
shifts=-1,
173+
dims=-1,
174+
cp_group=self.cp_group,
175+
packed_seq_params=packed_seq_params,
176+
)
177+
loss_mask, num_tokens = roll_tensor(
178+
loss_mask,
179+
shifts=-1,
180+
dims=-1,
181+
cp_group=self.cp_group,
182+
packed_seq_params=packed_seq_params,
183+
)
150184

151-
mtp_loss = loss_mask * mtp_loss
152-
if self.training:
153-
# TODO(shifangx): remove the use of parallel_state here
154-
# after moving loss logging to loss_func in pretrain_gpt.py
155-
MTPLossLoggingHelper.save_loss_to_tracker(
156-
torch.sum(mtp_loss) / num_tokens,
157-
mtp_layer_number,
158-
self.config.mtp_num_layers,
159-
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
185+
# Compute mtp loss without storing logits to save memory.
186+
mtp_loss = self.compute_output_layer_and_language_model_loss(
187+
hidden_states_list[mtp_layer_number + 1],
188+
labels=mtp_labels,
189+
weight=self.shared_embedding_or_output_weight(),
190+
sequence_parallel_enabled=self.output_layer.sequence_parallel,
191+
column_parallel_linear=self.output_layer,
192+
col_linear_kwargs={
193+
"weight": output_weight,
194+
"runtime_gather_output": runtime_gather_output,
195+
},
160196
)
161-
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
162-
if self.config.calculate_per_token_loss:
163-
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
164-
else:
165-
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
197+
198+
mtp_loss = loss_mask * mtp_loss
199+
if self.training:
200+
# TODO(shifangx): remove the use of parallel_state here
201+
# after moving loss logging to loss_func in pretrain_gpt.py
202+
MTPLossLoggingHelper.save_loss_to_tracker(
203+
torch.sum(mtp_loss) / num_tokens,
204+
mtp_layer_number,
205+
self.config.mtp_num_layers,
206+
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
207+
)
208+
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
209+
if self.config.calculate_per_token_loss:
210+
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
211+
else:
212+
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
166213

167214
logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
168215
# [s b h] => [b s h]

verl/workers/actor/megatron_actor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def __init__(
139139
assert self.mtp_config.enable, "MTP requires mtp_config.enable to be True"
140140

141141
self.use_fused_kernels = self.config.get("use_fused_kernels", False)
142+
if getattr(self.mtp_config, "enable", False) and self.use_fused_kernels:
143+
self.use_fused_kernels = False
144+
logger.warning_once(
145+
"MTP is not compatible with fused kernels for now. Automatically disable use_fused_kernels."
146+
)
142147
if self.use_fused_kernels and not getattr(self.config, "overlap_moe_expert_parallel_comm", False):
143148
# do not patch if overlap_moe_expert_parallel_comm is enabled
144149
logger.warning_once(

verl/workers/megatron_workers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def _init_hf_config_and_tf_config(
155155
if enable_mtp:
156156
assert hf_config.num_nextn_predict_layers > 0, "MTP requires at least one nextn_predict_layer"
157157
assert megatron_config.use_mbridge, "MTP requires use_mbridge to be True"
158-
assert megatron_config.vanilla_mbridge, "MTP requires vanilla_mbridge to be True"
159158
override_transformer_config["mtp_loss_scaling_factor"] = self.config.model.mtp.mtp_loss_scaling_factor
160159
else:
161160
if hasattr(hf_config, "num_nextn_predict_layers"):

0 commit comments

Comments
 (0)