Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/experimental/agent_loop/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerG
config=config,
rm_resource_pool=rm_resource_pool,
)
agent_loop_manager = AgentLoopManager(
agent_loop_manager = AgentLoopManager.create(
config=config,
worker_group=actor_rollout_wg,
reward_loop_worker_handles=reward_loop_manager.reward_loop_workers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ async def test_agent_loop_extra_fields_schema_stable_for_training_concat_on_cpu(
# Minimal config surface used by the agent loops.
config = OmegaConf.create(
{
"actor_rollout_ref": {"rollout": {"prompt_length": 16, "response_length": 16}},
"actor_rollout_ref": {
"rollout": {"prompt_length": 16, "response_length": 16, "multi_turn": {"tool_config_path": None}},
"model": {},
},
"data": {
"tool_config_path": None,
"apply_chat_template_kwargs": {},
Expand All @@ -160,23 +163,23 @@ async def test_agent_loop_extra_fields_schema_stable_for_training_concat_on_cpu(
processor = None

trainer_config = DictConfigWrap(config)
dataset_config = DictConfigWrap(config.data)
data_config = DictConfigWrap(config.data)

single_turn = SingleTurnAgentLoop(
trainer_config=trainer_config,
server_manager=server_manager,
tokenizer=tokenizer,
processor=processor,
dataset_cls=RLHFDataset,
dataset_config=dataset_config,
data_config=data_config,
)
partial_single_turn = PartialSingleTurnAgentLoop(
trainer_config=trainer_config,
server_manager=server_manager,
tokenizer=tokenizer,
processor=processor,
dataset_cls=RLHFDataset,
dataset_config=dataset_config,
data_config=data_config,
)

raw_prompt = [{"role": "user", "content": "hi"}]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def test_agent_reward_loop_standalone():
)
actor_rollout_wg.init_model()

agent_loop_manager = AgentLoopManager(config, worker_group=actor_rollout_wg)
agent_loop_manager = AgentLoopManager.create(
config=config,
worker_group=actor_rollout_wg,
)
# sleep rollout replicas
checkpoint_manager = CheckpointEngineManager(
config=omega_conf_to_dataclass(config.actor_rollout_ref.rollout.checkpoint_engine),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_agent_reward_loop_standalone():
config.actor_rollout_ref.rollout.prompt_length = 1024
config.actor_rollout_ref.rollout.response_length = 4096
config.actor_rollout_ref.rollout.skip_tokenizer_init = True
config.actor_rollout_ref.rollout.nnodes = 1
config.trainer.n_gpus_per_node = 4
config.trainer.nnodes = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is config.trainer.nnodes = 1 still needed for test_agent_reward_loop_standalone?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, config.trainer.nnodes is not used in standalone mode.


Expand All @@ -76,8 +77,9 @@ def test_agent_reward_loop_standalone():

# 1. init reward model manager
reward_loop_manager = RewardLoopManager(config)
agent_loop_manager = AgentLoopManager(
config=config, reward_loop_worker_handles=reward_loop_manager.reward_loop_workers
agent_loop_manager = AgentLoopManager.create(
config=config,
reward_loop_worker_handles=reward_loop_manager.reward_loop_workers,
)

# 2. init test data
Expand Down
211 changes: 117 additions & 94 deletions verl/experimental/agent_loop/agent_loop.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions verl/experimental/agent_loop/single_turn_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class SingleTurnAgentLoop(AgentLoopBase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
self.response_length = self.config.actor_rollout_ref.rollout.response_length
self.prompt_length = self.rollout_config.prompt_length
self.response_length = self.rollout_config.response_length

async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
messages = list(kwargs["raw_prompt"])
Expand Down
39 changes: 13 additions & 26 deletions verl/experimental/agent_loop/tool_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@

import torch
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer

from verl.experimental.agent_loop.agent_loop import (
AgentLoopBase,
AgentLoopOutput,
AsyncLLMServerManager,
DictConfigWrap,
register,
)
from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser
Expand Down Expand Up @@ -96,37 +93,27 @@ def __init__(

@register("tool_agent")
class ToolAgentLoop(AgentLoopBase):
def __init__(
self,
trainer_config: DictConfigWrap,
server_manager: AsyncLLMServerManager,
tokenizer: AutoTokenizer,
processor: AutoProcessor,
**kwargs,
):
super().__init__(trainer_config, server_manager, tokenizer, processor, **kwargs)
config = trainer_config.config
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Initialize tools from config file
self.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns
self.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns
self.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls
self.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length
self.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side
tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path
self.max_user_turns = self.rollout_config.multi_turn.max_user_turns
self.max_assistant_turns = self.rollout_config.multi_turn.max_assistant_turns
self.max_parallel_calls = self.rollout_config.multi_turn.max_parallel_calls
self.max_tool_response_length = self.rollout_config.multi_turn.max_tool_response_length
self.tool_response_truncate_side = self.rollout_config.multi_turn.tool_response_truncate_side
tool_config_path = self.rollout_config.multi_turn.tool_config_path
tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []
self.tools = {tool.name: tool for tool in tool_list}
self.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]
self.tool_parser = ToolParser.get_tool_parser(
config.actor_rollout_ref.rollout.multi_turn.format, self.tokenizer
)
self.tool_parser_name = config.actor_rollout_ref.rollout.multi_turn.format
self.tool_parser = ToolParser.get_tool_parser(self.rollout_config.multi_turn.format, self.tokenizer)
self.tool_parser_name = self.rollout_config.multi_turn.format

self.prompt_length = config.actor_rollout_ref.rollout.prompt_length
self.response_length = config.actor_rollout_ref.rollout.response_length
self.prompt_length = self.rollout_config.prompt_length
self.response_length = self.rollout_config.response_length

# Initialize interactions from config file
self.interaction_config_file = config.actor_rollout_ref.rollout.multi_turn.interaction_config_path
self.interaction_config_file = self.rollout_config.multi_turn.interaction_config_path
if self.interaction_config_file:
self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions(
self.interaction_config_file
Expand Down
69 changes: 7 additions & 62 deletions verl/experimental/fully_async_policy/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
AsyncLLMServerManager,
DictConfigWrap,
_agent_loop_registry,
_get_rollout_and_model_config,
get_trajectory_info,
)
from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config
from verl.protocol import DataProto
from verl.single_controller.ray import RayWorkerGroup
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup
from verl.utils.rollout_trace import (
rollout_trace_attr,
rollout_trace_op,
Expand Down Expand Up @@ -102,7 +102,7 @@ async def generate_sequences_no_post(
Returns:
list[AgentLoopOutput]: List of agent loop outputs, one per sample in the batch.
"""
config = self.config.actor_rollout_ref.rollout
config = self.rollout_config
sampling_params = dict(
temperature=config.temperature,
top_p=config.top_p,
Expand Down Expand Up @@ -191,7 +191,7 @@ async def _partial_run_agent_loop(
tokenizer=self.tokenizer,
processor=self.processor,
dataset_cls=self.dataset_cls,
dataset_config=DictConfigWrap(config=self.config.data),
data_config=DictConfigWrap(config=self.config.data),
)
output: AgentLoopOutput = await agent_loop.run(
sampling_params, cancellation_event=self.cancellation_event, **kwargs
Expand Down Expand Up @@ -219,15 +219,17 @@ def __init__(
self,
config: DictConfig,
worker_group: RayWorkerGroup = None,
rollout_resource_pool: RayResourcePool = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
self.config = config
self.rollout_config, self.model_config = _get_rollout_and_model_config(config)
self.worker_group = worker_group
self.reward_loop_worker_handles = reward_loop_worker_handles
self.agent_loop_workers_class = FullyAsyncAgentLoopWorker

# Select rollout replica class based on rollout name
rollout_name = config.actor_rollout_ref.rollout.name
rollout_name = self.rollout_config.name
if rollout_name == "sglang":
from verl.experimental.fully_async_policy.sglang_rollout.sglang_async_server import FullyAsyncSGLangReplica

Expand All @@ -246,63 +248,6 @@ def __init__(
self.server_addresses = None
self.agent_loop_workers = None

@classmethod
async def create(
cls,
config: DictConfig,
worker_group: RayWorkerGroup = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
instance = cls(config, worker_group, reward_loop_worker_handles)
await instance._async_init()
return instance

async def _async_init(self):
await self._initialize_llm_servers_async()
self._init_agent_loop_workers()

async def _initialize_llm_servers_async(self):
rollout_world_size = (
self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
* self.config.actor_rollout_ref.rollout.data_parallel_size
* self.config.actor_rollout_ref.rollout.pipeline_model_parallel_size
)
world_size = (
self.worker_group.world_size
if self.worker_group
else self.config.rollout.n_gpus_per_node * self.config.rollout.nnodes
)
num_replicas = world_size // rollout_world_size

rollout_config = self.config.actor_rollout_ref.rollout
model_config = self.config.actor_rollout_ref.model
self.rollout_replicas = [
self.rollout_replica_class(
replica_rank=replica_rank,
config=rollout_config,
model_config=model_config,
gpus_per_node=self.config.rollout.n_gpus_per_node,
)
for replica_rank in range(num_replicas)
]

if self.worker_group:
await asyncio.gather(*[server.init_hybrid(self.worker_group) for server in self.rollout_replicas])
else:
await asyncio.gather(*[server.init_standalone() for server in self.rollout_replicas])

self.server_handles = [server._server_handle for server in self.rollout_replicas]
self.server_addresses = [server._server_address for server in self.rollout_replicas]

print(f"AgentLoopManager: {self.server_addresses}")
# Update Prometheus configuration with server addresses
if rollout_config.prometheus.enable:
if rollout_config.disable_log_stats:
raise ValueError("PROMETHEUS needs disable_log_stats==False, but it is currently True.")
await asyncio.to_thread(
update_prometheus_config, rollout_config.prometheus, self.server_addresses, rollout_config.name
)

async def generate_single_sample_async(
self,
sample: DataProto,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class PartialSingleTurnAgentLoop(AgentLoopBase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
self.response_length = self.config.actor_rollout_ref.rollout.response_length
self.apply_chat_template_kwargs = self.config.data.get("apply_chat_template_kwargs", {})
self.prompt_length = self.rollout_config.prompt_length
self.response_length = self.rollout_config.response_length
self.apply_chat_template_kwargs = self.data_config.get("apply_chat_template_kwargs", {})

async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
output: Optional[AgentLoopOutput] = kwargs.get("output", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class AsyncPartialToolAgentLoop(ToolAgentLoop):

"""

def __init__(self, trainer_config, **kwargs):
super().__init__(trainer_config, **kwargs)
self.enable_partial_rollout = trainer_config.config.async_training.get("partial_rollout", False)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.enable_partial_rollout = self.config.async_training.get("partial_rollout", False)

# async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
async def run(
Expand Down
3 changes: 3 additions & 0 deletions verl/experimental/fully_async_policy/fully_async_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ def main(config):
from time import time

start_time = time()
# TODO: unify rollout config with actor_rollout_ref
config.actor_rollout_ref.rollout.nnodes = config.rollout.nnodes
config.actor_rollout_ref.rollout.n_gpus_per_node = config.rollout.n_gpus_per_node
run_ppo(config, task_runner_class=FullyAsyncTaskRunner)
print(f"total time: {time() - start_time:.2f} seconds")

Expand Down
50 changes: 0 additions & 50 deletions verl/experimental/one_step_off_policy/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
import ray

from verl.experimental.agent_loop.agent_loop import AgentLoopManager
from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config
from verl.protocol import DataProto
from verl.single_controller.ray import RayResourcePool

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
Expand Down Expand Up @@ -56,54 +54,6 @@ async def generate_sequences_async(self, prompts: DataProto) -> DataProto:
output.meta_info = {"timing": timing, **outputs[0].meta_info}
return output

def _initialize_llm_servers(self, rollout_resource_pool: RayResourcePool):
rollout_world_size = (
self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
* self.config.actor_rollout_ref.rollout.data_parallel_size
* self.config.actor_rollout_ref.rollout.pipeline_model_parallel_size
)
world_size = (
self.worker_group.world_size
if self.worker_group
else self.config.rollout.n_gpus_per_node * self.config.rollout.nnodes
)
num_replicas = world_size // rollout_world_size

rollout_config = self.config.actor_rollout_ref.rollout
model_config = self.config.actor_rollout_ref.model
self.rollout_replicas = [
self.rollout_replica_class(
replica_rank=replica_rank,
config=rollout_config,
model_config=model_config,
gpus_per_node=self.config.rollout.n_gpus_per_node,
)
for replica_rank in range(num_replicas)
]

if self.worker_group and rollout_config.name != "trtllm":
self._run_all([server.init_hybrid(self.worker_group) for server in self.rollout_replicas])
elif self.worker_group and rollout_config.name == "trtllm":
self._run_all(
[
server.init_hybrid_colocated(self.worker_group, rollout_resource_pool)
for server in self.rollout_replicas
]
)
else:
self._run_all([server.init_standalone() for server in self.rollout_replicas])

self.server_handles = [server._server_handle for server in self.rollout_replicas]
self.server_addresses = [server._server_address for server in self.rollout_replicas]

print(f"AgentLoopManager: {self.server_addresses}")

# Update Prometheus configuration with server addresses
if rollout_config.prometheus.enable:
if rollout_config.disable_log_stats:
raise ValueError("PROMETHEUS needs disable_log_stats==False, but it is currently True.")
update_prometheus_config(rollout_config.prometheus, self.server_addresses, rollout_config.name)

async def wake_up(self):
await asyncio.gather(*[replica.wake_up() for replica in self.rollout_replicas])

Expand Down
4 changes: 4 additions & 0 deletions verl/experimental/one_step_off_policy/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def main(config):
# Automatically set `config.trainer.device = npu` when running on Ascend NPU.
auto_set_device(config)

# TODO: unify rollout config with actor_rollout_ref
config.actor_rollout_ref.rollout.nnodes = config.rollout.nnodes
config.actor_rollout_ref.rollout.n_gpus_per_node = config.rollout.n_gpus_per_node

run_ppo(config, task_runner_class=OneStepTaskRunner)
print(f"total time: {time() - start_time:.2f} seconds")

Expand Down
2 changes: 1 addition & 1 deletion verl/experimental/one_step_off_policy/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _init_async_rollout_manager(self):
from verl.experimental.one_step_off_policy.agent_loop import OneStepOffAgentLoopManager

self.async_rollout_mode = True
self.async_rollout_manager = OneStepOffAgentLoopManager(
self.async_rollout_manager = OneStepOffAgentLoopManager.create(
config=self.config, reward_loop_worker_handles=reward_loop_worker_handles
)

Expand Down
Loading
Loading