-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[rollout] feat: Fix partial load problem, Add vlm support for trtllm rollout #5149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
dcaacfe
0394ab5
0664ab1
bf71c9b
f6e58b8
7af6917
94c4eb0
25518fe
fd007fb
659ec01
55b55dc
e2cc50b
ca17f8a
62af0f2
24a6620
6f055a2
d0b1d1d
6b021f4
a7faa7b
8519d36
9acdcd6
5a145a5
3776338
1706e71
90837f3
57506e2
e193d0d
81050ce
91d8c59
60dd50b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # Copyright 2026 Bytedance Ltd. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import base64 | ||
| import inspect | ||
| import pickle | ||
| from typing import Optional | ||
|
|
||
| from tensorrt_llm._ray_utils import control_action_decorator | ||
| from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer | ||
| from tensorrt_llm._torch.utils import get_device_uuid | ||
| from tensorrt_llm.logger import logger | ||
|
|
||
|
|
||
| class WorkerExtension: | ||
|
||
| def __init__(self): | ||
| pass | ||
|
|
||
| @control_action_decorator | ||
| def supports_partial_loading(self) -> bool: | ||
| """Check if the model supports partial weight loading.""" | ||
| try: | ||
| model = self.engine.model_engine.model | ||
| load_weights_args = inspect.getfullargspec(model.load_weights).args | ||
| return "allow_partial_loading" in load_weights_args | ||
| except Exception as e: | ||
| logger.warning(f"Failed to check partial loading support: {e}") | ||
| return False | ||
|
|
||
| @control_action_decorator | ||
| def update_weights(self, ipc_handles: Optional[dict] = None): | ||
| try: | ||
| if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): | ||
| for module in self.engine.model_engine.model.modules(): | ||
| if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False): | ||
| module.pre_reload_weights() | ||
| self.engine.model_engine.model.first_pre_reload_weights = True | ||
|
|
||
| if ipc_handles is not None: | ||
| device_uuid = get_device_uuid() | ||
| handles = ipc_handles.get(device_uuid, None) | ||
| if handles is not None: | ||
| weights = pickle.loads(base64.b64decode(handles)) | ||
| model = self.engine.model_engine.model | ||
| load_weights_args = inspect.getfullargspec(model.load_weights).args | ||
| supports_partial_loading = "allow_partial_loading" in load_weights_args | ||
|
|
||
| if supports_partial_loading: | ||
| self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True) | ||
| else: | ||
| self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False) | ||
| else: | ||
| for module in self.engine.model_engine.model.modules(): | ||
| if hasattr(module, "process_weights_after_loading") and not getattr( | ||
| module, "_weights_removed", False | ||
| ): | ||
| module.process_weights_after_loading() | ||
| if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False): | ||
| module.post_load_weights() | ||
| moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) | ||
| if isinstance(moe_load_balancer, MoeLoadBalancer): | ||
| moe_load_balancer.register_weight_slots_after_to_cuda() | ||
| logger.info("moe_load_balancer finalizing model...") | ||
| moe_load_balancer.finalize_model() | ||
| logger.info("moe_load_balancer finalize model done") | ||
| self.engine.reset_prefix_cache() | ||
| delattr(self.engine.model_engine.model, "first_pre_reload_weights") | ||
|
|
||
| except Exception as e: | ||
| logger.error("Encountered an error in update_weights") | ||
| raise e | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a unittest for this new feature? There is a
test_trtllm_async_server.py.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I'm adding one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test script and relating test workflow has been added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't find the test_trtllm_async_server.py in verl repo, so I write a test script for test on both llm rollout and vlm rollout of tensorrt-llm rollout worker
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a unittest MR that should be merge shortly, that contains the test_trtllm_async_server.py.