Skip to content

Commit 5d76fbc

Browse files
committed
[vllm] feat: implement chunked weight handling in vllm rollout for large tensors
Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
1 parent f56c893 commit 5d76fbc

File tree

3 files changed

+236
-9
lines changed

3 files changed

+236
-9
lines changed
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
import numpy as np
17+
import pytest
18+
import ray
19+
from omegaconf import DictConfig
20+
21+
from verl.checkpoint_engine import CheckpointEngineManager
22+
from verl.experimental.agent_loop.agent_loop import AgentLoopManager
23+
from verl.protocol import DataProto
24+
from verl.single_controller.ray import (
25+
RayClassWithInitArgs,
26+
RayResourcePool,
27+
RayWorkerGroup,
28+
)
29+
from verl.single_controller.ray.base import create_colocated_worker_cls
30+
from verl.utils.device import get_device_name
31+
from verl.utils.tokenizer import hf_tokenizer
32+
from verl.workers.engine_workers import ActorRolloutRefWorker
33+
34+
35+
@pytest.fixture
36+
def init_config() -> DictConfig:
37+
from hydra import compose, initialize_config_dir
38+
39+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
40+
config = compose(config_name="ppo_trainer")
41+
42+
config.trainer.n_gpus_per_node = 8
43+
config.trainer.nnodes = 1
44+
config.actor_rollout_ref.actor.use_dynamic_bsz = True
45+
config.actor_rollout_ref.model.path = os.path.expanduser("~/models/Qwen/Qwen3-VL-2B-Instruct")
46+
config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"]
47+
config.actor_rollout_ref.rollout.skip_tokenizer_init = False
48+
config.actor_rollout_ref.rollout.max_num_seqs = 256
49+
config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.8
50+
config.actor_rollout_ref.rollout.agent.num_workers = 2
51+
config.actor_rollout_ref.rollout.checkpoint_engine.backend = "naive"
52+
config.actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes = 256
53+
config.actor_rollout_ref.rollout.enforce_eager = True
54+
55+
return config
56+
57+
58+
@pytest.mark.skip(reason="This test costs too much to run in CI.")
59+
@pytest.mark.asyncio
60+
def test_server_adapter_colocated_weight_update(init_config):
61+
ray.init(
62+
runtime_env={
63+
"env_vars": {
64+
"TOKENIZERS_PARALLELISM": "true",
65+
"NCCL_DEBUG": "WARN",
66+
"VLLM_LOGGING_LEVEL": "INFO",
67+
"VLLM_USE_V1": "1",
68+
"VLLM_DISABLE_COMPILE_CACHE": "1",
69+
"HCCL_HOST_SOCKET_PORT_RANGE": "60000-60050",
70+
"HCCL_NPU_SOCKET_PORT_RANGE": "61000-61050",
71+
}
72+
}
73+
)
74+
75+
# 0. init actor rollout worker group
76+
resource_pool = RayResourcePool(
77+
process_on_nodes=[init_config.trainer.n_gpus_per_node] * init_config.trainer.nnodes, max_colocate_count=3
78+
)
79+
actor_rollout_cls = ray.remote(ActorRolloutRefWorker)
80+
cls_dict = {
81+
"actor_rollout": RayClassWithInitArgs(
82+
cls=actor_rollout_cls, config=init_config.actor_rollout_ref, role="actor_rollout"
83+
)
84+
}
85+
ray_cls_with_init = create_colocated_worker_cls(cls_dict)
86+
wg_dict = RayWorkerGroup(
87+
resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name=get_device_name()
88+
)
89+
spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
90+
actor_rollout_wg = spawn_wg["actor_rollout"]
91+
actor_rollout_wg.init_model()
92+
93+
# 1. create AgentLoopManager
94+
agent_loop_manager = AgentLoopManager(
95+
config=init_config,
96+
worker_group=actor_rollout_wg,
97+
rollout_resource_pool=resource_pool,
98+
)
99+
100+
# 2. create CheckpointEngineManager
101+
checkpoint_manager = CheckpointEngineManager(
102+
backend=init_config.actor_rollout_ref.rollout.checkpoint_engine.backend,
103+
trainer=actor_rollout_wg,
104+
replicas=agent_loop_manager.rollout_replicas,
105+
)
106+
checkpoint_manager.sleep_replicas()
107+
108+
# 3. generate prompts
109+
raw_prompts = [
110+
[
111+
{
112+
"role": "user",
113+
"content": "This is a test for weight update. If the weight has been correctly "
114+
'updated and you understand my meaning, please respond with "Test Passed".',
115+
}
116+
],
117+
[
118+
{
119+
"role": "user",
120+
"content": "This is a test for weight update. If the weight has been correctly "
121+
'updated and you understand my meaning, please respond with "Test Passed".',
122+
}
123+
],
124+
]
125+
batch = DataProto(
126+
non_tensor_batch={
127+
"raw_prompt": np.array(raw_prompts),
128+
"agent_name": np.array(["single_turn_agent"] * len(raw_prompts)),
129+
"data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
130+
"reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
131+
},
132+
)
133+
134+
# 4. update weights and generate sequences, check if the responses are correct
135+
for _ in range(3):
136+
checkpoint_manager.update_weights()
137+
result = agent_loop_manager.generate_sequences(batch)
138+
checkpoint_manager.sleep_replicas()
139+
140+
# Check response
141+
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
142+
responses = result.batch["responses"]
143+
response_mask = result.batch["response_mask"]
144+
145+
for i in range(len(responses)):
146+
valid_tokens = responses[i][response_mask[i].bool()]
147+
response = tokenizer.decode(valid_tokens)
148+
assert "test passed" in response.lower(), f"Response does not contain 'test passed': {response}"
149+
150+
print("=========================")
151+
print("[OUTPUT]:", response)
152+
print("---")
153+
154+
ray.shutdown()

verl/workers/rollout/vllm_rollout/utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False
236236
patch_vllm_moe_model_weight_loader(self.model_runner.model)
237237

238238
# receive bucket and update weights
239+
# Buffer to collect chunks for weights that were sliced
240+
pending_chunks = {} # name -> {chunk_idx: tensor, ...}
239241
while True:
240242
metadata = socket.recv_pyobj()
241243
weights, tensor = [], None
@@ -250,14 +252,40 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False
250252
tensor = tensor.clone()
251253
else:
252254
tensor = tensor.to(self.device)
253-
weights.append((name, tensor))
255+
256+
# Check if this is a chunk of a sliced weight
257+
if "chunk_idx" in meta and "total_chunks" in meta:
258+
# This is a chunk, store it for later merging
259+
original_name = meta["name"]
260+
chunk_idx = meta["chunk_idx"]
261+
if original_name not in pending_chunks:
262+
pending_chunks[original_name] = {}
263+
pending_chunks[original_name][chunk_idx] = tensor
264+
265+
# Check if we have all chunks for this weight
266+
if len(pending_chunks[original_name]) == meta["total_chunks"]:
267+
# Merge all chunks back into one tensor
268+
chunks_dict = pending_chunks[original_name]
269+
sorted_chunks = [chunks_dict[i] for i in range(meta["total_chunks"])]
270+
merged_tensor = torch.cat(sorted_chunks, dim=0)
271+
weights.append((original_name, merged_tensor))
272+
del pending_chunks[original_name]
273+
else:
274+
weights.append((name, tensor))
275+
254276
get_torch_device().synchronize()
255277
socket.send(b"")
256278
self._update_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done)
257279
del weights, tensor
258280
if metadata["is_last"]:
259281
break
260282

283+
# Check if there are any remaining chunks that weren't processed
284+
if pending_chunks:
285+
raise RuntimeError(
286+
f"Received chunks for weights {list(pending_chunks.keys())} but did not receive all chunks for them."
287+
)
288+
261289
if self._is_qat_model:
262290
# QAT: call process_weights_after_loading AFTER all buckets are received
263291
from verl.utils.qat import manual_process_weights_after_loading

verl/workers/rollout/vllm_rollout/vllm_rollout.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import logging
3131
import os
3232
import time
33+
from functools import reduce
3334
from typing import Any, Generator, Optional
3435

3536
import ray
@@ -198,27 +199,71 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
198199
# transfer volume.
199200
# weight = weight.to(dtype, non_blocking=True)
200201

202+
# Check if the weight needs to be sliced into chunks
203+
# (e.g., large embedding layer that exceeds bucket_size)
204+
weight_size = weight.nbytes
205+
if weight_size > bucket_size:
206+
# Slice the weight along the first dimension into chunks
207+
dtype_size = weight.element_size()
208+
numel_per_chunk = bucket_size // dtype_size
209+
210+
# Calculate chunk size along the first dimension
211+
first_dim_size = weight.shape[0]
212+
chunk_dim_size = numel_per_chunk // reduce(lambda x, y: x * y, weight.shape[1:], 1)
213+
214+
num_chunks = (first_dim_size + chunk_dim_size - 1) // chunk_dim_size
215+
logger.info(
216+
f"Slicing weight {name} ({weight.shape}, {weight.dtype}, {weight_size} bytes) "
217+
f"into {num_chunks} chunks"
218+
)
219+
220+
start_idx = 0
221+
for chunk_idx in range(num_chunks):
222+
end_idx = min(start_idx + chunk_dim_size, first_dim_size)
223+
224+
# Extract chunk along first dimension
225+
chunk = weight[start_idx:end_idx]
226+
chunk_size = chunk.nbytes
227+
228+
# Fill bucket with chunk
229+
if offset + chunk_size > bucket_size:
230+
get_torch_device().synchronize()
231+
s.send_pyobj({"bucket_meta": bucket_meta, "is_last": False})
232+
s.recv()
233+
bucket_meta = {}
234+
offset = 0
235+
236+
bucket_meta[f"{name}_chunk_{chunk_idx}"] = {
237+
"name": name,
238+
"shape": chunk.shape,
239+
"dtype": chunk.dtype,
240+
"offset": offset,
241+
"chunk_idx": chunk_idx,
242+
"total_chunks": num_chunks,
243+
}
244+
buffer[offset : offset + chunk_size].copy_(chunk.view(-1).view(torch.uint8), non_blocking=True)
245+
offset += chunk_size
246+
247+
start_idx = end_idx
248+
249+
continue
250+
201251
# fill the tensor bucket
202-
if offset + weight.nbytes > bucket_size:
252+
if offset + weight_size > bucket_size:
203253
get_torch_device().synchronize()
204254
s.send_pyobj({"bucket_meta": bucket_meta, "is_last": False})
205255
s.recv()
206256
bucket_meta = {}
207257
offset = 0
208258

209-
# TODO: slice embedding layer weight into chunks
210-
assert offset + weight.nbytes <= bucket_size, (
211-
f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket."
212-
f"Please increase rollout.update_weights_bucket_megabytes({bucket_size_mb} MB)."
213-
)
214259
bucket_meta[name] = {
215260
"name": name,
216261
"shape": weight.shape,
217262
"dtype": weight.dtype,
218263
"offset": offset,
219264
}
220-
buffer[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True)
221-
offset += weight.nbytes
265+
buffer[offset : offset + weight_size].copy_(weight.view(-1).view(torch.uint8), non_blocking=True)
266+
offset += weight_size
222267

223268
# send the last bucket
224269
get_torch_device().synchronize()

0 commit comments

Comments
 (0)