Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
dcaacfe
Fix partial load problem, Add vlm support for trtllm rollout
SchumiDing Jan 31, 2026
0394ab5
Precommit check
SchumiDing Jan 31, 2026
0664ab1
Add check for if the model is vlm in trtllmhttpserver
SchumiDing Jan 31, 2026
bf71c9b
Support latest trtllm
SchumiDing Feb 2, 2026
f6e58b8
Support for qwen2.5 vl
SchumiDing Feb 2, 2026
7af6917
Add trtllm rollout test script
SchumiDing Feb 2, 2026
94c4eb0
Add test_trtllm_rollout workflow to test trtllm_rollout
SchumiDing Feb 2, 2026
25518fe
Add back mistakenly deleted file
SchumiDing Feb 2, 2026
fd007fb
Precommit check
SchumiDing Feb 2, 2026
659ec01
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 4, 2026
55b55dc
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 5, 2026
e2cc50b
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 5, 2026
ca17f8a
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 6, 2026
62af0f2
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 11, 2026
24a6620
Modified to inherit the worker extension class of tensorrt llm
SchumiDing Feb 11, 2026
6f055a2
Modified to inherit the worker extension class of tensorrt llm
SchumiDing Feb 11, 2026
d0b1d1d
fix readability problem of multimodal config
SchumiDing Feb 11, 2026
6b021f4
Remove need for multimodal server config
SchumiDing Feb 11, 2026
a7faa7b
Add vlm unit test into exisiting trtllm unit test
SchumiDing Feb 11, 2026
8519d36
add e2e script to train qwen2.5-vl with trtllm rollout
SchumiDing Feb 11, 2026
9acdcd6
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 12, 2026
5a145a5
Change import statement
SchumiDing Feb 12, 2026
3776338
remove reward config in e2e script
SchumiDing Feb 12, 2026
1706e71
When multi modal input for trtllm, decode with special token first
SchumiDing Feb 12, 2026
90837f3
rever typo
SchumiDing Feb 12, 2026
57506e2
revert typo
SchumiDing Feb 12, 2026
e193d0d
pre commit check
SchumiDing Feb 12, 2026
81050ce
Fix bugs
SchumiDing Feb 27, 2026
91d8c59
Update
SchumiDing Feb 27, 2026
60dd50b
Update
SchumiDing Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 66 additions & 19 deletions verl/workers/rollout/trtllm_rollout/trtllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def __init__(
logger.warning(f"rollout mode is {self.rollout_mode}, load_format is dummy, set to auto")
self.config.load_format = "auto"

self.is_vlm_model = (
self.model_config.hf_config is not None and hasattr(self.model_config.hf_config, "vision_config")
) or hasattr(self.model_config, "vision_config")

# used for http server
self._server_address = ray.util.get_node_ip_address().strip("[]")
self._server_port = None
Expand Down Expand Up @@ -125,7 +129,7 @@ async def launch_server(self):
"model": self.model_config.local_path,
"backend": "pytorch",
"orchestrator_type": "ray",
"ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
"ray_worker_extension_cls": "verl.workers.rollout.trtllm_rollout.trtllm_worker_extension.WorkerExtension",
"kv_cache_config": kv_cache_config,
"max_seq_len": self.config.max_model_len,
"max_batch_size": self.config.max_num_seqs,
Expand Down Expand Up @@ -159,15 +163,42 @@ async def launch_server(self):
}
)

self.llm = await AsyncLLM(**llm_kwargs)
if self.is_vlm_model:
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig

multimodal_config = MultimodalServerConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unittest for this new feature? There is a test_trtllm_async_server.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'm adding one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test script and relating test workflow has been added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find the test_trtllm_async_server.py in verl repo, so I write a test script for test on both llm rollout and vlm rollout of tensorrt-llm rollout worker

Copy link
Contributor

@hchings hchings Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find the test_trtllm_async_server.py in verl repo

We have a unittest MR that should be merge shortly, that contains the test_trtllm_async_server.py.

media_io_kwargs={
"image": {
"format": "pil",
"device": "cpu",
},
"video": {
"num_frames": 8,
"fps": 30,
"format": "pil",
"device": "cpu",
},
}
)
self.llm = await AsyncLLM(**llm_kwargs)
trtllm_server = OpenAIServer(
llm=self.llm,
model=self.model_config.local_path,
tool_parser=None,
server_role=None,
metadata_server_cfg=None,
multimodal_server_config=multimodal_config,
)
else:
self.llm = await AsyncLLM(**llm_kwargs)
trtllm_server = OpenAIServer(
llm=self.llm,
model=self.model_config.local_path,
tool_parser=None,
server_role=None,
metadata_server_cfg=None,
)

trtllm_server = OpenAIServer(
llm=self.llm,
model=self.model_config.local_path,
tool_parser=None,
server_role=None,
metadata_server_cfg=None,
)
app = trtllm_server.app
self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address)

Expand All @@ -179,9 +210,6 @@ async def generate(
image_data: Optional[list[Any]] = None,
video_data: Optional[list[Any]] = None,
) -> TokenOutput:
"""Generate sequence with token-in-token-out."""
assert image_data is None and video_data is None, "Multimodality is not yet supported in TRTLLMHttpServer."

from tensorrt_llm.llmapi import SamplingParams

max_tokens = min(self.config.response_length, self.config.max_model_len - len(prompt_ids))
Expand All @@ -192,15 +220,34 @@ async def generate(
sampling_params.update(self.sampling_args)

trt_llm_sampling_params = SamplingParams(**sampling_params)
outputs = await self.llm.generate_async(
inputs=prompt_ids,
sampling_params=trt_llm_sampling_params,
)

if self.is_vlm_model:
if image_data or video_data:
input_dict = {
"prompt_token_ids": prompt_ids,
"multi_modal_data": {},
}
if image_data:
input_dict["multi_modal_data"]["image"] = image_data
if video_data:
input_dict["multi_modal_data"]["video"] = video_data
outputs = await self.llm.generate_async(
inputs=input_dict,
sampling_params=trt_llm_sampling_params,
)
else:
outputs = await self.llm.generate_async(
inputs=prompt_ids,
sampling_params=trt_llm_sampling_params,
)
else:
outputs = await self.llm.generate_async(
inputs=prompt_ids,
sampling_params=trt_llm_sampling_params,
)
token_ids = outputs.outputs[0].token_ids
log_probs = None
if trt_llm_sampling_params.logprobs is not None:
log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs]
if outputs.outputs[0].logprobs is not None:
log_probs = [logprobs[token_ids[i]].logprob for i, logprobs in enumerate(outputs.outputs[0].logprobs)]
return TokenOutput(token_ids=token_ids, log_probs=log_probs)

async def wake_up(self):
Expand Down
35 changes: 28 additions & 7 deletions verl/workers/rollout/trtllm_rollout/trtllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def __init__(
self.is_leader_rank = None
self.replica_rank = None
self.is_dp_rank = None
self._supports_partial_loading = None

# hybrid mode
if self.device_mesh is not None:
Expand Down Expand Up @@ -312,6 +313,21 @@ def __init__(

self.node_ip = ray.util.get_node_ip_address().strip("[]")

async def get_supports_partial_loading(self) -> bool:
"""Query and cache whether the model supports partial weight loading."""
if self._supports_partial_loading is not None:
return self._supports_partial_loading

await self._init_server_adapter()
try:
self._supports_partial_loading = await self.server_actor.supports_partial_loading.remote()
except Exception as e:
logger.warning(f"Failed to query partial loading support: {e}, defaulting to False")
self._supports_partial_loading = False

logger.info(f"Model supports partial loading: {self._supports_partial_loading}")
return self._supports_partial_loading

async def _init_server_adapter(self):
if self._adapter is not None:
return
Expand Down Expand Up @@ -406,15 +422,20 @@ async def flush():
cur_available_bytes = total_available_bytes
cur_handles = []

# Query if model supports partial loading
supports_partial_loading = await self.get_supports_partial_loading()

for name, param in weights:
size_in_bytes = param.element_size() * param.numel()
if size_in_bytes > cur_available_bytes:
await flush()
if supports_partial_loading:
size_in_bytes = param.element_size() * param.numel()
if size_in_bytes > cur_available_bytes:
await flush()

assert cur_available_bytes >= size_in_bytes, (
f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}"
)
cur_available_bytes -= size_in_bytes

assert cur_available_bytes >= size_in_bytes, (
f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}"
)
cur_available_bytes -= size_in_bytes
handle = reduce_tensor(param.detach())
cur_handles.append((name, handle))

Expand Down
81 changes: 81 additions & 0 deletions verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import inspect
import pickle
from typing import Optional

from tensorrt_llm._ray_utils import control_action_decorator
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer
from tensorrt_llm._torch.utils import get_device_uuid
from tensorrt_llm.logger import logger


class WorkerExtension:
Copy link
Collaborator

@Superjomn Superjomn Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the latest WorkerExtension in the TensorRT-LLM repo. Are there any motivations for implementing a new one in verl repo? I am thinking about how to unify both. Ideally, we may update the one in the TensorRT-LLM codebase, but if we need a minor change on it before the next trtllm version bump up, @hchings do you have a suggestion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Ideally, we should still use the worker extension from the tedsnort-llm repo. But to support model that do not allow partial loading, I suppose the use of self.engine.model_engine.model_loader.reload should be able to use with param: allow_partial_loading=False

Copy link
Contributor Author

@SchumiDing SchumiDing Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add this new worker extension is to support allow_partial_loading=False, cause tensorrt-llm always set this param as True, but some models do not support

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll prefer that we keep this at TensorRT-LLM repo instead and make it generic for other RL FWs to reuse in the future.

def __init__(self):
pass

@control_action_decorator
def supports_partial_loading(self) -> bool:
"""Check if the model supports partial weight loading."""
try:
model = self.engine.model_engine.model
load_weights_args = inspect.getfullargspec(model.load_weights).args
return "allow_partial_loading" in load_weights_args
except Exception as e:
logger.warning(f"Failed to check partial loading support: {e}")
return False

@control_action_decorator
def update_weights(self, ipc_handles: Optional[dict] = None):
try:
if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"):
for module in self.engine.model_engine.model.modules():
if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False):
module.pre_reload_weights()
self.engine.model_engine.model.first_pre_reload_weights = True

if ipc_handles is not None:
device_uuid = get_device_uuid()
handles = ipc_handles.get(device_uuid, None)
if handles is not None:
weights = pickle.loads(base64.b64decode(handles))
model = self.engine.model_engine.model
load_weights_args = inspect.getfullargspec(model.load_weights).args
supports_partial_loading = "allow_partial_loading" in load_weights_args

if supports_partial_loading:
self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True)
else:
self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False)
else:
for module in self.engine.model_engine.model.modules():
if hasattr(module, "process_weights_after_loading") and not getattr(
module, "_weights_removed", False
):
module.process_weights_after_loading()
if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False):
module.post_load_weights()
moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None)
if isinstance(moe_load_balancer, MoeLoadBalancer):
moe_load_balancer.register_weight_slots_after_to_cuda()
logger.info("moe_load_balancer finalizing model...")
moe_load_balancer.finalize_model()
logger.info("moe_load_balancer finalize model done")
self.engine.reset_prefix_cache()
delattr(self.engine.model_engine.model, "first_pre_reload_weights")

except Exception as e:
logger.error("Encountered an error in update_weights")
raise e
Loading