Skip to content

Commit 6f33c55

Browse files
authored
[ckpt] feat: add checkpoint-engine abstraction (#4775)
### What does this PR do? Add Checkpoint Engine abstraction. #### Overview Checkpoint Engine is an unified abstract layer to synchronize weights between various training backends and inference backends. It provides three unified APIs: - send_weights: get named tensors from generator and send them in streaming manner. - receive_weights: return a tensor generator that yield named tensors in streaming manner. - get_weights: return a tensor generator that yield named tensors in streaming manner, used for each inference instance update weight independently from local cache (e.g share memory, disk). For more detail, see `verl/checkpoint_engine/README.md`. #### verl core <img width="640" height="167" alt="image" src="https://github.com/user-attachments/assets/fbd125d7-b461-4c89-9678-b95a2ef89c33" /> #### checkpoint engine <img width="1004" height="409" alt="checkpoint-engine" src="https://github.com/user-attachments/assets/fc263c1f-17b2-4579-9842-87b24e12abc7" />
1 parent c408a6e commit 6f33c55

File tree

12 files changed

+1433
-1
lines changed

12 files changed

+1433
-1
lines changed

.github/workflows/npu_unit_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ jobs:
109109
- name: Run all NPU unit tests
110110
run: |
111111
export PYTHONPATH=$PYTHONPATH:/Megatron-LM
112-
pytest -s -x --ignore-glob="*test_special_*.py" --ignore-glob="*on_cpu.py" --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob="tests/special*" --ignore-glob="tests/experimental" --ignore-glob="tests/workers/reward_model" --ignore-glob="*test_rvdz*" --ignore-glob="*test_ray_collectives*" --ignore-glob="*test_nvtx_profile*" tests/
112+
pytest -s -x --ignore-glob="*test_special_*.py" --ignore-glob="*on_cpu.py" --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob="tests/special*" --ignore-glob="tests/experimental" --ignore-glob="tests/workers/reward_model" --ignore-glob="*test_rvdz*" --ignore-glob="*test_ray_collectives*" --ignore-glob="*test_nvtx_profile*" --ignore-glob="*test_nccl*" --ignore-glob="*test_nixl*" tests/
113113
- name: Testing FSDP2 actor functionality
114114
run: |
115115
torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/workers/actor/test_special_dp_actor.py
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
import pytest
17+
import ray
18+
19+
from tests.checkpoint_engine.test_utils import create_rollout_worker_group, create_trainer_worker_group
20+
from verl.single_controller.ray.base import (
21+
RayResourcePool,
22+
split_resource_pool,
23+
)
24+
25+
26+
@pytest.mark.parametrize("rebuild_group", [False, True])
27+
@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)])
28+
def test_nccl_checkpoint_engine(
29+
rebuild_group,
30+
num_trainer,
31+
num_rollout,
32+
num_nodes=1,
33+
num_gpus_per_node=8,
34+
check_allclose=True,
35+
model_path="~/models/Qwen/Qwen3-8B-Base",
36+
):
37+
model_path = os.path.expanduser(model_path)
38+
ray.init(
39+
runtime_env={
40+
"env_vars": {
41+
"UCX_TLS": "rc,tcp,cuda",
42+
"UCX_MAX_RNDV_RAILS": "4",
43+
"UCX_LOG_LEVEL": "INFO",
44+
"VERL_LOGGING_LEVEL": "DEBUG",
45+
}
46+
}
47+
)
48+
49+
resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3)
50+
trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout])
51+
checkpoint_kwargs = {
52+
"bucket_size": 2 * 1024 * 1024 * 1024, # 2GB
53+
"rebuild_group": rebuild_group,
54+
}
55+
56+
trainer = create_trainer_worker_group(model_path, trainer_pool, "nccl", checkpoint_kwargs)
57+
trainer.reset()
58+
rollout = create_rollout_worker_group(
59+
model_path, rollout_pool, "nccl", checkpoint_kwargs, check_allclose=check_allclose
60+
)
61+
62+
for _ in range(3):
63+
# 1. prepare all workers
64+
metadata = ray.get(
65+
trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size)
66+
+ rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size)
67+
)
68+
trainer_kwargs = {
69+
"method": ["init_process_group"] * trainer.world_size,
70+
"rank": [0] + [-1] * (trainer.world_size - 1),
71+
"world_size": [rollout.world_size + 1] * trainer.world_size,
72+
"master_metadata": [metadata[0]] * trainer.world_size,
73+
}
74+
rollout_kwargs = {
75+
"method": ["init_process_group"] * rollout.world_size,
76+
"rank": list(range(1, rollout.world_size + 1)),
77+
"world_size": [rollout.world_size + 1] * rollout.world_size,
78+
"master_metadata": [metadata[0]] * rollout.world_size,
79+
}
80+
81+
# 2. init process group between all workers
82+
ray.get(
83+
trainer.execute_checkpoint_engine(**trainer_kwargs) + rollout.execute_checkpoint_engine(**rollout_kwargs)
84+
)
85+
86+
# 3. update weights of all workers
87+
ray.get(trainer.update_weights() + rollout.update_weights())
88+
89+
# 4. finish all workers
90+
ray.get(
91+
trainer.execute_checkpoint_engine(["finish"] * trainer.world_size)
92+
+ rollout.execute_checkpoint_engine(["finish"] * rollout.world_size)
93+
)
94+
95+
# 5. check weights of rollout workers
96+
rollout.check_weights()
97+
98+
ray.shutdown()
99+
100+
101+
if __name__ == "__main__":
102+
test_nccl_checkpoint_engine(
103+
rebuild_group=False,
104+
num_trainer=2,
105+
num_rollout=30,
106+
num_nodes=4,
107+
num_gpus_per_node=8,
108+
check_allclose=False,
109+
model_path=os.environ["HDFS_ROOT"] + "/model/Qwen3-30B-A3B-Base",
110+
)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
import pytest
17+
import ray
18+
19+
from tests.checkpoint_engine.test_utils import create_rollout_worker_group, create_trainer_worker_group
20+
from verl.single_controller.ray.base import (
21+
RayResourcePool,
22+
split_resource_pool,
23+
)
24+
25+
26+
@pytest.mark.skip(reason="temporary skip since our ci environment is not ready")
27+
@pytest.mark.parametrize("device", ["cuda", "cpu"])
28+
@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)])
29+
def test_nixl_checkpoint_engine(
30+
num_trainer,
31+
num_rollout,
32+
device,
33+
num_nodes=1,
34+
num_gpus_per_node=8,
35+
check_allclose=True,
36+
model_path="~/models/Qwen/Qwen3-8B-Base",
37+
):
38+
model_path = os.path.expanduser(model_path)
39+
ray.init(
40+
runtime_env={
41+
"env_vars": {
42+
# TODO: it's pretty hard to set these environment variables right, please consult
43+
# with your network admin. Maybe auto adjust UCX_* according to NCCL_IB_*?
44+
"UCX_TLS": "rc,ud,cuda",
45+
# "UCX_IB_GID_INDEX": "3", # NCCL_IB_GID_INDEX
46+
# "UCX_IB_DEVICES": "mlx5_1:1,mlx5_2:1,mlx5_3:1", # NCCL_IB_HCA
47+
"UCX_RC_TIMEOUT": "30s", # NCCL_IB_TIMEOUT
48+
"UCX_RC_RETRY_COUNT": "7", # NCCL_IB_RETRY_COUNT
49+
"UCX_KEEPALIVE_INTERVAL": "1s",
50+
"UCX_KEEPALIVE_NUM_EPS": "10",
51+
"UCX_MAX_RNDV_RAILS": "4",
52+
"UCX_LOG_LEVEL": "INFO",
53+
"VERL_LOGGING_LEVEL": "DEBUG",
54+
}
55+
}
56+
)
57+
58+
resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3)
59+
trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout])
60+
checkpoint_kwargs = {
61+
"bucket_size": 2 * 1024 * 1024 * 1024, # 2GB
62+
"device": device,
63+
}
64+
65+
trainer = create_trainer_worker_group(model_path, trainer_pool, "nixl", checkpoint_kwargs)
66+
trainer.reset()
67+
rollout = create_rollout_worker_group(
68+
model_path, rollout_pool, "nixl", checkpoint_kwargs, device=device, check_allclose=check_allclose
69+
)
70+
71+
for _ in range(3):
72+
# 1. prepare all workers
73+
metadata = ray.get(
74+
trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size)
75+
+ rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size)
76+
)
77+
78+
trainer_kwargs = {
79+
"method": ["init_process_group"] * trainer.world_size,
80+
"rank": [0] + [-1] * (trainer.world_size - 1),
81+
"world_size": [rollout.world_size + 1] * trainer.world_size,
82+
"prev_agent_metadata": [None] * trainer.world_size,
83+
"next_agent_metadata": [metadata[-rollout.world_size]] + [None] * (trainer.world_size - 1),
84+
}
85+
86+
rollout_kwargs = {
87+
"method": ["init_process_group"] * rollout.world_size,
88+
"rank": list(range(1, rollout.world_size + 1)),
89+
"world_size": [rollout.world_size + 1] * rollout.world_size,
90+
"prev_agent_metadata": [metadata[0]] + metadata[-rollout.world_size : -1],
91+
"next_agent_metadata": metadata[-rollout.world_size + 1 :] + [None],
92+
}
93+
94+
# 2. init process group between all workers
95+
ray.get(
96+
trainer.execute_checkpoint_engine(**trainer_kwargs) + rollout.execute_checkpoint_engine(**rollout_kwargs)
97+
)
98+
99+
# 3. update weights of all workers
100+
ray.get(trainer.update_weights() + rollout.update_weights())
101+
102+
# 4. finish all workers
103+
ray.get(
104+
trainer.execute_checkpoint_engine(["finish"] * trainer.world_size)
105+
+ rollout.execute_checkpoint_engine(["finish"] * rollout.world_size)
106+
)
107+
108+
# 5. check weights of rollout workers
109+
rollout.check_weights()
110+
111+
ray.shutdown()
112+
113+
114+
if __name__ == "__main__":
115+
test_nixl_checkpoint_engine(
116+
num_trainer=2,
117+
num_rollout=30,
118+
device="cuda",
119+
num_nodes=4,
120+
num_gpus_per_node=8,
121+
check_allclose=False,
122+
model_path=os.environ["HDFS_ROOT"] + "/model/Qwen3-30B-A3B-Base",
123+
)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import ray
16+
import torch
17+
from transformers import AutoModelForCausalLM
18+
19+
from verl.checkpoint_engine import CheckpointEngineRegistry
20+
from verl.single_controller.base.decorator import Dispatch, register
21+
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
22+
from verl.utils.fs import copy_to_local
23+
from verl.workers.config import FSDPEngineConfig, HFModelConfig
24+
from verl.workers.engine_workers import TrainingWorker, TrainingWorkerConfig
25+
26+
27+
class TrainingWorkerTest(TrainingWorker):
28+
def __init__(self, config: TrainingWorkerConfig, checkpoint_backend: str, checkpoint_kwargs: dict) -> None:
29+
copy_to_local(config.model_config.path)
30+
super().__init__(config)
31+
if torch.distributed.get_rank() == 0 and checkpoint_backend == "nccl":
32+
checkpoint_kwargs["is_master"] = True
33+
self.checkpoint_engine = CheckpointEngineRegistry.new(checkpoint_backend, **checkpoint_kwargs)
34+
35+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
36+
async def update_weights(self):
37+
per_tensor_param, _ = self.engine.get_per_tensor_param()
38+
await self.checkpoint_engine.send_weights(per_tensor_param)
39+
40+
@register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False)
41+
def execute_checkpoint_engine(self, method: str, *args, **kwargs):
42+
return getattr(self.checkpoint_engine, method)(*args, **kwargs)
43+
44+
45+
class RolloutWorkerTest:
46+
def __init__(
47+
self,
48+
model_path,
49+
checkpoint_backend: str,
50+
checkpoint_kwargs: dict,
51+
device: str = "cuda",
52+
check_allclose: bool = True,
53+
) -> None:
54+
self.checkpoint_engine = CheckpointEngineRegistry.new(checkpoint_backend, **checkpoint_kwargs)
55+
local_path = copy_to_local(model_path)
56+
self.model = AutoModelForCausalLM.from_pretrained(local_path, torch_dtype=torch.bfloat16)
57+
self.model.to(device)
58+
self.check_allclose = check_allclose
59+
self.received_weights: dict[str, torch.Tensor] = {}
60+
61+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
62+
async def update_weights(self):
63+
async for name, weight in self.checkpoint_engine.receive_weights():
64+
weight = weight.clone()
65+
if self.check_allclose:
66+
self.received_weights[name] = weight.clone().to(torch.bfloat16)
67+
68+
@register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False)
69+
def execute_checkpoint_engine(self, method: str, *args, **kwargs):
70+
return getattr(self.checkpoint_engine, method)(*args, **kwargs)
71+
72+
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
73+
def check_weights(self):
74+
if not self.check_allclose:
75+
return
76+
for name, weight in self.model.state_dict().items():
77+
assert name in self.received_weights, f"weight {name} not received"
78+
assert torch.allclose(weight, self.received_weights[name]), f"weight {name} not equal"
79+
self.received_weights.clear()
80+
81+
82+
def create_trainer_worker_group(
83+
model_path: str, resource_pool: RayResourcePool, checkpoint_backend: str, checkpoint_kwargs: dict
84+
) -> RayWorkerGroup:
85+
local_path = copy_to_local(model_path)
86+
model_config = HFModelConfig(path=local_path, use_remove_padding=True)
87+
engine_config = FSDPEngineConfig(forward_only=True, fsdp_size=resource_pool.world_size, strategy="fsdp")
88+
89+
trainer_config = TrainingWorkerConfig(
90+
model_type="language_model",
91+
model_config=model_config,
92+
engine_config=engine_config,
93+
)
94+
ray_cls_with_init = RayClassWithInitArgs(
95+
cls=ray.remote(TrainingWorkerTest),
96+
config=trainer_config,
97+
checkpoint_backend=checkpoint_backend,
98+
checkpoint_kwargs=checkpoint_kwargs,
99+
)
100+
ray_cls_with_init.update_options(
101+
{
102+
"runtime_env": {
103+
"env_vars": {
104+
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
105+
}
106+
}
107+
}
108+
)
109+
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
110+
return wg
111+
112+
113+
def create_rollout_worker_group(
114+
model_path: str, resource_pool: RayResourcePool, checkpoint_backend: str, checkpoint_kwargs: dict, **kwargs
115+
) -> RayWorkerGroup:
116+
ray_cls_with_init = RayClassWithInitArgs(
117+
cls=ray.remote(RolloutWorkerTest),
118+
model_path=model_path,
119+
checkpoint_backend=checkpoint_backend,
120+
checkpoint_kwargs=checkpoint_kwargs,
121+
**kwargs,
122+
)
123+
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
124+
return wg
File renamed without changes.

tests/special_sanity/check_device_api_usage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"verl/workers/engine/veomni/transformer_impl.py", # appear in default device_name
4343
"verl/workers/rollout/vllm_rollout/vllm_async_server.py", # appear in config.cudagraph_capture_sizes
4444
"verl/workers/rollout/sglang_rollout/async_sglang_server.py", # manually set CUDA_VISIBLE_DEVICES
45+
"verl/checkpoint_engine", # checkpoint engine backend are device specific
4546
]
4647

4748
# directory or file path must contain keyword "nccl"

0 commit comments

Comments
 (0)