[algo] feat: support router replay in MegatronEngine#5219
[algo] feat: support router replay in MegatronEngine#5219xhx1022 wants to merge 5 commits intoverl-project:mainfrom
Conversation
|
|
There was a problem hiding this comment.
Code Review
This pull request introduces router replay support in the MegatronEngine, a valuable feature for Mixture-of-Experts models that enables recording and reusing router decisions to ensure deterministic training. The changes are well-structured, touching the PPO trainer, Megatron engine, and supporting utility files, and correctly add support for nested tensors and virtual pipeline parallelism. I've identified one critical issue in the pp_gather utility function related to handling uneven layer distribution in pipeline parallelism, which could lead to runtime errors. My feedback is focused on addressing this to improve the robustness of the implementation.
| layers_topk_idx_global_list = [ | ||
| torch.empty( | ||
| size=local_layers_router_map.shape, | ||
| dtype=local_layers_router_map.dtype, | ||
| device=local_layers_router_map.device, | ||
| ) | ||
| for _ in range(world_size) | ||
| ] | ||
| torch.distributed.all_gather( | ||
| tensor=local_layers_router_map, | ||
| tensor_list=layers_topk_idx_global_list, | ||
| group=pp_group, | ||
| async_op=False, | ||
| ) |
There was a problem hiding this comment.
The use of torch.distributed.all_gather assumes that local_layers_router_map has the same shape across all pipeline parallel ranks. However, with uneven pipeline parallelism (which is supported by get_num_layers_to_build), the number of layers per rank can differ, leading to different tensor shapes. This will cause all_gather to fail with a shape mismatch error.
This issue is also hinted at in the TODO at line 355. To make this function robust to uneven layer distributions, torch.distributed.all_gather_object should be used instead, as it can handle tensors of varying sizes. Note that the subsequent torch.cat at line 413 will also fail with tensors of different shapes and will need to be adjusted to handle this case when VPP is not enabled.
layers_topk_idx_global_list = [None] * world_size
torch.distributed.all_gather_object(layers_topk_idx_global_list, local_layers_router_map, pp_group)Signed-off-by: xhx1022 <1737006628@qq.com>
What does this PR do?
This PR introduces Router Replay support within the MegatronEngine, enabling the router computed in compute_log_logp to be reused by update_actor.