diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index 50eb4a128d3..d47c97f7752 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect +import types import warnings from enum import Enum +from functools import wraps import torch @@ -237,6 +240,30 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): return routing_probs, routing_map +def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float: + """Return the aux loss coeff for the given auxiliary loss type. + If the auxiliary loss type is not found, return 0.0. + """ + if isinstance(_self.routing_type, str): + if _self.routing_type == aux_loss_type: + return _self.config.moe_aux_loss_coeff + if isinstance(_self.routing_type, list): + try: + idx = _self.routing_type.index(aux_loss_type) + return _self.config.moe_aux_loss_coeff[idx] + except (ValueError, IndexError): + return 0.0 + return 0.0 + + +def _is_aux_loss_enabled(_self) -> bool: + """Check if the auxiliary loss is enabled.""" + for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]: + if _get_aux_loss_coeff(_self, aux_loss_type) > 0: + return True + return False + + def patched_routing(self, logits: torch.Tensor, *args, **kwargs): """Top-k routing function @@ -282,6 +309,8 @@ def patched_routing(self, logits: torch.Tensor, *args, **kwargs): pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, ) + if not hasattr(self, "is_aux_loss_enabled"): + self.is_aux_loss_enabled = types.MethodType(_is_aux_loss_enabled, self) # Apply each aux loss type and attach aux loss autograd function to probs if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled(): # Calculate scores and routing_map for aux loss @@ -311,26 +340,50 @@ def apply_router_replay_patch(): # Clear router instances to avoid state leakage between model initializations. RouterReplay.router_instances.clear() # Step 1: Patch TransformerConfig to include the feature flag - if not hasattr(TransformerConfig, "enable_routing_replay"): - # Add class attribute with default value - TransformerConfig.enable_routing_replay = False + try: + from megatron.training import get_args + + global_args = get_args() + except Exception: + global_args = None + + try: + sig = inspect.signature(TransformerConfig.__init__) + native_params = sig.parameters + except Exception: + native_params = {} + + ext_attrs = ["enable_routing_replay", "moe_router_fusion"] + + for attr in ext_attrs: + val = getattr(global_args, attr, False) if global_args else False + + if not hasattr(TransformerConfig, attr): + setattr(TransformerConfig, attr, val) + if not hasattr(TransformerConfig, "_verl_router_patched"): # Store original __init__ method original_tf_config_init = TransformerConfig.__init__ # Define new __init__ method that safely handles enable_routing_replay parameter + @wraps(original_tf_config_init) def patched_tf_config_init(self, *args, **kwargs): # Simple solution: remove the unknown parameter before calling original constructor - enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) + if "enable_routing_replay" not in native_params: + enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) + if "moe_router_fusion" not in native_params: + moe_router_fusion = kwargs.pop("moe_router_fusion", TransformerConfig.moe_router_fusion) # Call original constructor with remaining kwargs original_tf_config_init(self, *args, **kwargs) # Set the instance attribute self.enable_routing_replay = enable_routing_replay + self.moe_router_fusion = moe_router_fusion # Apply the patch TransformerConfig.__init__ = patched_tf_config_init + TransformerConfig._verl_router_patched = True # Step 2: Patch TopKRouter only once to ensure idempotency. if hasattr(TopKRouter, "_router_replay_patched"): diff --git a/verl/utils/megatron/router_replay_utils.py b/verl/utils/megatron/router_replay_utils.py index 3aec85c24b7..c0fadd3505b 100644 --- a/verl/utils/megatron/router_replay_utils.py +++ b/verl/utils/megatron/router_replay_utils.py @@ -17,6 +17,7 @@ Utilities for handling router replay functionality in Megatron models. """ +import inspect import warnings from typing import Optional @@ -330,7 +331,13 @@ def get_current_rank_layer_info(tf_config, vp_rank=None): if vp_rank is None: vp_rank = 0 num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage=vp_rank) - offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank) + + sig = inspect.signature(get_transformer_layer_offset) + + if "vp_stage" in sig.parameters: + offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank) + else: + offset = get_transformer_layer_offset(tf_config) local = {} local["start"] = offset local["end"] = offset + num_layers_to_build diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 4aefb526c51..c3611dd6880 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -521,7 +521,12 @@ async def generate( max_tokens = sampling_params.pop("max_new_tokens") else: # Default to a calculation that considers configured lengths - max_tokens = self.config.response_length + self.config.prompt_length - len(prompt_ids) + if self.config.enable_rollout_routing_replay: + # When routing replay is enabled, we strictly use the configured response length + max_tokens = self.config.response_length + else: + # Otherwise, use a dynamic calculation based on prompt length + max_tokens = self.config.response_length + self.config.prompt_length - len(prompt_ids) # Clamp max_tokens to the valid range [0, max_possible_tokens] max_tokens = max(0, min(max_tokens, max_possible_tokens))