diff --git a/tests/utils/test_bucketed_weight_transfer.py b/tests/utils/test_bucketed_weight_transfer.py new file mode 100644 index 00000000000..0f5ae5c84c6 --- /dev/null +++ b/tests/utils/test_bucketed_weight_transfer.py @@ -0,0 +1,219 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for BucketedWeightSender and BucketedWeightReceiver. + +Sender and receiver run in separate processes to match real-world usage +and because CUDA IPC requires distinct processes. +""" + +import asyncio +import multiprocessing as mp +import uuid + +import pytest +import torch + +from verl.utils.device import get_device_name, get_torch_device + +PROCESS_TIMEOUT = 60 + +# Use string checks to avoid initializing CUDA in the main pytest process, +# which would make subsequent fork-based multiprocessing in other tests unsafe. +HAS_ACCELERATOR = get_device_name() != "cpu" +HAS_CUDA = "cuda" in get_device_name() + + +def _unique_zmq_handle(): + return f"ipc:///tmp/test-bwt-{uuid.uuid4().hex}.sock" + + +def _generate_weights(weight_specs, seed): + """Deterministically generate weights on the best available device from specs. + + Args: + weight_specs: list of (name, shape, dtype) tuples + seed: random seed for reproducibility + Returns: + list of (name, tensor_on_device) tuples + """ + device_name = get_device_name() + device = torch.device(f"{device_name}:0") + get_torch_device().manual_seed(seed) + weights = [] + for name, shape, dtype in weight_specs: + # Generate in float32 then cast, since torch.randn doesn't support all dtypes + t = torch.randn(shape, dtype=torch.float32, device=device).to(dtype) + weights.append((name, t)) + return weights + + +# --------------------------------------------------------------------------- +# Process entry points (must be module-level for pickling with spawn) +# --------------------------------------------------------------------------- +def _sender_fn(zmq_handle, weight_specs, seed, bucket_size_mb, use_shm): + """Sender process: generate weights, move to device, send.""" + from verl.workers.rollout.bucketed_weight_transfer import BucketedWeightSender + + weights = _generate_weights(weight_specs, seed) + sender = BucketedWeightSender( + zmq_handle=zmq_handle, + bucket_size_mb=bucket_size_mb, + use_shm=use_shm, + ) + asyncio.run(sender.async_send_weights(iter(weights))) + + +def _receiver_fn(zmq_handle, use_shm, result_queue): + """Receiver process: receive weights, send back (name, dtype, shape, checksum).""" + from verl.utils.device import get_device_name + from verl.workers.rollout.bucketed_weight_transfer import BucketedWeightReceiver + + device = torch.device(f"{get_device_name()}:0") + receiver = BucketedWeightReceiver( + zmq_handle=zmq_handle, + device=device, + use_shm=use_shm, + ) + received = [] + receiver.receive_weights(on_bucket_received=lambda w: received.extend(w)) + # Only send lightweight metadata + checksum back through the queue + summaries = [(name, t.dtype, tuple(t.shape), t.float().sum().item()) for name, t in received] + result_queue.put(summaries) + + +# --------------------------------------------------------------------------- +# Test helper +# --------------------------------------------------------------------------- +def _transfer_and_validate(weight_specs, bucket_size_mb, use_shm): + """Spawn sender + receiver processes, then validate received tensors.""" + zmq_handle = _unique_zmq_handle() + seed = 42 + ctx = mp.get_context("spawn") + result_queue = ctx.Queue() + + sender_p = ctx.Process( + target=_sender_fn, + args=(zmq_handle, weight_specs, seed, bucket_size_mb, use_shm), + ) + receiver_p = ctx.Process( + target=_receiver_fn, + args=(zmq_handle, use_shm, result_queue), + ) + + # Start sender first (it binds), then receiver (it connects) + sender_p.start() + receiver_p.start() + + sender_p.join(timeout=PROCESS_TIMEOUT) + receiver_p.join(timeout=PROCESS_TIMEOUT) + + assert sender_p.exitcode == 0, f"Sender process failed with exit code {sender_p.exitcode}" + assert receiver_p.exitcode == 0, f"Receiver process failed with exit code {receiver_p.exitcode}" + + summaries = result_queue.get(timeout=5) + + # Regenerate expected weights on device with the same seed + expected = _generate_weights(weight_specs, seed) + + assert len(summaries) == len(expected), f"Expected {len(expected)} weights, got {len(summaries)}" + + for (exp_name, exp_tensor), (recv_name, recv_dtype, recv_shape, recv_cksum) in zip( + expected, summaries, strict=False + ): + assert exp_name == recv_name, f"Name mismatch: expected {exp_name}, got {recv_name}" + assert tuple(exp_tensor.shape) == recv_shape, ( + f"Shape mismatch for {exp_name}: expected {tuple(exp_tensor.shape)}, got {recv_shape}" + ) + assert exp_tensor.dtype == recv_dtype, ( + f"Dtype mismatch for {exp_name}: expected {exp_tensor.dtype}, got {recv_dtype}" + ) + exp_sum = exp_tensor.float().sum().item() + assert exp_sum == recv_cksum, f"Data mismatch for {exp_name}" + + +# --------------------------------------------------------------------------- +# Shared memory tests +# --------------------------------------------------------------------------- +@pytest.mark.skipif(not HAS_ACCELERATOR, reason="Requires CUDA or NPU") +class TestBucketedWeightTransferSHM: + """Test BucketedWeightSender/Receiver via shared memory path.""" + + def test_single_small_weight(self): + specs = [("layer.weight", (32, 16), torch.float32)] + _transfer_and_validate(specs, bucket_size_mb=1, use_shm=True) + + def test_multiple_weights_single_bucket(self): + specs = [ + ("layer0.weight", (16, 16), torch.float32), + ("layer0.bias", (16,), torch.float32), + ("layer1.weight", (16, 8), torch.bfloat16), + ] + _transfer_and_validate(specs, bucket_size_mb=1, use_shm=True) + + def test_multiple_buckets(self): + # ~64 KB each x 20 = ~1.25 MB, bucket = 1 MB => spans 2 buckets + specs = [(f"layer{i}.weight", (128, 128), torch.float32) for i in range(20)] + _transfer_and_validate(specs, bucket_size_mb=1, use_shm=True) + + def test_mixed_dtypes(self): + specs = [ + ("fp32_param", (64, 64), torch.float32), + ("bf16_param", (64, 64), torch.bfloat16), + ("fp16_param", (32, 32), torch.float16), + ] + _transfer_and_validate(specs, bucket_size_mb=1, use_shm=True) + + def test_empty_weights(self): + _transfer_and_validate([], bucket_size_mb=1, use_shm=True) + + +# --------------------------------------------------------------------------- +# CUDA IPC tests (CUDA only — IPC is not supported on NPU) +# --------------------------------------------------------------------------- +@pytest.mark.skipif(not HAS_CUDA, reason="Requires CUDA (IPC not supported on NPU)") +class TestBucketedWeightTransferIPC: + """Test BucketedWeightSender/Receiver via CUDA IPC path.""" + + def test_single_small_weight(self): + specs = [("layer.weight", (32, 16), torch.float32)] + _transfer_and_validate(specs, bucket_size_mb=1, use_shm=False) + + def test_multiple_weights_single_bucket(self): + specs = [ + ("layer0.weight", (16, 16), torch.float32), + ("layer0.bias", (16,), torch.float32), + ("layer1.weight", (16, 8), torch.bfloat16), + ] + _transfer_and_validate(specs, bucket_size_mb=1, use_shm=False) + + def test_multiple_buckets(self): + specs = [(f"layer{i}.weight", (128, 128), torch.float32) for i in range(20)] + _transfer_and_validate(specs, bucket_size_mb=1, use_shm=False) + + def test_mixed_dtypes(self): + specs = [ + ("fp32_param", (64, 64), torch.float32), + ("bf16_param", (64, 64), torch.bfloat16), + ("fp16_param", (32, 32), torch.float16), + ] + _transfer_and_validate(specs, bucket_size_mb=1, use_shm=False) + + def test_empty_weights(self): + _transfer_and_validate([], bucket_size_mb=1, use_shm=False) + + def test_exact_bucket_boundary(self): + # 1 MB bucket = 1048576 bytes; float32 = 4 bytes => 262144 elements + numel = (1 << 20) // 4 + specs = [("exact_fit", (numel,), torch.float32)] + _transfer_and_validate(specs, bucket_size_mb=1, use_shm=False) diff --git a/tests/utils/test_shared_memory.py b/tests/utils/test_shared_memory.py index b548529f030..66508b67719 100644 --- a/tests/utils/test_shared_memory.py +++ b/tests/utils/test_shared_memory.py @@ -18,7 +18,7 @@ import torch -from verl.workers.rollout.vllm_rollout.utils import create_shared_memory, rebuild_shared_memory +from verl.workers.rollout.bucketed_weight_transfer import create_shared_memory, rebuild_shared_memory class TestSharedMemory(unittest.TestCase): diff --git a/verl/workers/rollout/bucketed_weight_transfer.py b/verl/workers/rollout/bucketed_weight_transfer.py new file mode 100644 index 00000000000..d4aa12490ca --- /dev/null +++ b/verl/workers/rollout/bucketed_weight_transfer.py @@ -0,0 +1,301 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Bucketed weight transfer via ZMQ + IPC (or shared memory fallback). + +Not recommended depending on vllm for this file. +""" + +import gc +import logging +import os +from multiprocessing import shared_memory +from typing import Callable, TypedDict + +import torch +import zmq +from torch.multiprocessing.reductions import reduce_tensor + +from verl.utils.device import get_device_id, get_device_name, get_torch_device + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + + +class TensorMetadata(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + offset: int + + +# copy from https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/rlhf_utils.py +def rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor: + func, args = handle + list_args = list(args) + if device_id is not None: + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + buffer = func(*list_args) + return buffer + + +def create_shared_memory(size: int, name: str): + """Create shared memory for weight transfer. If already exists, attach to it.""" + try: + shm = shared_memory.SharedMemory(name=name, create=True, size=size) + except FileExistsError: + shm = shared_memory.SharedMemory(name=name) + assert shm.size >= size, f"Stale shm segment '{name}': expected {size} bytes, got {shm.size}" + return shm + + +def rebuild_shared_memory(name: str, size: int, dtype=torch.uint8): + """Rebuild tensor from shared memory.""" + shm = shared_memory.SharedMemory(name=name) + tensor = torch.frombuffer(shm.buf[:size], dtype=dtype) + + return tensor, shm + + +class BucketedWeightSender: + """ + Send model weights via bucketed IPC transfer over ZMQ. + + Packs weight tensors into a fixed-size communication buffer and sends them + in buckets to the receiver. Supports CUDA IPC and shared memory fallback. + + Args: + zmq_handle: ZMQ IPC socket path (e.g., "ipc:///tmp/rl-colocate-zmq-.sock") + bucket_size_mb: Communication buffer size in MB + use_shm: Use shared memory instead of CUDA IPC (for NPU compatibility) + """ + + def __init__( + self, + zmq_handle: str, + bucket_size_mb: int = 512, + use_shm: bool = False, + ): + self.zmq_handle = zmq_handle + self.bucket_size_mb = bucket_size_mb + self.bucket_size = int(bucket_size_mb) << 20 + self.use_shm = use_shm + + self.zmq_context = zmq.Context.instance() + self.socket = None + self.buffer = None + self.shm = None + + async def async_send_weights(self, weights): + """ + Send weights to the receiver. Accepts a sync generator or async iterator. + + Args: + weights: Generator or async iterator yielding (name, tensor) pairs + """ + from verl.workers.rollout.utils import ensure_async_iterator + + try: + self._init_socket() + self._init_buffer() + + # send bucket weights + offset = 0 + bucket_meta: dict[str, TensorMetadata] = {} + # dtype = PrecisionType.to_dtype(self.config.dtype) + async for name, weight in ensure_async_iterator(weights): + # model parameters are in fp32 full precision + # (vermouth1992) we should not force cast weight here because some parameters + # (such as moe gate) have to keep fp32 precision. If a weight is bf16 in the rollout side, + # the rollout should automatically cast on demand. However, this would incur a higher weight + # transfer volume. + # weight = weight.to(dtype, non_blocking=True) + + # fill the tensor bucket + if offset + weight.nbytes > self.bucket_size: + get_torch_device().synchronize() + self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) + self.socket.recv() + bucket_meta = {} + offset = 0 + + # TODO: slice embedding layer weight into chunks + assert offset + weight.nbytes <= self.bucket_size, ( + f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." + f"Please increase rollout.update_weights_bucket_megabytes({self.bucket_size_mb} MB)." + ) + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + self.buffer[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) + offset += weight.nbytes + + # send the last bucket + get_torch_device().synchronize() + self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": True}) + self.socket.recv() + finally: + self._cleanup() + + def _init_socket(self): + """Initialize ZMQ REQ socket and bind.""" + self.socket = self.zmq_context.socket(zmq.REQ) + self.socket.bind(self.zmq_handle) + + def _init_buffer(self): + """build communication buffer""" + buffer, shm = None, None + if not self.use_shm: + buffer = torch.empty(self.bucket_size, dtype=torch.uint8, device=f"{get_device_name()}:{get_device_id()}") + handle = reduce_tensor(buffer) + self.socket.send_pyobj(handle) + else: + import uuid + + # Create unique name for shared memory + shm_name = f"verl_weights_{uuid.uuid4().hex}" + shm = create_shared_memory(self.bucket_size, shm_name) + buffer = torch.frombuffer(shm.buf, dtype=torch.uint8) + + comm_metadata = {"name": shm_name, "size": self.bucket_size} + self.socket.send_pyobj(comm_metadata) + + self.socket.recv() + self.buffer = buffer + self.shm = shm + + def _cleanup(self): + """clean up""" + if self.socket is not None: + self.socket.close() + self.socket = None + del self.buffer + self.buffer = None + if self.shm is not None: + self.shm.close() + self.shm.unlink() + del self.shm + self.shm = None + gc.collect() + get_torch_device().ipc_collect() + get_torch_device().empty_cache() + + +class BucketedWeightReceiver: + """ + Receive model weights via bucketed IPC transfer over ZMQ. + + Receives weight tensors from BucketedWeightSender and passes each + bucket to a callback for processing (e.g., loading into the model). + + Args: + zmq_handle: ZMQ IPC socket path (must match sender) + device: Target device for received tensors + use_shm: Use shared memory instead of CUDA IPC + """ + + def __init__( + self, + zmq_handle: str, + device: torch.device, + use_shm: bool = False, + ): + self.zmq_handle = zmq_handle + self.device = device + self.use_shm = use_shm + + self.zmq_context = zmq.Context.instance() + self.socket = None + self.buffer = None + self.shm = None + + def receive_weights(self, on_bucket_received: callable): + """ + Receive weights from sender and process each bucket via callback. + + Args: + on_bucket_received: Callback function(weights: list[(name, tensor)]) called per bucket. + """ + try: + self._init_socket() + self._init_buffer() + + # receive bucket and update weights + while True: + metadata = self.socket.recv_pyobj() + weights, tensor = [], None + for name, meta in metadata["bucket_meta"].items(): + shape, dtype, offset = meta["shape"], meta["dtype"], meta["offset"] + size = dtype.itemsize * shape.numel() + # NOTE: we need to clone the tensor to release CUDA IPC memory + # but for shared memory, it's not necessary and if we do clone, + # it will cause extra memory copy overhead and slow down the process. + tensor = self.buffer[offset : offset + size].view(dtype=dtype).view(shape) + if not self.use_shm: + tensor = tensor.clone() + else: + tensor = tensor.to(self.device) + weights.append((name, tensor)) + get_torch_device().synchronize() + self.socket.send(b"") + on_bucket_received(weights) + del weights, tensor + if metadata["is_last"]: + break + finally: + self._cleanup() + + def _init_socket(self): + """Initialize ZMQ REP socket and connect.""" + self.socket = self.zmq_context.socket(zmq.REP) + self.socket.connect(self.zmq_handle) + + def _init_buffer(self): + """Receive and rebuild communication buffer from sender.""" + comm_metadata = self.socket.recv_pyobj() + buffer, shm = None, None + if not self.use_shm: + handle = comm_metadata + buffer = rebuild_ipc(handle, self.device.index) + assert buffer.dtype == torch.uint8 + else: + shm_name = comm_metadata["name"] + shm_size = comm_metadata["size"] + buffer, shm = rebuild_shared_memory(shm_name, shm_size, dtype=torch.uint8) + self.socket.send(b"") + self.buffer = buffer + self.shm = shm + + def _cleanup(self): + """clean up""" + if self.socket is not None: + self.socket.close() + self.socket = None + # Synchronize before releasing the buffer to ensure all async ops + # referencing it (e.g. clone, .to()) have completed. + get_torch_device().synchronize() + del self.buffer + self.buffer = None + if self.shm is not None: + self.shm.close() + del self.shm + self.shm = None + gc.collect() + get_torch_device().ipc_collect() + get_torch_device().empty_cache() diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index 7fa3b1dd67c..47bd68fa0d9 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -12,21 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import ctypes -import gc import json import logging import os import platform import signal import threading -from multiprocessing import shared_memory from types import MethodType -from typing import Any, Callable, Literal, TypedDict, get_args +from typing import Any, Literal, get_args import torch -import zmq -from verl.utils.device import get_torch_device, is_npu_available +from verl.utils.device import is_npu_available from verl.utils.vllm import TensorLoRARequest, VLLMHijack from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights @@ -105,42 +102,6 @@ def compute_logits( model.compute_logits = MethodType(compute_logits, model) -# copy from https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/rlhf_utils.py -def rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor: - func, args = handle - list_args = list(args) - if device_id is not None: - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - buffer = func(*list_args) - return buffer - - -def create_shared_memory(size: int, name: str): - """Create shared memory for weight transfer. If already exists, attach to it.""" - try: - shm = shared_memory.SharedMemory(name=name, create=True, size=size) - except FileExistsError: - shm = shared_memory.SharedMemory(name=name) - return shm - - -def rebuild_shared_memory(name: str, size: int, dtype=torch.uint8): - """Rebuild tensor from shared memory.""" - shm = shared_memory.SharedMemory(name=name) - tensor = torch.frombuffer(shm.buf[:size], dtype=dtype) - - return tensor, shm - - -class TensorMetadata(TypedDict): - name: str - shape: torch.Size - dtype: torch.dtype - offset: int - - class vLLMColocateWorkerExtension: """ The class for vLLM's worker to inherit from, in the colocate setting. @@ -195,6 +156,8 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False """Update the weights of the rollout model.""" from vllm.platforms import current_platform + from verl.workers.rollout.bucketed_weight_transfer import BucketedWeightReceiver + if current_platform.device_type == "npu" and self.device is None: self.device = torch.device(f"npu:{self.local_rank}") @@ -202,25 +165,6 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False if peft_config and base_sync_done: self.remove_lora(VLLM_LORA_INT_ID) - # build communication buffer - assert self.device is not None - if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None: - self._zmq_ctx = zmq.Context() - socket = self._zmq_ctx.socket(zmq.REP) - socket.connect(self._get_zmq_handle()) - - comm_metadata = socket.recv_pyobj() - buffer, shm = None, None - if not use_shm: - handle = comm_metadata - buffer = rebuild_ipc(handle, self.device.index) - assert buffer.dtype == torch.uint8 - else: - shm_name = comm_metadata["name"] - shm_size = comm_metadata["size"] - buffer, shm = rebuild_shared_memory(shm_name, shm_size, dtype=torch.uint8) - socket.send(b"") - use_standard_weight_load = not (peft_config and base_sync_done) and not is_fp8_model( self.model_runner.vllm_config ) @@ -235,28 +179,17 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False # Re-apply here because async IPC weight sync can happen long after init and lose MoE weight_loader attrs. patch_vllm_moe_model_weight_loader(self.model_runner.model) - # receive bucket and update weights - while True: - metadata = socket.recv_pyobj() - weights, tensor = [], None - for name, meta in metadata["bucket_meta"].items(): - shape, dtype, offset = meta["shape"], meta["dtype"], meta["offset"] - size = dtype.itemsize * shape.numel() - # NOTE: we need to clone the tensor to release CUDA IPC memory - # but for shared memory, it's not necessary and if we do clone, - # it will cause extra memory copy overhead and slow down the process. - tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) - if not use_shm: - tensor = tensor.clone() - else: - tensor = tensor.to(self.device) - weights.append((name, tensor)) - get_torch_device().synchronize() - socket.send(b"") - self._update_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done) - del weights, tensor - if metadata["is_last"]: - break + assert self.device is not None + receiver = BucketedWeightReceiver( + zmq_handle=self._get_zmq_handle(), + device=self.device, + use_shm=use_shm, + ) + receiver.receive_weights( + on_bucket_received=lambda weights: self._update_weights( + weights, peft_config=peft_config, base_sync_done=base_sync_done + ) + ) if self._is_qat_model: # QAT: call process_weights_after_loading AFTER all buckets are received @@ -272,18 +205,6 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False model_config = self.model_runner.vllm_config.model_config process_weights_after_loading(model, model_config, self.device) - # clean up - socket.close() - del buffer - gc.collect() - if shm is not None: - shm.close() - del shm - get_torch_device().synchronize() - gc.collect() - get_torch_device().ipc_collect() - get_torch_device().empty_cache() - def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config: dict, base_sync_done: bool): if peft_config and base_sync_done: weights = dict(weights) diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 53a433cc51e..f73e3892a21 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -26,7 +26,6 @@ - After inference, all the parameters that doesn't belong to this pp rank is freed. """ -import gc import logging import os import time @@ -34,18 +33,16 @@ import ray import torch -import zmq from packaging import version as vs from torch.distributed.device_mesh import DeviceMesh -from torch.multiprocessing.reductions import reduce_tensor from verl import DataProto from verl.third_party.vllm import VLLM_SLEEP_LEVEL, get_version -from verl.utils.device import get_device_id, get_device_name, get_torch_device, is_support_ipc +from verl.utils.device import get_device_id, is_support_ipc from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.base import BaseRollout -from verl.workers.rollout.utils import ensure_async_iterator -from verl.workers.rollout.vllm_rollout.utils import TensorMetadata, get_device_uuid +from verl.workers.rollout.bucketed_weight_transfer import BucketedWeightSender +from verl.workers.rollout.vllm_rollout.utils import get_device_uuid logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) @@ -98,7 +95,6 @@ def __init__( self.sleep_level = VLLM_SLEEP_LEVEL self.device_uuid = get_device_uuid(get_device_id()) - self.zmq_context = zmq.Context() self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{self.device_uuid}.sock" self.use_shm = not is_support_ipc() @@ -165,81 +161,14 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None kwargs={**kwargs, "use_shm": self.use_shm}, ) - # build communication buffer bucket_size_mb = self.config.checkpoint_engine.update_weights_bucket_megabytes - bucket_size = int(bucket_size_mb) << 20 - s = self.zmq_context.socket(zmq.REQ) - s.bind(self.zmq_handle) - - buffer, shm = None, None - if not self.use_shm: - buffer = torch.empty(bucket_size, dtype=torch.uint8, device=f"{get_device_name()}:{get_device_id()}") - handle = reduce_tensor(buffer) - s.send_pyobj(handle) - else: - import uuid - from multiprocessing import shared_memory - - # Create unique name for shared memory - shm_name = f"verl_weights_{uuid.uuid4().hex}" - shm = shared_memory.SharedMemory(name=shm_name, create=True, size=bucket_size) - buffer = torch.frombuffer(shm.buf, dtype=torch.uint8) - - comm_metadata = {"name": shm_name, "size": bucket_size} - s.send_pyobj(comm_metadata) - - s.recv() - - # send bucket weights - offset = 0 - bucket_meta: dict[str, TensorMetadata] = {} - # dtype = PrecisionType.to_dtype(self.config.dtype) - async for name, weight in ensure_async_iterator(weights): - # model parameters are in fp32 full precision - # (vermouth1992) we should not force cast weight here because some parameters - # (such as moe gate) have to keep fp32 precision. If a weight is bf16 in the rollout side, - # the rollout should automatically cast on demand. However, this would incur a higher weight - # transfer volume. - # weight = weight.to(dtype, non_blocking=True) - - # fill the tensor bucket - if offset + weight.nbytes > bucket_size: - get_torch_device().synchronize() - s.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) - s.recv() - bucket_meta = {} - offset = 0 - - # TODO: slice embedding layer weight into chunks - assert offset + weight.nbytes <= bucket_size, ( - f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." - f"Please increase rollout.update_weights_bucket_megabytes({bucket_size_mb} MB)." - ) - bucket_meta[name] = { - "name": name, - "shape": weight.shape, - "dtype": weight.dtype, - "offset": offset, - } - buffer[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) - offset += weight.nbytes - - # send the last bucket - get_torch_device().synchronize() - s.send_pyobj({"bucket_meta": bucket_meta, "is_last": True}) - s.recv() + sender = BucketedWeightSender( + zmq_handle=self.zmq_handle, + bucket_size_mb=bucket_size_mb, + use_shm=self.use_shm, + ) + await sender.async_send_weights(weights) - # clean up - s.close() - del buffer - gc.collect() - if shm is not None: - shm.close() - shm.unlink() - del shm - gc.collect() - get_torch_device().ipc_collect() - get_torch_device().empty_cache() if future is not None: await future