Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions verl/utils/megatron/router_replay_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import types
import warnings
from enum import Enum
from functools import wraps

import torch

Expand Down Expand Up @@ -237,6 +240,30 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):
return routing_probs, routing_map


def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float:
"""Return the aux loss coeff for the given auxiliary loss type.
If the auxiliary loss type is not found, return 0.0.
"""
if isinstance(_self.routing_type, str):
if _self.routing_type == aux_loss_type:
return _self.config.moe_aux_loss_coeff
if isinstance(_self.routing_type, list):
try:
idx = _self.routing_type.index(aux_loss_type)
return _self.config.moe_aux_loss_coeff[idx]
except (ValueError, IndexError):
return 0.0
return 0.0


def _is_aux_loss_enabled(_self) -> bool:
"""Check if the auxiliary loss is enabled."""
for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]:
if _get_aux_loss_coeff(_self, aux_loss_type) > 0:
return True
return False


def patched_routing(self, logits: torch.Tensor, *args, **kwargs):
"""Top-k routing function

Expand Down Expand Up @@ -282,6 +309,8 @@ def patched_routing(self, logits: torch.Tensor, *args, **kwargs):
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
)

if not hasattr(self, "is_aux_loss_enabled"):
self.is_aux_loss_enabled = types.MethodType(_is_aux_loss_enabled, self)
# Apply each aux loss type and attach aux loss autograd function to probs
if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled():
# Calculate scores and routing_map for aux loss
Expand Down Expand Up @@ -311,26 +340,50 @@ def apply_router_replay_patch():
# Clear router instances to avoid state leakage between model initializations.
RouterReplay.router_instances.clear()
# Step 1: Patch TransformerConfig to include the feature flag
if not hasattr(TransformerConfig, "enable_routing_replay"):
# Add class attribute with default value
TransformerConfig.enable_routing_replay = False
try:
from megatron.training import get_args

global_args = get_args()
except Exception:
global_args = None

try:
sig = inspect.signature(TransformerConfig.__init__)
native_params = sig.parameters
except Exception:
native_params = {}

ext_attrs = ["enable_routing_replay", "moe_router_fusion"]

for attr in ext_attrs:
val = getattr(global_args, attr, False) if global_args else False

if not hasattr(TransformerConfig, attr):
setattr(TransformerConfig, attr, val)

if not hasattr(TransformerConfig, "_verl_router_patched"):
# Store original __init__ method
original_tf_config_init = TransformerConfig.__init__

# Define new __init__ method that safely handles enable_routing_replay parameter
@wraps(original_tf_config_init)
def patched_tf_config_init(self, *args, **kwargs):
# Simple solution: remove the unknown parameter before calling original constructor
enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay)
if "enable_routing_replay" not in native_params:
enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay)

if "moe_router_fusion" not in native_params:
moe_router_fusion = kwargs.pop("moe_router_fusion", TransformerConfig.moe_router_fusion)
# Call original constructor with remaining kwargs
original_tf_config_init(self, *args, **kwargs)

# Set the instance attribute
self.enable_routing_replay = enable_routing_replay
self.moe_router_fusion = moe_router_fusion

# Apply the patch
TransformerConfig.__init__ = patched_tf_config_init
TransformerConfig._verl_router_patched = True

# Step 2: Patch TopKRouter only once to ensure idempotency.
if hasattr(TopKRouter, "_router_replay_patched"):
Expand Down
9 changes: 8 additions & 1 deletion verl/utils/megatron/router_replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Utilities for handling router replay functionality in Megatron models.
"""

import inspect
import warnings
from typing import Optional

Expand Down Expand Up @@ -330,7 +331,13 @@ def get_current_rank_layer_info(tf_config, vp_rank=None):
if vp_rank is None:
vp_rank = 0
num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage=vp_rank)
offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank)

sig = inspect.signature(get_transformer_layer_offset)

if "vp_stage" in sig.parameters:
offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank)
else:
offset = get_transformer_layer_offset(tf_config)
local = {}
local["start"] = offset
local["end"] = offset + num_layers_to_build
Expand Down
7 changes: 6 additions & 1 deletion verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,12 @@ async def generate(
max_tokens = sampling_params.pop("max_new_tokens")
else:
# Default to a calculation that considers configured lengths
max_tokens = self.config.response_length + self.config.prompt_length - len(prompt_ids)
if self.config.enable_rollout_routing_replay:
# When routing replay is enabled, we strictly use the configured response length
max_tokens = self.config.response_length
else:
# Otherwise, use a dynamic calculation based on prompt length
max_tokens = self.config.response_length + self.config.prompt_length - len(prompt_ids)

# Clamp max_tokens to the valid range [0, max_possible_tokens]
max_tokens = max(0, min(max_tokens, max_possible_tokens))
Expand Down
Loading