Skip to content

[algo] feat: support router replay in MegatronEngine#5219

Open
xhx1022 wants to merge 5 commits intoverl-project:mainfrom
xhx1022:r2_engine
Open

[algo] feat: support router replay in MegatronEngine#5219
xhx1022 wants to merge 5 commits intoverl-project:mainfrom
xhx1022:r2_engine

Conversation

@xhx1022
Copy link
Collaborator

@xhx1022 xhx1022 commented Feb 6, 2026

What does this PR do?

This PR introduces Router Replay support within the MegatronEngine, enabling the router computed in compute_log_logp to be reused by update_actor.

Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: xhx1022 <1737006628@qq.com>
@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces router replay support in the MegatronEngine, a valuable feature for Mixture-of-Experts models that enables recording and reusing router decisions to ensure deterministic training. The changes are well-structured, touching the PPO trainer, Megatron engine, and supporting utility files, and correctly add support for nested tensors and virtual pipeline parallelism. I've identified one critical issue in the pp_gather utility function related to handling uneven layer distribution in pipeline parallelism, which could lead to runtime errors. My feedback is focused on addressing this to improve the robustness of the implementation.

Comment on lines +378 to 391
layers_topk_idx_global_list = [
torch.empty(
size=local_layers_router_map.shape,
dtype=local_layers_router_map.dtype,
device=local_layers_router_map.device,
)
for _ in range(world_size)
]
torch.distributed.all_gather(
tensor=local_layers_router_map,
tensor_list=layers_topk_idx_global_list,
group=pp_group,
async_op=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of torch.distributed.all_gather assumes that local_layers_router_map has the same shape across all pipeline parallel ranks. However, with uneven pipeline parallelism (which is supported by get_num_layers_to_build), the number of layers per rank can differ, leading to different tensor shapes. This will cause all_gather to fail with a shape mismatch error.

This issue is also hinted at in the TODO at line 355. To make this function robust to uneven layer distributions, torch.distributed.all_gather_object should be used instead, as it can handle tensors of varying sizes. Note that the subsequent torch.cat at line 413 will also fail with tensors of different shapes and will need to be adjusted to handle this case when VPP is not enabled.

        layers_topk_idx_global_list = [None] * world_size
        torch.distributed.all_gather_object(layers_topk_idx_global_list, local_layers_router_map, pp_group)

@xhx1022 xhx1022 changed the title [WIP algo] feat: support router replay in MegatronEngine [algo] feat: support router replay in MegatronEngine Feb 6, 2026
Signed-off-by: xhx1022 <1737006628@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants