[ckpt] feat: implement large tensor slicing in vllm rollout and CheckpointEngine for weight updating#5378
Conversation
…rge tensors Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
5d76fbc to
fe7e78e
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces chunked weight handling for large tensors in vLLM rollouts, which is a valuable feature for working with large models. The changes involve slicing oversized weights on the sender side and reassembling them on the receiver side. A new test is also added to verify the correctness of this new functionality. My review identified a critical bug in the chunking logic that could lead to a ZeroDivisionError, and a high-severity issue in the new test related to a hardcoded, user-specific file path. I have provided code suggestions to address both of these points. Apart from these issues, the implementation appears solid.
| config.trainer.n_gpus_per_node = 8 | ||
| config.trainer.nnodes = 1 | ||
| config.actor_rollout_ref.actor.use_dynamic_bsz = True | ||
| config.actor_rollout_ref.model.path = os.path.expanduser("~/models/Qwen/Qwen3-VL-2B-Instruct") |
There was a problem hiding this comment.
The model path is hardcoded to a user-specific local directory. This makes the test non-portable and difficult for other developers to run. It's recommended to use a model from the Hugging Face Hub that can be downloaded automatically, or at least make this path configurable via an environment variable. Using a smaller model for this test would also make it faster and more efficient.
| config.actor_rollout_ref.model.path = os.path.expanduser("~/models/Qwen/Qwen3-VL-2B-Instruct") | |
| config.actor_rollout_ref.model.path = os.path.expanduser(os.environ.get("VERL_TEST_MODEL_PATH", "hf-internal-testing/tiny-random-LlamaForCausalLM")) |
…isionError Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
|
CheckpointEngine should also slice large tensor as well. |
OK, I will also implement large tensor slicing in CheckpointEngine. |
|
should merge after #5309 |
…rge tensors Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
…isionError Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
…ort large tensors Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
…ode reuse Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
…hunked weight handling Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
…class method Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
…CL checkpoint engines Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
…ne configs and clone tensors in pending chunks Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
…t engines Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
[ckpt] feat: implement large tensor slicing in CheckpointEngine
Thank you for the suggestion. Since similar logic has already been implemented and tested in |
There was a problem hiding this comment.
Code Review
This pull request introduces support for large tensor slicing during weight updates in the CheckpointEngine and vllm rollout, refactors common logic into a new CollectiveCheckpointEngine base class, and adds a correctness test for the naive backend. While these changes enhance functionality and code structure, a high-severity security issue was identified: the use of pickle (via zmq.recv_pyobj) for deserializing metadata over the network poses a significant risk of Remote Code Execution (RCE). It is strongly recommended to switch to a safer serialization format like JSON. Additionally, there is a significant piece of duplicated code for the tensor slicing logic that could benefit from refactoring to improve long-term maintainability.
| self.socket.send_pyobj(self.metadata) | ||
| else: | ||
| self.socket.recv_string() | ||
| self.metadata = self.socket.recv_pyobj() |
There was a problem hiding this comment.
The use of zmq.Socket.recv_pyobj() is insecure as it relies on the pickle module for deserialization. pickle is known to be vulnerable to arbitrary code execution if the input data is controlled by an attacker. Since this data is received over the network from the master process, an attacker who can connect to the ZMQ port or spoof the master can achieve Remote Code Execution (RCE) on the worker processes.
Recommendation: Use a safer serialization format such as JSON or msgpack. For non-serializable types like torch.Size or torch.dtype, convert them to strings or lists before sending and reconstruct them on the receiving end.
…cing in ServerAdapter and CheckpointEngine Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
|
I think this PR is ready for review. cc @wuxibin89 @pengwu22 |
Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
…ass configuration Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
What does this PR do?
This PR does the following:
CheckpointEngine.CheckpointEngineandCollectiveCheckpointEngineclasses.Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.