diff --git a/tests/checkpoint_engine/test_correctness_on_gpu.py b/tests/checkpoint_engine/test_correctness_on_gpu.py index 05cf27cf4a2..2045d85ce87 100644 --- a/tests/checkpoint_engine/test_correctness_on_gpu.py +++ b/tests/checkpoint_engine/test_correctness_on_gpu.py @@ -54,7 +54,7 @@ async def test_nccl_checkpoint_engine( # initialize config checkpoint_engine_config = CheckpointEngineConfig( - backend="nccl", engine_kwargs={"nccl": {"rebuild_group": rebuild_group}} + backend="nccl", update_weights_bucket_megabytes=256, engine_kwargs={"nccl": {"rebuild_group": rebuild_group}} ) model_config = HFModelConfig(path=model_path, use_remove_padding=True) rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config) diff --git a/tests/checkpoint_engine/test_correctness_on_npu.py b/tests/checkpoint_engine/test_correctness_on_npu.py index 17f7dbe4b8e..cee5cb69570 100644 --- a/tests/checkpoint_engine/test_correctness_on_npu.py +++ b/tests/checkpoint_engine/test_correctness_on_npu.py @@ -54,7 +54,7 @@ async def test_hccl_checkpoint_engine( # initialize config checkpoint_engine_config = CheckpointEngineConfig( - backend="hccl", engine_kwargs={"hccl": {"rebuild_group": rebuild_group}} + backend="hccl", update_weights_bucket_megabytes=256, engine_kwargs={"hccl": {"rebuild_group": rebuild_group}} ) model_config = HFModelConfig(path=model_path, use_remove_padding=True) rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config) diff --git a/tests/checkpoint_engine/test_naive_correctness.py b/tests/checkpoint_engine/test_naive_correctness.py new file mode 100644 index 00000000000..6f5a6d8863a --- /dev/null +++ b/tests/checkpoint_engine/test_naive_correctness.py @@ -0,0 +1,159 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import pytest +import ray +from omegaconf import DictConfig + +from verl.checkpoint_engine import CheckpointEngineManager +from verl.experimental.agent_loop.agent_loop import AgentLoopManager +from verl.protocol import DataProto +from verl.single_controller.ray import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import get_device_name +from verl.utils.tokenizer import hf_tokenizer +from verl.workers.config import CheckpointEngineConfig +from verl.workers.engine_workers import ActorRolloutRefWorker + + +@pytest.fixture +def init_config() -> DictConfig: + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + + config.trainer.n_gpus_per_node = 8 + config.trainer.nnodes = 1 + config.actor_rollout_ref.actor.use_dynamic_bsz = True + config.actor_rollout_ref.model.path = os.path.expanduser("~/models/Qwen/Qwen3-VL-2B-Instruct") + config.actor_rollout_ref.rollout.name = os.environ.get("ROLLOUT_NAME", "vllm") + config.actor_rollout_ref.rollout.skip_tokenizer_init = False + config.actor_rollout_ref.rollout.max_num_seqs = 256 + config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.8 + config.actor_rollout_ref.rollout.agent.num_workers = 2 + config.actor_rollout_ref.rollout.checkpoint_engine.backend = "naive" + config.actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes = 256 + config.actor_rollout_ref.rollout.enforce_eager = True + + return config + + +@pytest.mark.skip(reason="temporary skip since our ci environment is not ready") +@pytest.mark.asyncio +def test_server_adapter_colocated_weight_update(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + "VLLM_DISABLE_COMPILE_CACHE": "1", + "HCCL_HOST_SOCKET_PORT_RANGE": "60000-60050", + "HCCL_NPU_SOCKET_PORT_RANGE": "61000-61050", + } + } + ) + + # 0. init actor rollout worker group + resource_pool = RayResourcePool( + process_on_nodes=[init_config.trainer.n_gpus_per_node] * init_config.trainer.nnodes, max_colocate_count=3 + ) + actor_rollout_cls = ray.remote(ActorRolloutRefWorker) + cls_dict = { + "actor_rollout": RayClassWithInitArgs( + cls=actor_rollout_cls, config=init_config.actor_rollout_ref, role="actor_rollout" + ) + } + ray_cls_with_init = create_colocated_worker_cls(cls_dict) + wg_dict = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name=get_device_name() + ) + spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys()) + actor_rollout_wg = spawn_wg["actor_rollout"] + actor_rollout_wg.init_model() + + # 1. create AgentLoopManager + agent_loop_manager = AgentLoopManager( + config=init_config, + worker_group=actor_rollout_wg, + rollout_resource_pool=resource_pool, + ) + + # 2. create CheckpointEngineManager + checkpoint_engine_config: CheckpointEngineConfig = omega_conf_to_dataclass( + init_config.actor_rollout_ref.rollout.checkpoint_engine + ) + checkpoint_manager = CheckpointEngineManager( + config=checkpoint_engine_config, + trainer=actor_rollout_wg, + replicas=agent_loop_manager.rollout_replicas, + ) + checkpoint_manager.sleep_replicas() + + # 3. generate prompts + raw_prompts = [ + [ + { + "role": "user", + "content": "This is a test for weight update. If the weight has been correctly " + 'updated and you understand my meaning, please respond with "Test Passed".', + } + ], + [ + { + "role": "user", + "content": "This is a test for weight update. If the weight has been correctly " + 'updated and you understand my meaning, please respond with "Test Passed".', + } + ], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array(raw_prompts), + "agent_name": np.array(["single_turn_agent"] * len(raw_prompts)), + "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), + }, + ) + + # 4. update weights and generate sequences, check if the responses are correct + for _ in range(3): + checkpoint_manager.update_weights() + result = agent_loop_manager.generate_sequences(batch) + checkpoint_manager.sleep_replicas() + + # Check response + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + + for i in range(len(responses)): + valid_tokens = responses[i][response_mask[i].bool()] + response = tokenizer.decode(valid_tokens) + assert "test passed" in response.lower(), f"Response does not contain 'test passed': {response}" + + print("=========================") + print("[OUTPUT]:", response) + print("---") + + ray.shutdown() diff --git a/verl/checkpoint_engine/base.py b/verl/checkpoint_engine/base.py index 91eb658fba8..1eae910594a 100644 --- a/verl/checkpoint_engine/base.py +++ b/verl/checkpoint_engine/base.py @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +import logging +import os +import time from abc import ABC, abstractmethod -from typing import Any, Generator, TypedDict +from dataclasses import dataclass +from typing import Any, Callable, Generator import ray import torch +import zmq from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register @@ -26,12 +31,97 @@ from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig from verl.workers.rollout import BaseRollout, RolloutReplica, get_rollout_class +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + + +@dataclass +class TensorMeta: + """Metadata for a tensor in the checkpoint bucket.""" -class TensorMeta(TypedDict): name: str shape: torch.Size dtype: torch.dtype offset: int + chunk_idx: int | None = None + total_chunks: int | None = None + + +@dataclass +class MasterMetadata: + """Metadata for master process communication. + + Args: + zmq_ip: IP address for ZMQ communication. + zmq_port: Port for ZMQ communication. + dist_ip: IP address for distributed communication (HCCL only). + dist_port: Port for distributed communication (HCCL only). + """ + + zmq_ip: str + zmq_port: int + dist_ip: str | None = None + dist_port: int | None = None + + +class BroadcastOperation: + """Broadcast operation that can be async or sync. + + Args: + rank: The rank of the current process. + bucket: The tensor to broadcast. + metadata: The metadata of the tensor. + socket: The zeromq socket to communicate with master. + topic: The topic to subscribe. + broadcast_fn: The function to broadcast tensor. + async_mode: Whether to execute broadcast asynchronously. If False, runs in __init__. + """ + + def __init__( + self, + rank: int, + bucket: torch.Tensor, + metadata: dict[str, TensorMeta] | None, + socket: zmq.Socket, + topic: str, + broadcast_fn: Callable[[torch.Tensor, int], None], + async_mode: bool = True, + ) -> None: + self.rank = rank + self.bucket = bucket + self.metadata = metadata + self.socket = socket + self.topic = topic + self._broadcast_fn = broadcast_fn + self._async_mode = async_mode + + if self._async_mode: + loop = asyncio.get_running_loop() + self._task = loop.run_in_executor(None, self._run) + else: + self._run() + + def _run(self): + # broadcast tensor meta via zeromq PUB/SUB + if self.rank == 0: + self.socket.send_string(self.topic, flags=zmq.SNDMORE) + self.socket.send_pyobj(self.metadata) + else: + self.socket.recv_string() + self.metadata = self.socket.recv_pyobj() + + # broadcast tensor via backend-specific function + self._broadcast_fn(self.bucket, src_rank=0) + + async def wait_for_complete(self) -> dict[str, TensorMeta]: + """Wait for the broadcast operation to complete. + + Returns: + dict[str, TensorMeta]: The bucket meta after broadcast. + """ + if self._async_mode: + await self._task + return self.metadata class CheckpointEngineRegistry: @@ -158,6 +248,51 @@ def finalize(self): """ raise NotImplementedError + @property + @abstractmethod + def bucket_size(self) -> int: + """Return the bucket size in bytes.""" + raise NotImplementedError + + def _slice_weight_into_chunks(self, name: str, weight: torch.Tensor) -> list[tuple[torch.Tensor, dict]]: + """Slice a large weight tensor into chunks that fit in bucket. + + Args: + name: Name of the weight tensor. + weight: The weight tensor to slice. + + Returns: + List of (chunk, metadata) tuples. + """ + from verl.utils.tensor_utils import compute_weight_chunks + + chunk_infos = compute_weight_chunks(name, weight, self.bucket_size) + + # Check if no slicing needed (single chunk covering entire tensor) + if len(chunk_infos) == 1 and chunk_infos[0].chunk_idx == 0 and chunk_infos[0].total_chunks == 1: + meta = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": 0, + } + return [(weight, meta)] + + chunks = [] + for info in chunk_infos: + chunk = weight[info.start_idx : info.end_idx] + meta = { + "name": name, + "shape": chunk.shape, + "dtype": chunk.dtype, + "offset": 0, # Will be set when filling bucket + "chunk_idx": info.chunk_idx, + "total_chunks": info.total_chunks, + } + chunks.append((chunk, meta)) + + return chunks + @abstractmethod async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): """Send the weights of the model. @@ -176,6 +311,402 @@ async def receive_weights(self) -> Generator[tuple[str, torch.Tensor], None, Non """ raise NotImplementedError + def _yield_tensors_from_buffer( + self, + buffer: torch.Tensor, + bucket_meta: dict, + pending_chunks: dict, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + """Yield tensors from buffer, handling chunk merging for large weights. + + Args: + buffer: The buffer containing the tensor data. + bucket_meta: The metadata of the bucket. + pending_chunks: Dictionary to collect chunks for weights that were sliced. + + Yields: + A tuple of the name of the weight tensor and the tensor itself. + """ + for name, meta in bucket_meta.items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = buffer[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) + + # Check if this is a chunk of a sliced weight + if "chunk_idx" in meta and "total_chunks" in meta: + # This is a chunk, store it for later merging + original_name = meta["name"] + chunk_idx = meta["chunk_idx"] + if original_name not in pending_chunks: + pending_chunks[original_name] = {} + # NOTE: we need to clone the tensor here because the buffer will be + # reused for next bucket, which will overwrite the tensor data + pending_chunks[original_name][chunk_idx] = tensor.clone() + + # Check if we have all chunks for this weight + if len(pending_chunks[original_name]) == meta["total_chunks"]: + # Merge all chunks back into one tensor + chunks_dict = pending_chunks[original_name] + sorted_chunks = [chunks_dict[i] for i in range(meta["total_chunks"])] + merged_tensor = torch.cat(sorted_chunks, dim=0) + yield original_name, merged_tensor + del pending_chunks[original_name] + else: + yield name, tensor + + +class CollectiveCheckpointEngine(CheckpointEngine): + """Base class for collective communication checkpoint engines (NCCL, HCCL). + + This class provides common logic for collective communication backends like NCCL and HCCL. + It implements send_weights and receive_weights with bucket-based double-buffering and + chunked weight handling for large tensors. + + Subclasses must implement: + - _broadcast(bucket, src_rank): Broadcast tensor using backend-specific collective operation + - _synchronize(): Synchronize device operations (e.g., torch.cuda.synchronize) + - _copy_to_buffer(buffer, tensor, offset): Copy tensor to buffer at given offset + - prepare(): Allocate send/receive buffers and return MasterMetadata if master + - finalize(): Free buffers and optionally destroy process group + - init_process_group(rank, world_size, master_metadata): Initialize the process group + + Args: + bucket_size: Bucket size in bytes to transfer multiple weights at one time. + group_name: The name of the process group. + rebuild_group: Whether to rebuild the process group in each update. + is_master: Whether the current process is the master process. + rollout_dtype: The dtype of the weights received from rollout workers. + """ + + def __init__( + self, + bucket_size: int, + group_name: str = "default", + rebuild_group: bool = False, + is_master: bool = False, + rollout_dtype: torch.dtype = torch.bfloat16, + ) -> None: + self._bucket_size = bucket_size + self._rank: int | None = None + self._send_buf = None + self._recv_buf = None + self._world_size: int | None = None + self.group_name = group_name + self.rebuild_group = rebuild_group + self.rollout_dtype = rollout_dtype + self.is_master = is_master + self.topic = "bucket_metadata" + self._async_broadcast_mode = True + + if self.is_master: + self._start_zmq_server() + + @property + def bucket_size(self) -> int: + """Return the bucket size in bytes.""" + return self._bucket_size + + @property + def rank(self) -> int: + """Return the rank of the current process.""" + return self._rank + + @property + def send_buf(self): + """Return the send buffer.""" + return self._send_buf + + @property + def recv_buf(self): + """Return the receive buffer.""" + return self._recv_buf + + def _start_zmq_server(self): + """Start zeromq server for broadcasting bucket tensor metadata.""" + from verl.utils.net_utils import get_free_port, is_valid_ipv6_address + + self._ip = ray.util.get_node_ip_address().strip("[]") + self._zmq_port, self._listen_sock = get_free_port(self._ip) + + context = zmq.Context() + self._socket = context.socket(zmq.PUB) + if is_valid_ipv6_address(self._ip): + address = f"tcp://[{self._ip}]:{self._zmq_port}" + self._socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{self._ip}:{self._zmq_port}" + self._socket.bind(address) + + def _connect_zmq_client(self, metadata: MasterMetadata): + """Connect to zeromq server for receiving bucket tensor metadata.""" + from verl.utils.net_utils import is_valid_ipv6_address + + assert not self.is_master, "Master process should not connect to other processes." + context = zmq.Context() + self._socket = context.socket(zmq.SUB) + if is_valid_ipv6_address(metadata.zmq_ip): + address = f"tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}" + self._socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{metadata.zmq_ip}:{metadata.zmq_port}" + self._socket.connect(address) + self._socket.setsockopt_string(zmq.SUBSCRIBE, self.topic) + + @classmethod + def build_topology( + cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict] + ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]: + """Build communication topology between all workers. + + Args: + trainer_world_size: The world size of the trainer worker group. + rollout_world_size: The world size of the rollout replica. + metadata: A list of metadata `prepare` from all workers. + + Returns: + A tuple of two dictionaries for trainer and rollout worker group. + """ + trainer_kwargs = { + "rank": [0] + [-1] * (trainer_world_size - 1), + "world_size": [rollout_world_size + 1] * trainer_world_size, + "master_metadata": [metadata[0]] * trainer_world_size, + } + rollout_kwargs = { + "rank": list(range(1, rollout_world_size + 1)), + "world_size": [rollout_world_size + 1] * rollout_world_size, + "master_metadata": [metadata[0]] * rollout_world_size, + } + return trainer_kwargs, rollout_kwargs + + # ========== Abstract methods to be implemented by subclasses ========== + + @abstractmethod + def _broadcast(self, bucket, src_rank: int): + """Broadcast tensor using backend-specific collective operation. + + Args: + bucket: The tensor to broadcast. + src_rank: The source rank to broadcast from. + """ + raise NotImplementedError + + @abstractmethod + def _synchronize(self): + """Synchronize device operations.""" + raise NotImplementedError + + @abstractmethod + def _copy_to_buffer(self, buffer, tensor, offset): + """Copy tensor to buffer at given offset. + + Args: + buffer: The buffer to copy to. + tensor: The tensor to copy. + offset: The offset in the buffer. + """ + raise NotImplementedError + + @abstractmethod + def prepare(self) -> MasterMetadata | None: + """Prepare checkpoint engine before each step send_weights/receive_weights.""" + raise NotImplementedError + + @abstractmethod + def finalize(self): + """Finalize checkpoint engine after each step send_weights/receive_weights.""" + raise NotImplementedError + + @abstractmethod + def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata): + """Initialize the process group. + + Args: + rank: The rank of the current process. + world_size: The total number of processes. + master_metadata: The metadata from the master process. + """ + raise NotImplementedError + + def _create_broadcast_send_op(self, bucket, metadata) -> BroadcastOperation: + """Create broadcast operation for sending weights.""" + return BroadcastOperation( + rank=self._rank, + bucket=bucket, + metadata=metadata, + socket=self._socket, + topic=self.topic, + broadcast_fn=self._broadcast, + async_mode=self._async_broadcast_mode, + ) + + def _create_broadcast_recv_op(self, bucket) -> BroadcastOperation: + """Create broadcast operation for receiving weights.""" + return BroadcastOperation( + rank=self._rank, + bucket=bucket, + metadata=None, + socket=self._socket, + topic=self.topic, + broadcast_fn=self._broadcast, + async_mode=self._async_broadcast_mode, + ) + + @torch.no_grad() + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Send the weights of the model. + + Args: + weights: A generator that yields the name of the weight tensor and the tensor itself. + """ + + assert self.rank <= 0, "Trainer workers other than rank 0 should not send weights." + + # For trainer rank other than 0, consume weights without sending. + if self.rank < 0: + for name, weight in weights: + pass + return + + send_buf, recv_buf = self.send_buf, self.recv_buf + broadcast_op = None + + start_time = time.time() + bucket_meta: dict[str, TensorMeta] = {} + offset = 0 + + for name, weight in weights: + weight_size = weight.nbytes + # Check if the weight needs to be sliced into chunks + if weight_size > self.bucket_size: + # Slice the weight into chunks + chunks = self._slice_weight_into_chunks(name, weight) + + for chunk, chunk_meta in chunks: + chunk_size = chunk.nbytes + + # Fill bucket with chunk + if offset + chunk_size > self.bucket_size: + self._synchronize() + + # wait previous broadcast op finish + if broadcast_op is not None: + await broadcast_op.wait_for_complete() + + broadcast_op = self._create_broadcast_send_op( + send_buf, {"bucket_meta": bucket_meta, "is_last": False} + ) + + # swap send_buf and recv_buf + send_buf, recv_buf = recv_buf, send_buf + bucket_meta = {} + offset = 0 + + # Update offset in meta (for key, we use indexed key) + indexed_key = f"{name}_chunk_{chunk_meta['chunk_idx']}" + bucket_meta[indexed_key] = { + "name": chunk_meta["name"], + "shape": chunk_meta["shape"], + "dtype": chunk_meta["dtype"], + "offset": offset, + "chunk_idx": chunk_meta["chunk_idx"], + "total_chunks": chunk_meta["total_chunks"], + } + self._copy_to_buffer(send_buf, chunk, offset) + offset += chunk_size + + continue + + # fill the tensor bucket + if offset + weight_size > self.bucket_size: + self._synchronize() + + # wait previous broadcast op finish + if broadcast_op is not None: + await broadcast_op.wait_for_complete() + + broadcast_op = self._create_broadcast_send_op(send_buf, {"bucket_meta": bucket_meta, "is_last": False}) + + # swap send_buf and recv_buf + send_buf, recv_buf = recv_buf, send_buf + bucket_meta = {} + offset = 0 + + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + self._copy_to_buffer(send_buf, weight, offset) + offset += weight_size + + # broadcast last bucket + self._synchronize() + if broadcast_op is not None: + await broadcast_op.wait_for_complete() + + broadcast_op = self._create_broadcast_send_op(send_buf, {"bucket_meta": bucket_meta, "is_last": True}) + await broadcast_op.wait_for_complete() + logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s") + + @torch.no_grad() + async def receive_weights(self) -> Generator[tuple[str, torch.Tensor], None, None]: + """Receive the weights of the model. + + Yields: + A tuple of the name of the weight tensor and the tensor itself. + """ + assert self.rank > 0, "Rank 0 should not receive weights." + send_buf, recv_buf = self.send_buf, self.recv_buf + total_bytes, total_params = 0, 0 + + # Buffer to collect chunks for weights that were sliced + pending_chunks = {} # name -> {chunk_idx: tensor, ...} + + # receive first bucket + start_time = time.time() + broadcast_op = self._create_broadcast_recv_op(recv_buf) + metadata = await broadcast_op.wait_for_complete() + total_bytes += self.bucket_size + + # swap send_buf and recv_buf + send_buf, recv_buf = recv_buf, send_buf + + while not metadata["is_last"]: + # 1. receive next bucket + broadcast_op = self._create_broadcast_recv_op(recv_buf) + + # 2. yield tensor from send_buf + for tensor_tuple in self._yield_tensors_from_buffer(send_buf, metadata["bucket_meta"], pending_chunks): + total_params += 1 + yield tensor_tuple + + # 3. wait for next bucket broadcast finish + metadata = await broadcast_op.wait_for_complete() + total_bytes += self.bucket_size + + # 4. swap send_buf and recv_buf + self._synchronize() + send_buf, recv_buf = recv_buf, send_buf + + # yield tensor from send_buf + for tensor_tuple in self._yield_tensors_from_buffer(send_buf, metadata["bucket_meta"], pending_chunks): + total_params += 1 + yield tensor_tuple + + # Check if there are any remaining chunks that weren't processed + if pending_chunks: + raise RuntimeError( + f"Received chunks for weights {list(pending_chunks.keys())} but did not receive all chunks for them." + ) + + time_cost = time.time() - start_time + bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) + logger.info( + f"Rank {self.rank} receive weights done, total_params: {total_params}, " + f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" + ) + class CheckpointEngineWithCache(CheckpointEngine): """Checkpoint engine with local cache: shm, disk, etc. This allow to synchronize weights without interrupting @@ -210,6 +741,16 @@ def __init__(self, bucket_size: int, is_master: bool = False) -> None: self.bucket_size = bucket_size self.is_master = is_master + @property + def bucket_size(self) -> int: + """Return the bucket size in bytes.""" + return self._bucket_size + + @bucket_size.setter + def bucket_size(self, value: int): + """Set the bucket size in bytes.""" + self._bucket_size = value + def prepare(self): raise NotImplementedError diff --git a/verl/checkpoint_engine/hccl_checkpoint_engine.py b/verl/checkpoint_engine/hccl_checkpoint_engine.py index c4839999ddf..b195471763c 100644 --- a/verl/checkpoint_engine/hccl_checkpoint_engine.py +++ b/verl/checkpoint_engine/hccl_checkpoint_engine.py @@ -13,84 +13,23 @@ # limitations under the License. import logging import os -import time -from dataclasses import dataclass -from typing import AsyncGenerator, Generator -import ray import torch -import zmq -from vllm.distributed.utils import StatelessProcessGroup -from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta +from verl.checkpoint_engine.base import ( + CheckpointEngineRegistry, + CollectiveCheckpointEngine, + MasterMetadata, +) from verl.utils.distributed import stateless_init_process_group -from verl.utils.net_utils import get_free_port, is_valid_ipv6_address +from verl.utils.net_utils import get_free_port logger = logging.getLogger(__name__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) -@dataclass -class MasterMetadata: - zmq_ip: str - zmq_port: int - dist_ip: str - dist_port: int - - -class BroadcastOperation: - """Async broadcast operation with HCCL in separate thread. - - Args: - rank (int): The rank of the current process. - group_name (str): The name of the HCCL process group. - bucket (torch.Tensor): The tensor to broadcast. - metadata (dict[str, TensorMeta]): The metadata of the tensor. - socket (zmq.Socket): The zeromq socket to communicate with master. - topic (str): The topic to subscribe. - """ - - def __init__( - self, - rank: int, - process_group: StatelessProcessGroup | str, - bucket: torch.Tensor, - metadata: dict[str, TensorMeta], - socket: zmq.Socket, - topic: str, - ) -> None: - self.rank = rank - self.pyhccl = process_group - self.bucket = bucket - self.metadata = metadata - self.socket = socket - self.topic = topic - - self._run() - - def _run(self): - # broadcast tensor meta via zeromq PUB/SUB - if self.rank == 0: - self.socket.send_string(self.topic, flags=zmq.SNDMORE) - self.socket.send_pyobj(self.metadata) - else: - self.socket.recv_string() - self.metadata = self.socket.recv_pyobj() - - # broadcast tensor via HCCL - self.pyhccl.broadcast(self.bucket, src=0) - - async def wait_for_complete(self) -> dict[str, TensorMeta]: - """Wait for the broadcast operation to complete. - - Returns: - dict[str, TensorMeta]: The bucket meta after broadcast. - """ - return self.metadata - - @CheckpointEngineRegistry.register("hccl") -class HCCLCheckpointEngine(CheckpointEngine): +class HCCLCheckpointEngine(CollectiveCheckpointEngine): """HCCL checkpoint engine with collective communication. Args: @@ -110,255 +49,92 @@ def __init__( is_master: bool = False, rollout_dtype: torch.dtype = torch.bfloat16, ) -> None: - self.bucket_size = bucket_size - self.group_name = group_name - self.rebuild_group = rebuild_group - self.rollout_dtype = rollout_dtype - self.pyhccl = None - self.device = torch.npu.current_device() - - # start zeromq server for broadcasting bucket tensor metadata - self.is_master = is_master - self.topic = "bucket_metadata" + self._pyhccl = None + self._device = torch.npu.current_device() + super().__init__( + bucket_size=bucket_size, + group_name=group_name, + rebuild_group=rebuild_group, + is_master=is_master, + rollout_dtype=rollout_dtype, + ) + # HCCL does not support async broadcast, so we set it to False here: + # https://github.com/verl-project/verl/pull/5029/changes#r2773396972 + self._async_broadcast_mode = False if self.is_master: - self._start_zmq_server() - self.dist_port, _ = get_free_port(self.ip) + self._dist_port, _ = get_free_port(self._ip) - def prepare(self) -> MasterMetadata: - self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="npu") - self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="npu") + def prepare(self) -> MasterMetadata | None: + """Prepare checkpoint engine before each step send_weights/receive_weights.""" + self._send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="npu") + self._recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="npu") - return ( - MasterMetadata(zmq_ip=self.ip, zmq_port=self.zmq_port, dist_ip=self.ip, dist_port=self.dist_port) - if self.is_master - else None - ) + if self.is_master: + return MasterMetadata( + zmq_ip=self._ip, + zmq_port=self._zmq_port, + dist_ip=self._ip, + dist_port=self._dist_port, + ) + return None def finalize(self): - """Destroy the HCCL process group if rebuild_group is True.""" + """Finalize checkpoint engine after each step send_weights/receive_weights.""" if self.rebuild_group: - if self.rank >= 0: - self.pyhccl.destroyComm(self.pyhccl.comm) - self.pyhccl = None - self.rank = None - self.world_size = None + if self._rank is not None and self._rank >= 0: + self._pyhccl.destroyComm(self._pyhccl.comm) + self._pyhccl = None + self._rank = None + self._world_size = None - self.send_buf = None - self.recv_buf = None - torch.npu.empty_cache() - - @classmethod - def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): - trainer_kwargs = { - "rank": [0] + [-1] * (trainer_world_size - 1), - "world_size": [rollout_world_size + 1] * trainer_world_size, - "master_metadata": [metadata[0]] * trainer_world_size, - } - rollout_kwargs = { - "rank": list(range(1, rollout_world_size + 1)), - "world_size": [rollout_world_size + 1] * rollout_world_size, - "master_metadata": [metadata[0]] * rollout_world_size, - } - return trainer_kwargs, rollout_kwargs + self._send_buf = None + self._recv_buf = None - def _start_zmq_server(self): - self.ip = ray.util.get_node_ip_address().strip("[]") - self.zmq_port, _ = get_free_port(self.ip) - - context = zmq.Context() - self.socket = context.socket(zmq.PUB) - if is_valid_ipv6_address(self.ip): - address = f"tcp://[{self.ip}]:{self.zmq_port}" - self.socket.setsockopt(zmq.IPV6, 1) - else: - address = f"tcp://{self.ip}:{self.zmq_port}" - - self.socket.bind(address) - - def _connect_zmq_client(self, metadata: MasterMetadata): - assert not self.is_master, "Master process should not connect to other processes." - context = zmq.Context() - self.socket = context.socket(zmq.SUB) - if is_valid_ipv6_address(metadata.zmq_ip): - address = f"tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}" - self.socket.setsockopt(zmq.IPV6, 1) - else: - address = f"tcp://{metadata.zmq_ip}:{metadata.zmq_port}" - - self.socket.connect(address) - self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic) + torch.npu.empty_cache() def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata): """Initialize the HCCL process group. Args: - rank (int): The rank of the current process. - world_size (int): The total number of processes. + rank: The rank of the current process. + world_size: The total number of processes. + master_metadata: The metadata from the master process. """ # For trainer workers other than rank 0, their rank should be -1. if rank < 0: - self.rank = rank - self.world_size = world_size + self._rank = rank + self._world_size = world_size return - if self.rebuild_group or self.pyhccl is None: - self.pyhccl = stateless_init_process_group( - master_metadata.dist_ip, master_metadata.dist_port, rank, world_size, self.device + if self.rebuild_group or self._pyhccl is None: + self._pyhccl = stateless_init_process_group( + master_metadata.dist_ip, master_metadata.dist_port, rank, world_size, self._device ) - self.rank = rank - self.world_size = world_size + self._rank = rank + self._world_size = world_size else: - assert self.rank == rank, f"rank {rank} is not equal to self.rank {self.rank}" - assert self.world_size == world_size, ( - f"world_size {world_size} is not equal to self.world_size {self.world_size}" + assert self._rank == rank, f"rank {rank} is not equal to self.rank {self._rank}" + assert self._world_size == world_size, ( + f"world_size {world_size} is not equal to self.world_size {self._world_size}" ) - if self.rank > 0: + if self._rank > 0: self._connect_zmq_client(master_metadata) # barrier signal = torch.tensor([1], dtype=torch.int8, device=torch.npu.current_device()) - self.pyhccl.all_reduce(signal) - - logger.info(f"init_process_group rank: {self.rank}, world_size: {self.world_size}") - - @torch.no_grad() - async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): - """Send the weights of the model. + self._pyhccl.all_reduce(signal) - Args: - weights: A generator that yields the name of the weight tensor and the tensor itself. - """ - assert self.rank <= 0, "Trainer workers other than rank 0 should not send weights." - - # For trainer rank other than 0, consume weights without sending. - if self.rank < 0: - for name, weight in weights: - pass - return - - send_buf, recv_buf = self.send_buf, self.recv_buf - broadcast_op = None - - start_time = time.time() - bucket_meta: dict[str, TensorMeta] = {} - offset = 0 - for name, weight in weights: - # fill the tensor bucket - if offset + weight.nbytes > self.bucket_size: - torch.npu.synchronize() - - # wait previous broadcast op finish - if broadcast_op is not None: - await broadcast_op.wait_for_complete() - - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=send_buf, - metadata={"bucket_meta": bucket_meta, "is_last": False}, - socket=self.socket, - topic=self.topic, - ) - - # swap send_buf and recv_buf - send_buf, recv_buf = recv_buf, send_buf - bucket_meta = {} - offset = 0 - - assert offset + weight.nbytes <= self.bucket_size, ( - f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." - ) + logger.info(f"init_process_group rank: {self._rank}, world_size: {self._world_size}") - bucket_meta[name] = { - "name": name, - "shape": weight.shape, - "dtype": weight.dtype, - "offset": offset, - } - send_buf[offset : offset + weight.nbytes] = weight.view(-1).view(torch.uint8) - offset += weight.nbytes + def _broadcast(self, bucket, src_rank: int): + """Broadcast tensor using HCCL.""" + self._pyhccl.broadcast(bucket, src=src_rank) - # broadcast last bucket + def _synchronize(self): + """Synchronize NPU operations.""" torch.npu.synchronize() - if broadcast_op is not None: - await broadcast_op.wait_for_complete() - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=send_buf, - metadata={"bucket_meta": bucket_meta, "is_last": True}, - socket=self.socket, - topic=self.topic, - ) - await broadcast_op.wait_for_complete() - logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s") - - @torch.no_grad() - async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: - """Receive the weights of the model. - - Yields: - A tuple of the name of the weight tensor and the tensor itself. - """ - assert self.rank > 0, "Rank 0 should not receive weights." - send_buf, recv_buf = self.send_buf, self.recv_buf - total_bytes, total_params = 0, 0 - - # receive first bucket - start_time = time.time() - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=recv_buf, - metadata=None, - socket=self.socket, - topic=self.topic, - ) - metadata = await broadcast_op.wait_for_complete() - total_bytes += self.bucket_size - total_params += len(metadata["bucket_meta"]) - - # swap send_buf and recv_buf - send_buf, recv_buf = recv_buf, send_buf - while not metadata["is_last"]: - # 1. receive next bucket - broadcast_op = BroadcastOperation( - rank=self.rank, - process_group=self.pyhccl, - bucket=recv_buf, - metadata=None, - socket=self.socket, - topic=self.topic, - ) - - # 2. yield tensor from send_buf - for name, meta in metadata["bucket_meta"].items(): - dtype, shape = meta["dtype"], meta["shape"] - size = dtype.itemsize * shape.numel() - tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) - yield name, tensor - - # 3. wait for next bucket broadcast finish - metadata = await broadcast_op.wait_for_complete() - total_bytes += self.bucket_size - total_params += len(metadata["bucket_meta"]) - - # 4. swap send_buf and recv_buf - torch.npu.synchronize() # sync non-blocking copy - send_buf, recv_buf = recv_buf, send_buf - - # yield tensor from send_buf - for name, meta in metadata["bucket_meta"].items(): - dtype, shape = meta["dtype"], meta["shape"] - size = dtype.itemsize * shape.numel() - tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) - yield name, tensor - - time_cost = time.time() - start_time - bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) - logger.info( - f"Rank {self.rank} receive weights done, total_params: {total_params}, " - f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" - ) + def _copy_to_buffer(self, buffer, tensor, offset): + """Copy tensor to buffer using torch.""" + buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(torch.uint8) diff --git a/verl/checkpoint_engine/kimi_checkpoint_engine.py b/verl/checkpoint_engine/kimi_checkpoint_engine.py index f042c3489d8..f33934d909c 100644 --- a/verl/checkpoint_engine/kimi_checkpoint_engine.py +++ b/verl/checkpoint_engine/kimi_checkpoint_engine.py @@ -245,6 +245,16 @@ def __init__( self.initialized = False self.checkpoint_name = "kimi_checkpoint_engine" + @property + def bucket_size(self) -> int: + """Return the bucket size in bytes.""" + return self._bucket_size + + @bucket_size.setter + def bucket_size(self, value: int): + """Set the bucket size in bytes.""" + self._bucket_size = value + def prepare(self) -> MasterMetadata: if self.is_master: self.ip = ray.util.get_node_ip_address().strip("[]") diff --git a/verl/checkpoint_engine/nccl_checkpoint_engine.py b/verl/checkpoint_engine/nccl_checkpoint_engine.py index 279733900d6..42005a0525c 100644 --- a/verl/checkpoint_engine/nccl_checkpoint_engine.py +++ b/verl/checkpoint_engine/nccl_checkpoint_engine.py @@ -11,90 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import logging import os -import time -from dataclasses import dataclass -from typing import AsyncGenerator, Generator from unittest.mock import patch with patch("importlib.metadata.distributions", return_value=[]): import cupy as cp -import ray import ray.util.collective as collective import torch -import zmq -from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta -from verl.utils.net_utils import get_free_port, is_valid_ipv6_address +from verl.checkpoint_engine.base import ( + CheckpointEngineRegistry, + CollectiveCheckpointEngine, + MasterMetadata, +) logger = logging.getLogger(__name__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) -@dataclass -class MasterMetadata: - zmq_ip: str - zmq_port: int - - -class BroadcastOperation: - """Async broadcast operation with NCCL in separate thread. - - Args: - rank (int): The rank of the current process. - group_name (str): The name of the NCCL process group. - bucket (cp.ndarray | torch.Tensor): The tensor to broadcast. - metadata (dict[str, TensorMeta]): The metadata of the tensor. - socket (zmq.Socket): The zeromq socket to communicate with master. - topic (str): The topic to subscribe. - """ - - def __init__( - self, - rank: int, - group_name: str, - bucket: cp.ndarray | torch.Tensor, - metadata: dict[str, TensorMeta], - socket: zmq.Socket, - topic: str, - ) -> None: - self.rank = rank - self.group_name = group_name - self.bucket = bucket - self.metadata = metadata - self.socket = socket - self.topic = topic - - loop = asyncio.get_running_loop() - self._task = loop.run_in_executor(None, self._run) - - def _run(self): - # broadcast tensor meta via zeromq PUB/SUB - if self.rank == 0: - self.socket.send_string(self.topic, flags=zmq.SNDMORE) - self.socket.send_pyobj(self.metadata) - else: - self.socket.recv_string() - self.metadata = self.socket.recv_pyobj() - - # broadcast tensor via NCCL - collective.broadcast(self.bucket, src_rank=0, group_name=self.group_name) - - async def wait_for_complete(self) -> dict[str, TensorMeta]: - """Wait for the broadcast operation to complete. - - Returns: - dict[str, TensorMeta]: The bucket meta after broadcast. - """ - await self._task - return self.metadata - - @CheckpointEngineRegistry.register("nccl") -class NCCLCheckpointEngine(CheckpointEngine): +class NCCLCheckpointEngine(CollectiveCheckpointEngine): """NCCL checkpoint engine with collective communication. Args: @@ -114,249 +52,81 @@ def __init__( is_master: bool = False, rollout_dtype: torch.dtype = torch.bfloat16, ) -> None: - self.bucket_size = bucket_size - self.group_name = group_name - self.rebuild_group = rebuild_group - self.rollout_dtype = rollout_dtype - - # start zeromq server for broadcasting bucket tensor metadata - self.is_master = is_master - self.topic = "bucket_metadata" - if self.is_master: - self._start_zmq_server() + super().__init__( + bucket_size=bucket_size, + group_name=group_name, + rebuild_group=rebuild_group, + is_master=is_master, + rollout_dtype=rollout_dtype, + ) + self._async_broadcast_mode = True # NCCL uses async broadcast - def prepare(self) -> MasterMetadata: + def prepare(self) -> MasterMetadata | None: + """Prepare checkpoint engine before each step send_weights/receive_weights.""" # For master process, use cupy instead of torch to avoid memory register error # when `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. if self.is_master: - self.send_buf = cp.zeros(self.bucket_size, dtype=cp.uint8) - self.recv_buf = cp.zeros(self.bucket_size, dtype=cp.uint8) + self._send_buf = cp.zeros(self.bucket_size, dtype=cp.uint8) + self._recv_buf = cp.zeros(self.bucket_size, dtype=cp.uint8) else: - self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="cuda") - self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="cuda") + self._send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="cuda") + self._recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="cuda") - return MasterMetadata(zmq_ip=self.ip, zmq_port=self.listen_port) if self.is_master else None + if self.is_master: + return MasterMetadata(zmq_ip=self._ip, zmq_port=self._zmq_port) + return None def finalize(self): - """Destroy the NCCL process group if rebuild_group is True.""" + """Finalize checkpoint engine after each step send_weights/receive_weights.""" if self.rebuild_group: - if self.rank >= 0: + if self._rank is not None and self._rank >= 0: collective.destroy_collective_group(self.group_name) - self.rank = None - self.world_size = None + self._rank = None + self._world_size = None - self.send_buf = None - self.recv_buf = None + self._send_buf = None + self._recv_buf = None torch.cuda.empty_cache() - @classmethod - def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): - trainer_kwargs = { - "rank": [0] + [-1] * (trainer_world_size - 1), - "world_size": [rollout_world_size + 1] * trainer_world_size, - "master_metadata": [metadata[0]] * trainer_world_size, - } - rollout_kwargs = { - "rank": list(range(1, rollout_world_size + 1)), - "world_size": [rollout_world_size + 1] * rollout_world_size, - "master_metadata": [metadata[0]] * rollout_world_size, - } - return trainer_kwargs, rollout_kwargs - - def _start_zmq_server(self): - self.ip = ray.util.get_node_ip_address().strip("[]") - self.listen_port, _ = get_free_port(self.ip) - - context = zmq.Context() - self.socket = context.socket(zmq.PUB) - if is_valid_ipv6_address(self.ip): - address = f"tcp://[{self.ip}]:{self.listen_port}" - self.socket.setsockopt(zmq.IPV6, 1) - else: - address = f"tcp://{self.ip}:{self.listen_port}" - - self.socket.bind(address) - - def _connect_zmq_client(self, metadata: MasterMetadata): - assert not self.is_master, "Master process should not connect to other processes." - context = zmq.Context() - self.socket = context.socket(zmq.SUB) - if is_valid_ipv6_address(metadata.zmq_ip): - address = f"tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}" - self.socket.setsockopt(zmq.IPV6, 1) - else: - address = f"tcp://{metadata.zmq_ip}:{metadata.zmq_port}" - - self.socket.connect(address) - self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic) - def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata): """Initialize the NCCL process group. Args: - rank (int): The rank of the current process. - world_size (int): The total number of processes. + rank: The rank of the current process. + world_size: The total number of processes. + master_metadata: The metadata from the master process. """ # For trainer workers other than rank 0, their rank should be -1. if rank < 0: - self.rank = rank - self.world_size = world_size + self._rank = rank + self._world_size = world_size return if self.rebuild_group or not collective.is_group_initialized(self.group_name): collective.init_collective_group(world_size, rank, "nccl", self.group_name) - self.rank = rank - self.world_size = world_size + self._rank = rank + self._world_size = world_size else: - assert self.rank == rank, f"rank {rank} is not equal to self.rank {self.rank}" - assert self.world_size == world_size, ( - f"world_size {world_size} is not equal to self.world_size {self.world_size}" + assert self._rank == rank, f"rank {rank} is not equal to self.rank {self._rank}" + assert self._world_size == world_size, ( + f"world_size {world_size} is not equal to self.world_size {self._world_size}" ) - if self.rank > 0: + if self._rank > 0: self._connect_zmq_client(master_metadata) collective.barrier(self.group_name) - logger.info(f"init_process_group rank: {self.rank}, world_size: {self.world_size}") - - @torch.no_grad() - async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): - """Send the weights of the model. - - Args: - weights: A generator that yields the name of the weight tensor and the tensor itself. - """ - assert self.rank <= 0, "Trainer workers other than rank 0 should not send weights." - - # For trainer rank other than 0, consume weights without sending. - if self.rank < 0: - for name, weight in weights: - pass - return - - send_buf, recv_buf = self.send_buf, self.recv_buf - broadcast_op = None - - start_time = time.time() - bucket_meta: dict[str, TensorMeta] = {} - offset = 0 - for name, weight in weights: - # fill the tensor bucket - if offset + weight.nbytes > self.bucket_size: - torch.cuda.synchronize() - - # wait previous broadcast op finish - if broadcast_op is not None: - await broadcast_op.wait_for_complete() - - broadcast_op = BroadcastOperation( - rank=self.rank, - group_name=self.group_name, - bucket=send_buf, - metadata={"bucket_meta": bucket_meta, "is_last": False}, - socket=self.socket, - topic=self.topic, - ) - - # swap send_buf and recv_buf - send_buf, recv_buf = recv_buf, send_buf - bucket_meta = {} - offset = 0 - - assert offset + weight.nbytes <= self.bucket_size, ( - f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." - ) + logger.info(f"init_process_group rank: {self._rank}, world_size: {self._world_size}") - bucket_meta[name] = { - "name": name, - "shape": weight.shape, - "dtype": weight.dtype, - "offset": offset, - } - send_buf[offset : offset + weight.nbytes] = cp.asarray(weight.view(-1).view(torch.uint8)) - offset += weight.nbytes + def _broadcast(self, bucket, src_rank: int): + """Broadcast tensor using NCCL.""" + collective.broadcast(bucket, src_rank=src_rank, group_name=self.group_name) - # broadcast last bucket + def _synchronize(self): + """Synchronize CUDA operations.""" torch.cuda.synchronize() - if broadcast_op is not None: - await broadcast_op.wait_for_complete() - broadcast_op = BroadcastOperation( - rank=self.rank, - group_name=self.group_name, - bucket=send_buf, - metadata={"bucket_meta": bucket_meta, "is_last": True}, - socket=self.socket, - topic=self.topic, - ) - await broadcast_op.wait_for_complete() - logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s") - - @torch.no_grad() - async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: - """Receive the weights of the model. - - Yields: - A tuple of the name of the weight tensor and the tensor itself. - """ - assert self.rank > 0, "Rank 0 should not receive weights." - send_buf, recv_buf = self.send_buf, self.recv_buf - total_bytes, total_params = 0, 0 - - # receive first bucket - start_time = time.time() - broadcast_op = BroadcastOperation( - rank=self.rank, - group_name=self.group_name, - bucket=recv_buf, - metadata=None, - socket=self.socket, - topic=self.topic, - ) - metadata = await broadcast_op.wait_for_complete() - total_bytes += self.bucket_size - total_params += len(metadata["bucket_meta"]) - - # swap send_buf and recv_buf - send_buf, recv_buf = recv_buf, send_buf - while not metadata["is_last"]: - # 1. receive next bucket - broadcast_op = BroadcastOperation( - rank=self.rank, - group_name=self.group_name, - bucket=recv_buf, - metadata=None, - socket=self.socket, - topic=self.topic, - ) - - # 2. yield tensor from send_buf - for name, meta in metadata["bucket_meta"].items(): - dtype, shape = meta["dtype"], meta["shape"] - size = dtype.itemsize * shape.numel() - tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) - yield name, tensor - - # 3. wait for next bucket broadcast finish - metadata = await broadcast_op.wait_for_complete() - total_bytes += self.bucket_size - total_params += len(metadata["bucket_meta"]) - - # 4. swap send_buf and recv_buf - torch.cuda.synchronize() # sync non-blocking copy - send_buf, recv_buf = recv_buf, send_buf - - # yield tensor from send_buf - for name, meta in metadata["bucket_meta"].items(): - dtype, shape = meta["dtype"], meta["shape"] - size = dtype.itemsize * shape.numel() - tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) - yield name, tensor - - time_cost = time.time() - start_time - bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) - logger.info( - f"Rank {self.rank} receive weights done, total_params: {total_params}, " - f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" - ) + def _copy_to_buffer(self, buffer, tensor, offset): + """Copy tensor to buffer using cupy.""" + buffer[offset : offset + tensor.nbytes] = cp.asarray(tensor.view(-1).view(torch.uint8)) diff --git a/verl/checkpoint_engine/nixl_checkpoint_engine.py b/verl/checkpoint_engine/nixl_checkpoint_engine.py index fbdefc5b230..303c5cb286c 100644 --- a/verl/checkpoint_engine/nixl_checkpoint_engine.py +++ b/verl/checkpoint_engine/nixl_checkpoint_engine.py @@ -256,6 +256,16 @@ def __init__( self.agent = NixlAgent() self.is_master = is_master + @property + def bucket_size(self) -> int: + """Return the bucket size in bytes.""" + return self._bucket_size + + @bucket_size.setter + def bucket_size(self, value: int): + """Set the bucket size in bytes.""" + self._bucket_size = value + def prepare(self) -> NixlAgentMetadata: """Prepare send and recv bucket. @@ -385,8 +395,54 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, bucket_meta: dict[str, TensorMeta] = {} offset = 0 for name, weight in weights: + weight_size = weight.nbytes + # Check if the weight needs to be sliced into chunks + if weight_size > self.bucket_size: + # Use base class method to slice weight into chunks + chunks = self._slice_weight_into_chunks(name, weight) + + for chunk, chunk_meta in chunks: + chunk_size = chunk.nbytes + + # Fill bucket with chunk + if offset + chunk_size > self.bucket_size: + torch.cuda.synchronize() + + # wait previous bucket to be received + if readable_op is not None: + await readable_op.wait_for_complete() + + # send bucket meta to next agent + readable_op = ReadableOperation( + self.agent, + self.next_agent, + send_descs, + {"bucket_meta": bucket_meta, "is_last": False}, + ) + + # swap send and recv buf + send_buf, recv_buf = recv_buf, send_buf + send_descs, recv_descs = recv_descs, send_descs + bucket_meta = {} + offset = 0 + + # Update offset in meta (for key, we use indexed key) + indexed_key = f"{name}_chunk_{chunk_meta['chunk_idx']}" + bucket_meta[indexed_key] = { + "name": chunk_meta["name"], + "shape": chunk_meta["shape"], + "dtype": chunk_meta["dtype"], + "offset": offset, + "chunk_idx": chunk_meta["chunk_idx"], + "total_chunks": chunk_meta["total_chunks"], + } + send_buf[offset : offset + chunk_size].copy_(chunk.view(-1).view(torch.uint8), non_blocking=True) + offset += chunk_size + + continue + # fill the tensor bucket - if offset + weight.nbytes > self.bucket_size: + if offset + weight_size > self.bucket_size: torch.cuda.synchronize() # wait previous bucket to be received @@ -407,18 +463,14 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, bucket_meta = {} offset = 0 - assert offset + weight.nbytes <= self.bucket_size, ( - f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." - ) - bucket_meta[name] = { "name": name, "shape": weight.shape, "dtype": weight.dtype, "offset": offset, } - send_buf[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) - offset += weight.nbytes + send_buf[offset : offset + weight_size].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) + offset += weight_size # send last bucket meta to next agent torch.cuda.synchronize() @@ -443,6 +495,9 @@ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None send_descs, recv_descs = self.send_descs, self.recv_descs total_bytes, total_params = 0, 0 + # Buffer to collect chunks for weights that were sliced + pending_chunks = {} # name -> {chunk_idx: tensor, ...} + # receive first bucket from previous agent start_time = time.time() read_op = ReadOperation(self.agent, self.prev_agent, recv_descs, self.bucket_size) @@ -450,7 +505,6 @@ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None read_op.begin_read() await read_op.wait_for_complete() total_bytes += self.bucket_size - total_params += len(metadata["bucket_meta"]) # swap send and recv buf send_buf, recv_buf = recv_buf, send_buf @@ -472,18 +526,15 @@ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None read_op.begin_read() # 3. yield tensor from send_buf - for name, meta in metadata["bucket_meta"].items(): - dtype, shape = meta["dtype"], meta["shape"] - size = dtype.itemsize * shape.numel() - tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) - yield name, tensor + for tensor_tuple in self._yield_tensors_from_buffer(send_buf, metadata["bucket_meta"], pending_chunks): + total_params += 1 + yield tensor_tuple # 4. wait for next agent read complete and read from previous agent complete if readable_op is not None: await readable_op.wait_for_complete() await read_op.wait_for_complete() total_bytes += self.bucket_size - total_params += len(next_metadata["bucket_meta"]) # 5. swap send and recv buf torch.cuda.synchronize() # sync non-blocking copy @@ -502,15 +553,19 @@ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None ) # yield tensor from send_buf - for name, meta in metadata["bucket_meta"].items(): - dtype, shape = meta["dtype"], meta["shape"] - size = dtype.itemsize * shape.numel() - tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) - yield name, tensor + for tensor_tuple in self._yield_tensors_from_buffer(send_buf, metadata["bucket_meta"], pending_chunks): + total_params += 1 + yield tensor_tuple # wait for next agent read complete if readable_op is not None: await readable_op.wait_for_complete() + + # Check if there are any remaining chunks that weren't processed + if pending_chunks: + raise RuntimeError( + f"Received chunks for weights {list(pending_chunks.keys())} but did not receive all chunks for them." + ) time_cost = time.time() - start_time bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) logger.info( diff --git a/verl/utils/tensor_utils.py b/verl/utils/tensor_utils.py new file mode 100644 index 00000000000..64d5ef9fffe --- /dev/null +++ b/verl/utils/tensor_utils.py @@ -0,0 +1,108 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions for tensor operations.""" + +import logging +import os +from dataclasses import dataclass +from functools import reduce + +import torch + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + + +@dataclass +class WeightChunkInfo: + """Information about a chunk of a sliced weight tensor. + + Attributes: + start_idx: Start index along the first dimension. + end_idx: End index along the first dimension (exclusive). + chunk_idx: Index of this chunk (0-based). + total_chunks: Total number of chunks. + """ + + start_idx: int + end_idx: int + chunk_idx: int + total_chunks: int + + +def compute_weight_chunks( + name: str, + weight: torch.Tensor, + bucket_size: int, +) -> list[WeightChunkInfo]: + """Compute how to slice a weight tensor into chunks that fit in bucket. + + This function calculates the chunking strategy for large weight tensors + that exceed the bucket size. The tensor is sliced along its first dimension. + + Args: + name: Name of the weight tensor (for error messages and logging). + weight: The weight tensor to slice. + bucket_size: Maximum size in bytes for each chunk. + + Returns: + List of WeightChunkInfo, one for each chunk. + Returns a single-element list if the weight doesn't need slicing. + + Raises: + ValueError: If a single slice is larger than bucket_size. + """ + + weight_size = weight.nbytes + if weight_size <= bucket_size: + # No slicing needed + return [WeightChunkInfo(start_idx=0, end_idx=weight.shape[0], chunk_idx=0, total_chunks=1)] + + # Slice the weight along the first dimension into chunks + dtype_size = weight.element_size() + numel_per_chunk = bucket_size // dtype_size + + # Calculate chunk size along the first dimension + first_dim_size = weight.shape[0] + elements_per_row = reduce(lambda x, y: x * y, weight.shape[1:], 1) + if elements_per_row == 0: + # Empty tensor, return as is + return [WeightChunkInfo(start_idx=0, end_idx=first_dim_size, chunk_idx=0, total_chunks=1)] + + chunk_dim_size = numel_per_chunk // elements_per_row + if chunk_dim_size == 0: + raise ValueError( + f"Weight '{name}' with shape {weight.shape} is too large to be chunked. A single slice " + f"along the first dimension is larger than the bucket size ({bucket_size} bytes). " + f"Please increase `checkpoint_engine.update_weights_bucket_megabytes`." + ) + + num_chunks = (first_dim_size + chunk_dim_size - 1) // chunk_dim_size + logger.info(f"Slicing weight {name} ({weight.shape}, {weight.dtype}, {weight_size} bytes) into {num_chunks} chunks") + + chunks = [] + start_idx = 0 + for chunk_idx in range(num_chunks): + end_idx = min(start_idx + chunk_dim_size, first_dim_size) + chunks.append( + WeightChunkInfo( + start_idx=start_idx, + end_idx=end_idx, + chunk_idx=chunk_idx, + total_chunks=num_chunks, + ) + ) + start_idx = end_idx + + return chunks diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index 7fa3b1dd67c..b2568325baa 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -236,6 +236,8 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False patch_vllm_moe_model_weight_loader(self.model_runner.model) # receive bucket and update weights + # Buffer to collect chunks for weights that were sliced + pending_chunks = {} # name -> {chunk_idx: tensor, ...} while True: metadata = socket.recv_pyobj() weights, tensor = [], None @@ -250,7 +252,27 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False tensor = tensor.clone() else: tensor = tensor.to(self.device) - weights.append((name, tensor)) + + # Check if this is a chunk of a sliced weight + if "chunk_idx" in meta and "total_chunks" in meta: + # This is a chunk, store it for later merging + original_name = meta["name"] + chunk_idx = meta["chunk_idx"] + if original_name not in pending_chunks: + pending_chunks[original_name] = {} + pending_chunks[original_name][chunk_idx] = tensor + + # Check if we have all chunks for this weight + if len(pending_chunks[original_name]) == meta["total_chunks"]: + # Merge all chunks back into one tensor + chunks_dict = pending_chunks[original_name] + sorted_chunks = [chunks_dict[i] for i in range(meta["total_chunks"])] + merged_tensor = torch.cat(sorted_chunks, dim=0) + weights.append((original_name, merged_tensor)) + del pending_chunks[original_name] + else: + weights.append((name, tensor)) + get_torch_device().synchronize() socket.send(b"") self._update_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done) @@ -258,6 +280,12 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False if metadata["is_last"]: break + # Check if there are any remaining chunks that weren't processed + if pending_chunks: + raise RuntimeError( + f"Received chunks for weights {list(pending_chunks.keys())} but did not receive all chunks for them." + ) + if self._is_qat_model: # QAT: call process_weights_after_loading AFTER all buckets are received from verl.utils.qat import manual_process_weights_after_loading diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 53a433cc51e..c04bb0453a0 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -42,6 +42,7 @@ from verl import DataProto from verl.third_party.vllm import VLLM_SLEEP_LEVEL, get_version from verl.utils.device import get_device_id, get_device_name, get_torch_device, is_support_ipc +from verl.utils.tensor_utils import compute_weight_chunks from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.base import BaseRollout from verl.workers.rollout.utils import ensure_async_iterator @@ -202,27 +203,55 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None # transfer volume. # weight = weight.to(dtype, non_blocking=True) + # Check if the weight needs to be sliced into chunks + # (e.g., large embedding layer that exceeds bucket_size) + weight_size = weight.nbytes + if weight_size > bucket_size: + # Use shared utility to compute chunk info + chunk_infos = compute_weight_chunks(name, weight, bucket_size) + + for info in chunk_infos: + # Extract chunk along first dimension + chunk = weight[info.start_idx : info.end_idx] + chunk_size = chunk.nbytes + + # Fill bucket with chunk + if offset + chunk_size > bucket_size: + get_torch_device().synchronize() + s.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) + s.recv() + bucket_meta = {} + offset = 0 + + bucket_meta[f"{name}_chunk_{info.chunk_idx}"] = { + "name": name, + "shape": chunk.shape, + "dtype": chunk.dtype, + "offset": offset, + "chunk_idx": info.chunk_idx, + "total_chunks": info.total_chunks, + } + buffer[offset : offset + chunk_size].copy_(chunk.view(-1).view(torch.uint8), non_blocking=True) + offset += chunk_size + + continue + # fill the tensor bucket - if offset + weight.nbytes > bucket_size: + if offset + weight_size > bucket_size: get_torch_device().synchronize() s.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) s.recv() bucket_meta = {} offset = 0 - # TODO: slice embedding layer weight into chunks - assert offset + weight.nbytes <= bucket_size, ( - f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." - f"Please increase rollout.update_weights_bucket_megabytes({bucket_size_mb} MB)." - ) bucket_meta[name] = { "name": name, "shape": weight.shape, "dtype": weight.dtype, "offset": offset, } - buffer[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) - offset += weight.nbytes + buffer[offset : offset + weight_size].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) + offset += weight_size # send the last bucket get_torch_device().synchronize()