From 59c2206f5c57380b63853df45a7e34145d3ebb18 Mon Sep 17 00:00:00 2001 From: 755651978 <755651978@qq.com> Date: Thu, 12 Feb 2026 14:49:10 +0800 Subject: [PATCH 1/8] Support routing replay for NPU --- .../agent_loop/tool_agent_loop.py | 3 + verl/utils/megatron/router_replay_patch.py | 56 +++++++++++++++++-- verl/utils/megatron/router_replay_utils.py | 9 ++- .../rollout/vllm_rollout/vllm_async_server.py | 2 +- 4 files changed, 63 insertions(+), 7 deletions(-) diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py index f98485a6781..e96076d88fb 100644 --- a/verl/experimental/agent_loop/tool_agent_loop.py +++ b/verl/experimental/agent_loop/tool_agent_loop.py @@ -193,11 +193,14 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu multi_modal_data["images"] = agent_data.image_data if agent_data.video_data is not None: multi_modal_data["videos"] = agent_data.video_data + + routed_experts = getattr(agent_data, "routed_experts", None) output = AgentLoopOutput( prompt_ids=prompt_ids, response_ids=response_ids[: self.response_length], response_mask=agent_data.response_mask[: self.response_length], multi_modal_data=multi_modal_data, + routed_experts=routed_experts, response_logprobs=agent_data.response_logprobs[: self.response_length] if agent_data.response_logprobs else None, diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index 50eb4a128d3..1fe15c68466 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -15,6 +15,8 @@ from enum import Enum import torch +import types +import inspect try: from megatron.core.transformer.moe.moe_utils import ( @@ -236,7 +238,27 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): return routing_probs, routing_map - +def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float: + """获取给定辅助损失类型的系数。""" + # 逻辑保持不变 + if isinstance(_self.routing_type, str): + if _self.routing_type == aux_loss_type: + return _self.config.moe_aux_loss_coeff + if isinstance(_self.routing_type, list): + try: + idx = _self.routing_type.index(aux_loss_type) + return _self.config.moe_aux_loss_coeff[idx] + except (ValueError, IndexError): + return 0.0 + return 0.0 + +def _is_aux_loss_enabled(_self) -> bool: + """检查是否启用了任何辅助损失。""" + for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]: + # 注意这里调用的是同在模块级别的另一个辅助函数 + if _get_aux_loss_coeff(_self, aux_loss_type) > 0: + return True + return False def patched_routing(self, logits: torch.Tensor, *args, **kwargs): """Top-k routing function @@ -282,6 +304,8 @@ def patched_routing(self, logits: torch.Tensor, *args, **kwargs): pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, ) + if not hasattr(self, "is_aux_loss_enabled"): + self.is_aux_loss_enabled = types.MethodType(_is_aux_loss_enabled, self) # Apply each aux loss type and attach aux loss autograd function to probs if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled(): # Calculate scores and routing_map for aux loss @@ -311,26 +335,48 @@ def apply_router_replay_patch(): # Clear router instances to avoid state leakage between model initializations. RouterReplay.router_instances.clear() # Step 1: Patch TransformerConfig to include the feature flag - if not hasattr(TransformerConfig, "enable_routing_replay"): - # Add class attribute with default value - TransformerConfig.enable_routing_replay = False + try: + global_args = get_args() + except Exception: + global_args = None + + try: + sig = inspect.signature(TransformerConfig.__init__) + native_params = sig.parameters + except Exception: + native_params = [] + + ext_attrs = ["enable_routing_replay", "moe_router_fusion"] + + for attr in ext_attrs: + val = getattr(global_args, attr, False) if global_args else False + + if not hasattr(TransformerConfig, attr): + setattr(TransformerConfig, attr, val) + if not hasattr(TransformerConfig, "_verl_router_patched"): # Store original __init__ method original_tf_config_init = TransformerConfig.__init__ # Define new __init__ method that safely handles enable_routing_replay parameter + @wraps(original_tf_config_init) def patched_tf_config_init(self, *args, **kwargs): # Simple solution: remove the unknown parameter before calling original constructor - enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) + if "enable_routing_replay" not in native_params: + enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) + if "moe_router_fusion" not in native_params: + moe_router_fusion = kwargs.pop("enable_routing_replay", TransformerConfig.moe_router_fusion) # Call original constructor with remaining kwargs original_tf_config_init(self, *args, **kwargs) # Set the instance attribute self.enable_routing_replay = enable_routing_replay + self.moe_router_fusion = moe_router_fusion # Apply the patch TransformerConfig.__init__ = patched_tf_config_init + TransformerConfig._verl_router_patched = True # Step 2: Patch TopKRouter only once to ensure idempotency. if hasattr(TopKRouter, "_router_replay_patched"): diff --git a/verl/utils/megatron/router_replay_utils.py b/verl/utils/megatron/router_replay_utils.py index 3aec85c24b7..6b8c12ec07e 100644 --- a/verl/utils/megatron/router_replay_utils.py +++ b/verl/utils/megatron/router_replay_utils.py @@ -21,6 +21,7 @@ from typing import Optional import torch +import inspect try: from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage @@ -330,7 +331,13 @@ def get_current_rank_layer_info(tf_config, vp_rank=None): if vp_rank is None: vp_rank = 0 num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage=vp_rank) - offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank) + + sig = inspect.signature(get_transformer_layer_offset) + + if 'vp_stage' in sig.parameters: + offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank) + else: + offset = get_transformer_layer_offset(tf_config) local = {} local["start"] = offset local["end"] = offset + num_layers_to_build diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index cf5ab342888..7bbc3075cae 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -507,7 +507,7 @@ async def generate( max_tokens = sampling_params.pop("max_new_tokens") else: # Default to a calculation that considers configured lengths - max_tokens = self.config.response_length + self.config.prompt_length - len(prompt_ids) + max_tokens = self.config.response_length # Clamp max_tokens to the valid range [0, max_possible_tokens] max_tokens = max(0, min(max_tokens, max_possible_tokens)) From 2b5acf55eef097bac7913c6cfbc217e26ad6464e Mon Sep 17 00:00:00 2001 From: 755651978 <755651978@qq.com> Date: Thu, 12 Feb 2026 15:43:23 +0800 Subject: [PATCH 2/8] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- verl/utils/megatron/router_replay_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index 1fe15c68466..eb6934213be 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -366,7 +366,7 @@ def patched_tf_config_init(self, *args, **kwargs): enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) if "moe_router_fusion" not in native_params: - moe_router_fusion = kwargs.pop("enable_routing_replay", TransformerConfig.moe_router_fusion) + moe_router_fusion = kwargs.pop("moe_router_fusion", TransformerConfig.moe_router_fusion) # Call original constructor with remaining kwargs original_tf_config_init(self, *args, **kwargs) From 77af5c70a4a50aa98bbefaf0af0c5093a49a9a7d Mon Sep 17 00:00:00 2001 From: 755651978 <755651978@qq.com> Date: Thu, 12 Feb 2026 22:02:13 +0800 Subject: [PATCH 3/8] fix: import wraps in router_replay_patch --- verl/utils/megatron/router_replay_patch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index 1fe15c68466..72e2c9c3d5f 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -17,6 +17,7 @@ import torch import types import inspect +from functools import wraps try: from megatron.core.transformer.moe.moe_utils import ( From b44262f4d79579dbc960cabce9daf16fd37049e6 Mon Sep 17 00:00:00 2001 From: 755651978 <755651978@qq.com> Date: Tue, 24 Feb 2026 14:21:07 +0800 Subject: [PATCH 4/8] feat: support routing replay in vllm rollout and update code standards - Added dynamic max_tokens logic and attribute safety checks. - Synchronized code comments to English and fixed import dependencies. --- verl/utils/megatron/router_replay_patch.py | 9 +++++---- verl/workers/rollout/vllm_rollout/vllm_async_server.py | 7 ++++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index bec7282682a..02c235d067d 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -31,6 +31,7 @@ MoEAlltoAllTokenDispatcher = None from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.training import get_args # https://github.com/THUDM/slime/blob/main/slime/utils/routing_replay.py @@ -240,8 +241,9 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): return routing_probs, routing_map def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float: - """获取给定辅助损失类型的系数。""" - # 逻辑保持不变 + """Return the aux loss coeff for the given auxiliary loss type. + If the auxiliary loss type is not found, return 0.0. + """ if isinstance(_self.routing_type, str): if _self.routing_type == aux_loss_type: return _self.config.moe_aux_loss_coeff @@ -254,9 +256,8 @@ def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float: return 0.0 def _is_aux_loss_enabled(_self) -> bool: - """检查是否启用了任何辅助损失。""" + """Check if the auxiliary loss is enabled.""" for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]: - # 注意这里调用的是同在模块级别的另一个辅助函数 if _get_aux_loss_coeff(_self, aux_loss_type) > 0: return True return False diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 7bbc3075cae..cd3e95b3598 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -507,7 +507,12 @@ async def generate( max_tokens = sampling_params.pop("max_new_tokens") else: # Default to a calculation that considers configured lengths - max_tokens = self.config.response_length + if self.config.enable_rollout_routing_replay: + # When routing replay is enabled, we strictly use the configured response length + max_tokens = self.config.response_length + else: + # Otherwise, use a dynamic calculation based on prompt length + max_tokens = self.config.response_length + self.config.prompt_length - len(prompt_ids) # Clamp max_tokens to the valid range [0, max_possible_tokens] max_tokens = max(0, min(max_tokens, max_possible_tokens)) From 1e035b575b656b9d17f8d7fae480e533874358b0 Mon Sep 17 00:00:00 2001 From: 755651978 <755651978@qq.com> Date: Tue, 24 Feb 2026 14:49:14 +0800 Subject: [PATCH 5/8] fix: port routed_experts fix from official c6255ae --- verl/experimental/agent_loop/tool_agent_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py index e96076d88fb..ee6176775e0 100644 --- a/verl/experimental/agent_loop/tool_agent_loop.py +++ b/verl/experimental/agent_loop/tool_agent_loop.py @@ -88,6 +88,8 @@ def __init__( # Temporary state for tool calls self.tool_calls: list[FunctionCall] = [] + self.routed_experts = None + # Extra fields for dynamic addition, e.g., tool session data self.extra_fields: dict[str, Any] = {} @@ -193,19 +195,17 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu multi_modal_data["images"] = agent_data.image_data if agent_data.video_data is not None: multi_modal_data["videos"] = agent_data.video_data - - routed_experts = getattr(agent_data, "routed_experts", None) output = AgentLoopOutput( prompt_ids=prompt_ids, response_ids=response_ids[: self.response_length], response_mask=agent_data.response_mask[: self.response_length], multi_modal_data=multi_modal_data, - routed_experts=routed_experts, response_logprobs=agent_data.response_logprobs[: self.response_length] if agent_data.response_logprobs else None, num_turns=agent_data.user_turns + agent_data.assistant_turns + 1, metrics=agent_data.metrics, + routed_experts=agent_data.routed_experts, extra_fields={}, ) output.extra_fields.update({"turn_scores": agent_data.turn_scores, "tool_rewards": agent_data.tool_rewards}) From 4c03e9bbf0e800114140cace38fac2a8e070a648 Mon Sep 17 00:00:00 2001 From: 755651978 <755651978@qq.com> Date: Wed, 25 Feb 2026 14:21:02 +0800 Subject: [PATCH 6/8] fix: resolve ModuleNotFoundError for megatron in CI environments --- verl/utils/megatron/router_replay_patch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index 02c235d067d..850347af000 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -31,7 +31,6 @@ MoEAlltoAllTokenDispatcher = None from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training import get_args # https://github.com/THUDM/slime/blob/main/slime/utils/routing_replay.py @@ -338,6 +337,7 @@ def apply_router_replay_patch(): RouterReplay.router_instances.clear() # Step 1: Patch TransformerConfig to include the feature flag try: + from megatron.training import get_args global_args = get_args() except Exception: global_args = None @@ -346,7 +346,7 @@ def apply_router_replay_patch(): sig = inspect.signature(TransformerConfig.__init__) native_params = sig.parameters except Exception: - native_params = [] + native_params = {} ext_attrs = ["enable_routing_replay", "moe_router_fusion"] From cda2cd79a713125377f85ce479a8341454cca451 Mon Sep 17 00:00:00 2001 From: 755651978 <755651978@qq.com> Date: Thu, 26 Feb 2026 18:59:42 +0800 Subject: [PATCH 7/8] style: fix linting issues with pre-commit --- verl/utils/megatron/router_replay_patch.py | 10 +++++++--- verl/utils/megatron/router_replay_utils.py | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index 850347af000..074430dff61 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect +import types import warnings from enum import Enum +from functools import wraps import torch -import types -import inspect -from functools import wraps try: from megatron.core.transformer.moe.moe_utils import ( @@ -239,6 +239,7 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): return routing_probs, routing_map + def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float: """Return the aux loss coeff for the given auxiliary loss type. If the auxiliary loss type is not found, return 0.0. @@ -254,12 +255,15 @@ def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float: return 0.0 return 0.0 + def _is_aux_loss_enabled(_self) -> bool: """Check if the auxiliary loss is enabled.""" for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]: if _get_aux_loss_coeff(_self, aux_loss_type) > 0: return True return False + + def patched_routing(self, logits: torch.Tensor, *args, **kwargs): """Top-k routing function diff --git a/verl/utils/megatron/router_replay_utils.py b/verl/utils/megatron/router_replay_utils.py index 6b8c12ec07e..c0fadd3505b 100644 --- a/verl/utils/megatron/router_replay_utils.py +++ b/verl/utils/megatron/router_replay_utils.py @@ -17,11 +17,11 @@ Utilities for handling router replay functionality in Megatron models. """ +import inspect import warnings from typing import Optional import torch -import inspect try: from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage @@ -334,7 +334,7 @@ def get_current_rank_layer_info(tf_config, vp_rank=None): sig = inspect.signature(get_transformer_layer_offset) - if 'vp_stage' in sig.parameters: + if "vp_stage" in sig.parameters: offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank) else: offset = get_transformer_layer_offset(tf_config) From dc41f8e46dbffc320fd65e5e99789689dd34641d Mon Sep 17 00:00:00 2001 From: 755651978 <755651978@qq.com> Date: Fri, 27 Feb 2026 08:52:40 +0800 Subject: [PATCH 8/8] style: fix linting issues with pre-commit --- verl/utils/megatron/router_replay_patch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index 074430dff61..d47c97f7752 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -342,6 +342,7 @@ def apply_router_replay_patch(): # Step 1: Patch TransformerConfig to include the feature flag try: from megatron.training import get_args + global_args = get_args() except Exception: global_args = None