Skip to content

Commit 6f4942b

Browse files
kip-cxjkip-cxjgemini-code-assist[bot]
authored
[ckpt] feat: add kimi ckpt engine backend (#4954)
### What does this PR do? Based on ckpt engine abstraction [add checkpoint-engine abstraction](#4775), in this PR, we add kimi_ckpt_engine backend to support both GPU and huawei Ascend NPU. Since establishing communication domains across trainer and rollout workers is required, this PR also depends on the [newly added communication domain support](MoonshotAI/checkpoint-engine#66) in kimi_ckpt_engine. TODO: - [x] Add detailed performance testing results in checkpoint engine README. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: [add Hccl ckpt engine backend](#4885) - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test We have verified the functionality on both GPU and NPU. Performance benchmarks on a 32 NPU environment show promising results; however, due to a lack of available GPU resources, performance data for GPU is still pending. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. --------- Co-authored-by: kip-cxj <cuixiaojin@huawei.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ea042c2 commit 6f4942b

File tree

5 files changed

+503
-6
lines changed

5 files changed

+503
-6
lines changed

tests/checkpoint_engine/test_correctness_on_gpu.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
RayResourcePool,
2323
split_resource_pool,
2424
)
25+
from verl.utils.device import get_device_name
2526
from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig
2627

2728

@@ -127,6 +128,53 @@ async def test_nixl_checkpoint_engine(
127128
ray.shutdown()
128129

129130

131+
@pytest.mark.skip(reason="temporary skip since our ci environment is not ready")
132+
@pytest.mark.asyncio
133+
@pytest.mark.parametrize("rebuild_group", [False])
134+
@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)])
135+
async def test_kimi_checkpoint_engine(
136+
rebuild_group,
137+
num_trainer,
138+
num_rollout,
139+
num_nodes=1,
140+
num_gpus_per_node=8,
141+
check_allclose=True,
142+
model_path="~/models/Qwen/Qwen3-8B-Base",
143+
):
144+
model_path = os.path.expanduser(model_path)
145+
ray.init(
146+
runtime_env={
147+
"env_vars": {
148+
"NCCL_IB_HCA": "mlx5",
149+
"VERL_LOGGING_LEVEL": "DEBUG",
150+
}
151+
}
152+
)
153+
154+
# initialize config
155+
checkpoint_engine_config = CheckpointEngineConfig(
156+
backend="kimi_ckpt_engine", engine_kwargs={"kimi_ckpt_engine": {"rebuild_group": rebuild_group}}
157+
)
158+
model_config = HFModelConfig(path=model_path, use_remove_padding=True)
159+
rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config)
160+
161+
# create trainer and rollout worker group
162+
resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3)
163+
resource_pool.get_placement_groups(device_name=get_device_name())
164+
trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout])
165+
trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config)
166+
trainer.reset()
167+
rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose)
168+
169+
# create checkpoint engine manager
170+
checkpoint_manager = CheckpointEngineManager(backend="kimi_ckpt_engine", trainer=trainer, replicas=replicas)
171+
for _ in range(3):
172+
await checkpoint_manager.update_weights()
173+
rollout.check_weights()
174+
175+
ray.shutdown()
176+
177+
130178
if __name__ == "__main__":
131179
test_nccl_checkpoint_engine(
132180
rebuild_group=False,

tests/checkpoint_engine/test_correctness_on_npu.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,53 @@ async def test_hccl_checkpoint_engine(
7474
ray.shutdown()
7575

7676

77+
@pytest.mark.skip(reason="temporary skip since our ci environment is not ready")
78+
@pytest.mark.asyncio
79+
@pytest.mark.parametrize("rebuild_group", [False])
80+
@pytest.mark.parametrize("num_trainer, num_rollout", [(4, 28)])
81+
async def test_kimi_checkpoint_engine(
82+
rebuild_group,
83+
num_trainer,
84+
num_rollout,
85+
num_nodes=2,
86+
num_gpus_per_node=16,
87+
check_allclose=True,
88+
model_path="~/models/Qwen/Qwen3-32B",
89+
):
90+
model_path = os.path.expanduser(model_path)
91+
ray.init(
92+
runtime_env={
93+
"env_vars": {
94+
"HCCL_CONNECT_TIMEOUT": "1500",
95+
"VERL_LOGGING_LEVEL": "DEBUG",
96+
}
97+
}
98+
)
99+
100+
# initialize config
101+
checkpoint_engine_config = CheckpointEngineConfig(
102+
backend="kimi_ckpt_engine", engine_kwargs={"kimi_ckpt_engine": {"rebuild_group": rebuild_group}}
103+
)
104+
model_config = HFModelConfig(path=model_path, use_remove_padding=True)
105+
rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config)
106+
107+
# create trainer and rollout worker group
108+
resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3)
109+
resource_pool.get_placement_groups(device_name=get_device_name())
110+
trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout])
111+
trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config)
112+
trainer.reset()
113+
rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose)
114+
115+
# create checkpoint engine manager
116+
checkpoint_manager = CheckpointEngineManager(backend="kimi_ckpt_engine", trainer=trainer, replicas=replicas)
117+
for _ in range(3):
118+
await checkpoint_manager.update_weights()
119+
rollout.check_weights()
120+
121+
ray.shutdown()
122+
123+
77124
if __name__ == "__main__":
78125
test_hccl_checkpoint_engine(
79126
rebuild_group=False,

verl/checkpoint_engine/README.md

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,27 @@ Checkpoint Engine is an unified abstract layer to synchronize weights between va
1818
|nccl|NCCL|all_gather+broadcast|NVIDIA GPU & NCCL|Very High|Low: rebuild nccl group|Off-policy training<br>- Trainer/rollout disaggregated<br>- Fixed clusters
1919
|hccl|HCCL|all_gather+broadcast|Ascend NPU & HCCL| High|Low: rebuild hccl group|Off-policy training<br>- Trainer/rollout disaggregated<br>- Fixed clusters
2020
|nixl|NIXL|all_gather+ring p2p|Various transport backends (D2D, H2H, H2D, etc)<br>- UCX<br>- UCCL<br>- Mooncacke|Medium/High|High: dynamic adjust ring topology|Off-policy training<br>- Trainer/rollout disaggregated<br>- Elastic rollout<br>- Rollout fault tolerance<br>- Heterogeneous hardware rollout
21+
|kimi_ckpt_engine|MOONCAKE+NCCL/HCCL|p2p+broadcast|NVIDIA/Ascend|High|Low: rebuild communication group|Off-policy training<br>- Trainer/rollout disaggregated<br>- Save checkpoint each time
22+
23+
##### kimi_ckpt_engine detail:
24+
25+
In the kimi_ckpt_engine workflow, the trainer first offloads the weights to the CPU, and the rollout creates a sub communication group that includes all the cards for the rollout. Then, using Mooncake transfer engine, these weights are transmitted via P2P to a specific worker in the rollout, followed by a broadcast to all other rollout workers.
26+
27+
<img src="https://github.com/kip-cxj/verl/blob/cxj/doc_imgs/docs/_static/kimi_ckpt_engine.png?raw=true" alt="kimi-ckpt-engine" width="50%">
28+
29+
This mode requires the P2P feature of checkpoint_engine. Please ensure you have installed it via pip install 'checkpoint-engine[p2p]' and that your version is 0.4.0 or higher.
30+
31+
In addition, during the installation of checkpoint-engine[p2p], the transfer engine will be installed. However, This library has no prebuilt packages for Ascend devices and must be compiled from source. For detailed compilation instructions, see: [transfer-engine: ascend direct](https://github.com/kvcache-ai/Mooncake/blob/main/docs/source/design/transfer-engine/ascend_direct_transport.md)
2132

2233
### Benchmark
2334
1. benchmark setup
2435
- model: Qwen/Qwen3-30B-A3B-Base
25-
- trainer: fsdp world_size=2
36+
- trainer: fsdp world_size=2 (since Ascend 910C has 64GB of HBM, we set world_size=4)
2637
- rollout: num_rollout=30 (only receive weight without cuda ipc to vllm/sglang)
2738
```bash
28-
python3 tests/checkpoint_engine/test_nixl_checkpoint_engine.py
29-
python3 tests/checkpoint_engine/test_nccl_checkpoint_engine.py
30-
python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py
39+
pytest tests/checkpoint_engine/test_correctness_on_gpu.py
40+
pytest tests/checkpoint_engine/test_correctness_on_npu.py
41+
pytest tests/checkpoint_engine/test_special_server_adapter.py
3142
```
3243

3344
2. benchmark result
@@ -36,4 +47,5 @@ python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py
3647
|----|----|----|----|
3748
|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NCCL | ~7 | 8.25|
3849
|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NIXL | ~7 | 8.25|
39-
|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3|
50+
|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3|
51+
|2*16 Ascend 910C, inner suppernode| kimi_ckpt_engine | offload: 7 update: 3.5 | 16.5|

verl/checkpoint_engine/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,16 @@
4444
except ImportError:
4545
HCCLCheckpointEngine = None
4646

47-
4847
try:
4948
from .nixl_checkpoint_engine import NIXLCheckpointEngine
5049

5150
__all__ += ["NIXLCheckpointEngine"]
5251
except ImportError:
5352
NIXLCheckpointEngine = None
53+
54+
try:
55+
from .kimi_checkpoint_engine import KIMICheckpointEngine
56+
57+
__all__ += ["KIMICheckpointEngine"]
58+
except ImportError:
59+
KIMICheckpointEngine = None

0 commit comments

Comments
 (0)