Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions verl/experimental/agent_loop/tool_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,14 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
multi_modal_data["images"] = agent_data.image_data
if agent_data.video_data is not None:
multi_modal_data["videos"] = agent_data.video_data

routed_experts = getattr(agent_data, "routed_experts", None)
output = AgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=response_ids[: self.response_length],
response_mask=agent_data.response_mask[: self.response_length],
multi_modal_data=multi_modal_data,
routed_experts=routed_experts,
response_logprobs=agent_data.response_logprobs[: self.response_length]
if agent_data.response_logprobs
else None,
Expand Down
56 changes: 51 additions & 5 deletions verl/utils/megatron/router_replay_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from enum import Enum

import torch
import types
import inspect

try:
from megatron.core.transformer.moe.moe_utils import (
Expand Down Expand Up @@ -236,7 +238,27 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):

return routing_probs, routing_map


def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float:
"""获取给定辅助损失类型的系数。"""
# 逻辑保持不变
if isinstance(_self.routing_type, str):
if _self.routing_type == aux_loss_type:
return _self.config.moe_aux_loss_coeff
if isinstance(_self.routing_type, list):
try:
idx = _self.routing_type.index(aux_loss_type)
return _self.config.moe_aux_loss_coeff[idx]
except (ValueError, IndexError):
return 0.0
return 0.0

def _is_aux_loss_enabled(_self) -> bool:
"""检查是否启用了任何辅助损失。"""
for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]:
# 注意这里调用的是同在模块级别的另一个辅助函数
if _get_aux_loss_coeff(_self, aux_loss_type) > 0:
return True
return False
def patched_routing(self, logits: torch.Tensor, *args, **kwargs):
"""Top-k routing function

Expand Down Expand Up @@ -282,6 +304,8 @@ def patched_routing(self, logits: torch.Tensor, *args, **kwargs):
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
)

if not hasattr(self, "is_aux_loss_enabled"):
self.is_aux_loss_enabled = types.MethodType(_is_aux_loss_enabled, self)
# Apply each aux loss type and attach aux loss autograd function to probs
if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled():
# Calculate scores and routing_map for aux loss
Expand Down Expand Up @@ -311,26 +335,48 @@ def apply_router_replay_patch():
# Clear router instances to avoid state leakage between model initializations.
RouterReplay.router_instances.clear()
# Step 1: Patch TransformerConfig to include the feature flag
if not hasattr(TransformerConfig, "enable_routing_replay"):
# Add class attribute with default value
TransformerConfig.enable_routing_replay = False
try:
global_args = get_args()
except Exception:
global_args = None

try:
sig = inspect.signature(TransformerConfig.__init__)
native_params = sig.parameters
except Exception:
native_params = []

ext_attrs = ["enable_routing_replay", "moe_router_fusion"]

for attr in ext_attrs:
val = getattr(global_args, attr, False) if global_args else False

if not hasattr(TransformerConfig, attr):
setattr(TransformerConfig, attr, val)

if not hasattr(TransformerConfig, "_verl_router_patched"):
# Store original __init__ method
original_tf_config_init = TransformerConfig.__init__

# Define new __init__ method that safely handles enable_routing_replay parameter
@wraps(original_tf_config_init)
def patched_tf_config_init(self, *args, **kwargs):
# Simple solution: remove the unknown parameter before calling original constructor
enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay)
if "enable_routing_replay" not in native_params:
enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay)

if "moe_router_fusion" not in native_params:
moe_router_fusion = kwargs.pop("enable_routing_replay", TransformerConfig.moe_router_fusion)
# Call original constructor with remaining kwargs
original_tf_config_init(self, *args, **kwargs)

# Set the instance attribute
self.enable_routing_replay = enable_routing_replay
self.moe_router_fusion = moe_router_fusion

# Apply the patch
TransformerConfig.__init__ = patched_tf_config_init
TransformerConfig._verl_router_patched = True

# Step 2: Patch TopKRouter only once to ensure idempotency.
if hasattr(TopKRouter, "_router_replay_patched"):
Expand Down
9 changes: 8 additions & 1 deletion verl/utils/megatron/router_replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Optional

import torch
import inspect

try:
from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage
Expand Down Expand Up @@ -330,7 +331,13 @@ def get_current_rank_layer_info(tf_config, vp_rank=None):
if vp_rank is None:
vp_rank = 0
num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage=vp_rank)
offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank)

sig = inspect.signature(get_transformer_layer_offset)

if 'vp_stage' in sig.parameters:
offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank)
else:
offset = get_transformer_layer_offset(tf_config)
local = {}
local["start"] = offset
local["end"] = offset + num_layers_to_build
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ async def generate(
max_tokens = sampling_params.pop("max_new_tokens")
else:
# Default to a calculation that considers configured lengths
max_tokens = self.config.response_length + self.config.prompt_length - len(prompt_ids)
max_tokens = self.config.response_length

# Clamp max_tokens to the valid range [0, max_possible_tokens]
max_tokens = max(0, min(max_tokens, max_possible_tokens))
Expand Down