Skip to content

Commit ba76c5e

Browse files
authored
[misc] feat: delete unnecessary base class in agent loop worker and vLLMHttpServer (#4838)
1 parent b12eb3b commit ba76c5e

File tree

4 files changed

+9
-47
lines changed

4 files changed

+9
-47
lines changed

verl/experimental/agent_loop/agent_loop.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]:
341341
return decorator
342342

343343

344-
class AgentLoopWorkerBase:
344+
class AgentLoopWorker:
345345
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
346346

347347
def __init__(
@@ -351,10 +351,10 @@ def __init__(
351351
reward_router_address: str = None,
352352
):
353353
"""Initialize agent loop manager.
354-
355354
Args:
356355
config (DictConfig): YAML config.
357356
server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
357+
reward_router_address (str): reward router address.
358358
"""
359359
self.config = config
360360

@@ -804,22 +804,6 @@ def create_transferqueue_client(
804804
)
805805

806806

807-
@ray.remote
808-
class AgentLoopWorker(AgentLoopWorkerBase):
809-
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
810-
811-
def __init__(
812-
self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], reward_router_address: str = None
813-
):
814-
"""Initialize agent loop manager.
815-
Args:
816-
config (DictConfig): YAML config.
817-
server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
818-
reward_router_address (str): reward router address.
819-
"""
820-
super().__init__(config, server_handles, reward_router_address)
821-
822-
823807
async def get_trajectory_info(step, index, validate):
824808
"""Get trajectory info.
825809
@@ -869,7 +853,7 @@ def __init__(
869853
if not hasattr(self, "rollout_replica_class"):
870854
self.rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name)
871855
if not hasattr(self, "agent_loop_workers_class"):
872-
self.agent_loop_workers_class = AgentLoopWorker
856+
self.agent_loop_workers_class = ray.remote(AgentLoopWorker)
873857

874858
self._initialize_llm_servers()
875859
self._init_agent_loop_workers()

verl/experimental/fully_async_policy/agent_loop/agent_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from verl.experimental.agent_loop.agent_loop import (
2525
AgentLoopManager,
2626
AgentLoopOutput,
27-
AgentLoopWorkerBase,
27+
AgentLoopWorker,
2828
AsyncLLMServerManager,
2929
DictConfigWrap,
3030
_agent_loop_registry,
@@ -77,7 +77,7 @@ async def generate_for_partial(
7777

7878

7979
@ray.remote
80-
class FullyAsyncAgentLoopWorker(AgentLoopWorkerBase):
80+
class FullyAsyncAgentLoopWorker(AgentLoopWorker):
8181
def __init__(
8282
self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], reward_router_address: str = None
8383
):

verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from verl.workers.rollout.replica import RolloutMode
2626
from verl.workers.rollout.vllm_rollout.vllm_async_server import (
2727
_qwen2_5_vl_dedup_image_tokens,
28-
vLLMHttpServerBase,
28+
vLLMHttpServer,
2929
vLLMReplica,
3030
)
3131

@@ -34,7 +34,7 @@
3434

3535

3636
@ray.remote(num_cpus=1)
37-
class vLLMHttpServerForPartial(vLLMHttpServerBase):
37+
class vLLMHttpServerForPartial(vLLMHttpServer):
3838
def __init__(
3939
self,
4040
config: RolloutConfig,

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def check_health(self):
168168
return
169169

170170

171-
class vLLMHttpServerBase:
171+
class vLLMHttpServer:
172172
"""vLLM http server in single node, this is equivalent to launch server with command line:
173173
```
174174
vllm serve --tensor-parallel-size=8 ...
@@ -663,28 +663,6 @@ async def abort_request(self, request_id: str, reset_prefix_cache: bool = True)
663663
return {"aborted": False, "request_id": request_id, "error": str(e)}
664664

665665

666-
@ray.remote(num_cpus=1)
667-
class vLLMHttpServer(vLLMHttpServerBase):
668-
"""vLLM http server in single node, this is equivalent to launch server with command line:
669-
```
670-
vllm serve --tensor-parallel-size=8 ...
671-
```
672-
"""
673-
674-
def __init__(
675-
self,
676-
config: RolloutConfig,
677-
model_config: HFModelConfig,
678-
rollout_mode: RolloutMode,
679-
workers: list[ActorHandle],
680-
replica_rank: int,
681-
node_rank: int,
682-
gpus_per_node: int,
683-
nnodes: int,
684-
):
685-
super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes)
686-
687-
688666
_rollout_worker_actor_cls = ray.remote(vLLMAsyncRollout)
689667

690668

@@ -698,7 +676,7 @@ def __init__(
698676
is_reward_model: bool = False,
699677
):
700678
super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model)
701-
self.server_class = vLLMHttpServer
679+
self.server_class = ray.remote(vLLMHttpServer)
702680

703681
def get_ray_class_with_init_args(self) -> RayClassWithInitArgs:
704682
"""Get rollout worker actor class for colocated and standalone mode."""

0 commit comments

Comments
 (0)