From dcaacfec53ab96fb2a36a82b073266b352142689 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Sat, 31 Jan 2026 19:23:45 +0800 Subject: [PATCH 01/24] Fix partial load problem, Add vlm support for trtllm rollout --- .../trtllm_rollout/trtllm_async_server.py | 84 ++++++++++++++----- .../rollout/trtllm_rollout/trtllm_rollout.py | 35 ++++++-- .../trtllm_rollout/trtllm_worker_extension.py | 77 +++++++++++++++++ 3 files changed, 168 insertions(+), 28 deletions(-) create mode 100644 verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index f669a7bfe3b..6448075e476 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -125,7 +125,7 @@ async def launch_server(self): "model": self.model_config.local_path, "backend": "pytorch", "orchestrator_type": "ray", - "ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + "ray_worker_extension_cls": "verl.workers.rollout.trtllm_rollout.trtllm_worker_extension.WorkerExtension", "kv_cache_config": kv_cache_config, "max_seq_len": self.config.max_model_len, "max_batch_size": self.config.max_num_seqs, @@ -159,18 +159,45 @@ async def launch_server(self): } ) - self.llm = await AsyncLLM(**llm_kwargs) - - trtllm_server = OpenAIServer( - llm=self.llm, - model=self.model_config.local_path, - tool_parser=None, - server_role=None, - metadata_server_cfg=None, - ) + if self.is_vlm_model: + from tensorrt_llm.inputs.multimodal import MultimodalServerConfig + multimodal_config = MultimodalServerConfig( + media_io_kwargs={ + "image": { + "format": "pil", + "device": "cpu", + }, + "video": { + "num_frames": 8, + "fps": 30, + "format": "pil", + "device": "cpu", + }, + } + ) + self.llm = await AsyncLLM(**llm_kwargs) + trtllm_server = OpenAIServer( + llm=self.llm, + model=self.model_config.local_path, + tool_parser=None, + server_role=None, + metadata_server_cfg=None, + multimodal_server_config=multimodal_config, + ) + else: + self.llm = await AsyncLLM(**llm_kwargs) + trtllm_server = OpenAIServer( + llm=self.llm, + model=self.model_config.local_path, + tool_parser=None, + server_role=None, + metadata_server_cfg=None, + ) + app = trtllm_server.app self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address) + @resume_on_abort async def generate( self, prompt_ids: list[int], @@ -179,11 +206,7 @@ async def generate( image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, ) -> TokenOutput: - """Generate sequence with token-in-token-out.""" - assert image_data is None and video_data is None, "Multimodality is not yet supported in TRTLLMHttpServer." - from tensorrt_llm.llmapi import SamplingParams - max_tokens = min(self.config.response_length, self.config.max_model_len - len(prompt_ids)) sampling_params["max_tokens"] = max_tokens sampling_params["logprobs"] = 1 if sampling_params.pop("logprobs", False) else None @@ -192,15 +215,34 @@ async def generate( sampling_params.update(self.sampling_args) trt_llm_sampling_params = SamplingParams(**sampling_params) - outputs = await self.llm.generate_async( - inputs=prompt_ids, - sampling_params=trt_llm_sampling_params, - ) - + if self.is_vlm_model: + if image_data or video_data: + input_dict = { + "prompt_token_ids": prompt_ids, + "multi_modal_data": {}, + } + if image_data: + input_dict["multi_modal_data"]["image"] = image_data + if video_data: + input_dict["multi_modal_data"]["video"] = video_data + outputs = await self.llm.generate_async( + inputs=input_dict, + sampling_params=trt_llm_sampling_params, + ) + else: + outputs = await self.llm.generate_async( + inputs=prompt_ids, + sampling_params=trt_llm_sampling_params, + ) + else: + outputs = await self.llm.generate_async( + inputs=prompt_ids, + sampling_params=trt_llm_sampling_params, + ) token_ids = outputs.outputs[0].token_ids log_probs = None - if trt_llm_sampling_params.logprobs is not None: - log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs] + if outputs.outputs[0].logprobs is not None: + log_probs = [logprobs[token_ids[i]].logprob for i, logprobs in enumerate(outputs.outputs[0].logprobs)] return TokenOutput(token_ids=token_ids, log_probs=log_probs) async def wake_up(self): diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index 3c42ee7bc73..ba6a991b57d 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -281,6 +281,7 @@ def __init__( self.is_leader_rank = None self.replica_rank = None self.is_dp_rank = None + self._supports_partial_loading = None # hybrid mode if self.device_mesh is not None: @@ -312,6 +313,21 @@ def __init__( self.node_ip = ray.util.get_node_ip_address().strip("[]") + async def get_supports_partial_loading(self) -> bool: + """Query and cache whether the model supports partial weight loading.""" + if self._supports_partial_loading is not None: + return self._supports_partial_loading + + await self._init_server_adapter() + try: + self._supports_partial_loading = await self.server_actor.supports_partial_loading.remote() + except Exception as e: + logger.warning(f"Failed to query partial loading support: {e}, defaulting to False") + self._supports_partial_loading = False + + logger.info(f"Model supports partial loading: {self._supports_partial_loading}") + return self._supports_partial_loading + async def _init_server_adapter(self): if self._adapter is not None: return @@ -405,16 +421,21 @@ async def flush(): await self.update_weights_from_ipc_handles(serialized_device_handles) cur_available_bytes = total_available_bytes cur_handles = [] + + # Query if model supports partial loading + supports_partial_loading = await self.get_supports_partial_loading() for name, param in weights: - size_in_bytes = param.element_size() * param.numel() - if size_in_bytes > cur_available_bytes: - await flush() + if supports_partial_loading: + size_in_bytes = param.element_size() * param.numel() + if size_in_bytes > cur_available_bytes: + await flush() + + assert cur_available_bytes >= size_in_bytes, ( + f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}" + ) + cur_available_bytes -= size_in_bytes - assert cur_available_bytes >= size_in_bytes, ( - f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}" - ) - cur_available_bytes -= size_in_bytes handle = reduce_tensor(param.detach()) cur_handles.append((name, handle)) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py new file mode 100644 index 00000000000..86b341dbf84 --- /dev/null +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -0,0 +1,77 @@ +import base64 +import inspect +import pickle +from typing import Optional + +from tensorrt_llm._ray_utils import control_action_decorator +from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer +from tensorrt_llm._torch.utils import get_device_uuid +from tensorrt_llm.logger import logger + + +class WorkerExtension: + + def __init__(self): + pass + + @control_action_decorator + def supports_partial_loading(self) -> bool: + """Check if the model supports partial weight loading.""" + try: + model = self.engine.model_engine.model + load_weights_args = inspect.getfullargspec(model.load_weights).args + return "allow_partial_loading" in load_weights_args + except Exception as e: + logger.warning(f"Failed to check partial loading support: {e}") + return False + + @control_action_decorator + def update_weights(self, ipc_handles: Optional[dict] = None): + try: + if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): + for module in self.engine.model_engine.model.modules(): + if hasattr(module, "pre_reload_weights") and not getattr( + module, "_weights_removed", False + ): + module.pre_reload_weights() + setattr(self.engine.model_engine.model, "first_pre_reload_weights", True) + + if ipc_handles is not None: + device_uuid = get_device_uuid() + handles = ipc_handles.get(device_uuid, None) + if handles is not None: + weights = pickle.loads(base64.b64decode(handles)) + model = self.engine.model_engine.model + load_weights_args = inspect.getfullargspec(model.load_weights).args + supports_partial_loading = "allow_partial_loading" in load_weights_args + + if supports_partial_loading: + self.engine.model_engine.model_loader.reload( + model, weights, allow_partial_loading=True + ) + else: + self.engine.model_engine.model_loader.reload( + model, weights, allow_partial_loading=False + ) + else: + for module in self.engine.model_engine.model.modules(): + if hasattr(module, "process_weights_after_loading") and not getattr( + module, "_weights_removed", False + ): + module.process_weights_after_loading() + if hasattr(module, "post_load_weights") and not getattr( + module, "_weights_removed", False + ): + module.post_load_weights() + moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) + if isinstance(moe_load_balancer, MoeLoadBalancer): + moe_load_balancer.register_weight_slots_after_to_cuda() + logger.info("moe_load_balancer finalizing model...") + moe_load_balancer.finalize_model() + logger.info("moe_load_balancer finalize model done") + self.engine.reset_prefix_cache() + delattr(self.engine.model_engine.model, "first_pre_reload_weights") + + except Exception as e: + logger.error("Encountered an error in update_weights") + raise e From 0394ab512fdefe6be8a6b8fa5e2393dfa5e0777e Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Sat, 31 Jan 2026 20:03:04 +0800 Subject: [PATCH 02/24] Precommit check --- .../trtllm_rollout/trtllm_async_server.py | 5 +-- .../rollout/trtllm_rollout/trtllm_rollout.py | 4 +-- .../trtllm_rollout/trtllm_worker_extension.py | 32 +++++++++++-------- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 6448075e476..3317a641fc1 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -161,6 +161,7 @@ async def launch_server(self): if self.is_vlm_model: from tensorrt_llm.inputs.multimodal import MultimodalServerConfig + multimodal_config = MultimodalServerConfig( media_io_kwargs={ "image": { @@ -193,11 +194,10 @@ async def launch_server(self): server_role=None, metadata_server_cfg=None, ) - + app = trtllm_server.app self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address) - @resume_on_abort async def generate( self, prompt_ids: list[int], @@ -207,6 +207,7 @@ async def generate( video_data: Optional[list[Any]] = None, ) -> TokenOutput: from tensorrt_llm.llmapi import SamplingParams + max_tokens = min(self.config.response_length, self.config.max_model_len - len(prompt_ids)) sampling_params["max_tokens"] = max_tokens sampling_params["logprobs"] = 1 if sampling_params.pop("logprobs", False) else None diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index ba6a991b57d..ce2527c66e7 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -281,7 +281,7 @@ def __init__( self.is_leader_rank = None self.replica_rank = None self.is_dp_rank = None - self._supports_partial_loading = None + self._supports_partial_loading = None # hybrid mode if self.device_mesh is not None: @@ -421,7 +421,7 @@ async def flush(): await self.update_weights_from_ipc_handles(serialized_device_handles) cur_available_bytes = total_available_bytes cur_handles = [] - + # Query if model supports partial loading supports_partial_loading = await self.get_supports_partial_loading() diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index 86b341dbf84..a7a96f607fa 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -1,3 +1,16 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import base64 import inspect import pickle @@ -10,7 +23,6 @@ class WorkerExtension: - def __init__(self): pass @@ -30,11 +42,9 @@ def update_weights(self, ipc_handles: Optional[dict] = None): try: if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): for module in self.engine.model_engine.model.modules(): - if hasattr(module, "pre_reload_weights") and not getattr( - module, "_weights_removed", False - ): + if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False): module.pre_reload_weights() - setattr(self.engine.model_engine.model, "first_pre_reload_weights", True) + self.engine.model_engine.model.first_pre_reload_weights = True if ipc_handles is not None: device_uuid = get_device_uuid() @@ -46,22 +56,16 @@ def update_weights(self, ipc_handles: Optional[dict] = None): supports_partial_loading = "allow_partial_loading" in load_weights_args if supports_partial_loading: - self.engine.model_engine.model_loader.reload( - model, weights, allow_partial_loading=True - ) + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True) else: - self.engine.model_engine.model_loader.reload( - model, weights, allow_partial_loading=False - ) + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False) else: for module in self.engine.model_engine.model.modules(): if hasattr(module, "process_weights_after_loading") and not getattr( module, "_weights_removed", False ): module.process_weights_after_loading() - if hasattr(module, "post_load_weights") and not getattr( - module, "_weights_removed", False - ): + if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False): module.post_load_weights() moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) if isinstance(moe_load_balancer, MoeLoadBalancer): From 0664ab102b059063d29b4d420d28abbd146eef1c Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Sat, 31 Jan 2026 22:55:36 +0800 Subject: [PATCH 03/24] Add check for if the model is vlm in trtllmhttpserver --- verl/workers/rollout/trtllm_rollout/trtllm_async_server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 3317a641fc1..24b380a6962 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -89,6 +89,10 @@ def __init__( logger.warning(f"rollout mode is {self.rollout_mode}, load_format is dummy, set to auto") self.config.load_format = "auto" + self.is_vlm_model = ( + self.model_config.hf_config is not None and hasattr(self.model_config.hf_config, "vision_config") + ) or hasattr(self.model_config, "vision_config") + # used for http server self._server_address = ray.util.get_node_ip_address().strip("[]") self._server_port = None From bf71c9b4c3c19b6496133d93568119aea6d8951d Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 17:12:58 +0800 Subject: [PATCH 04/24] Support latest trtllm --- .../config/test_optim_config_on_cpu.py | 48 ---------- .../trtllm_rollout/trtllm_worker_extension.py | 96 +++++++++++++++---- 2 files changed, 80 insertions(+), 64 deletions(-) delete mode 100644 tests/workers/config/test_optim_config_on_cpu.py diff --git a/tests/workers/config/test_optim_config_on_cpu.py b/tests/workers/config/test_optim_config_on_cpu.py deleted file mode 100644 index b44cb40c6b1..00000000000 --- a/tests/workers/config/test_optim_config_on_cpu.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from verl.workers.config.optimizer import FSDPOptimizerConfig - - -class TestFSDPOptimizerConfigCPU: - def test_default_configuration(self): - config = FSDPOptimizerConfig(lr=0.1) - assert config.min_lr_ratio is None - assert config.lr_scheduler_type == "constant" - assert config.num_cycles == 0.5 - - @pytest.mark.parametrize("lr_scheduler_type", ["constant", "cosine"]) - def test_valid_lr_scheduler_types(self, lr_scheduler_type): - config = FSDPOptimizerConfig(lr_scheduler_type=lr_scheduler_type, lr=0.1) - assert config.lr_scheduler_type == lr_scheduler_type - - @pytest.mark.parametrize("warmup_style", ["constant", "cosine"]) - def test_valid_warmup_style_types(self, warmup_style): - config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1) - assert config.lr_scheduler_type == warmup_style - - def test_invalid_lr_scheduler_type(self): - with pytest.raises((ValueError, AssertionError)): - FSDPOptimizerConfig(lr_scheduler_type="invalid_style", lr=0.1) - - def test_invalid_warmup_style_type(self): - with pytest.raises((ValueError, AssertionError)): - FSDPOptimizerConfig(warmup_style="invalid_style", lr=0.1) - - @pytest.mark.parametrize("num_cycles", [0.1, 1.0, 2.5]) - def test_num_cycles_configuration(self, num_cycles): - config = FSDPOptimizerConfig(num_cycles=num_cycles, lr=0.1) - assert config.num_cycles == num_cycles diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index a7a96f607fa..ce56c8b9b5c 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -13,9 +13,11 @@ # limitations under the License. import base64 import inspect -import pickle from typing import Optional +import torch + +from tensorrt_llm import serialization from tensorrt_llm._ray_utils import control_action_decorator from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer from tensorrt_llm._torch.utils import get_device_uuid @@ -42,30 +44,85 @@ def update_weights(self, ipc_handles: Optional[dict] = None): try: if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): for module in self.engine.model_engine.model.modules(): - if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False): + if hasattr(module, "pre_reload_weights") and not getattr( + module, "_weights_removed", False + ): module.pre_reload_weights() - self.engine.model_engine.model.first_pre_reload_weights = True + setattr(self.engine.model_engine.model, "first_pre_reload_weights", True) if ipc_handles is not None: - device_uuid = get_device_uuid() - handles = ipc_handles.get(device_uuid, None) - if handles is not None: - weights = pickle.loads(base64.b64decode(handles)) - model = self.engine.model_engine.model - load_weights_args = inspect.getfullargspec(model.load_weights).args - supports_partial_loading = "allow_partial_loading" in load_weights_args - - if supports_partial_loading: - self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True) - else: - self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False) + logger.info("Update weights from IPC handles") + device_uuid = get_device_uuid(self.device_id) + + if device_uuid not in ipc_handles: + raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles") + + weights = {} + + serialized_handles = ipc_handles[device_uuid] + if isinstance(serialized_handles, str): + # Data is base64-encoded pickled bytes - deserialize it + # using restricted unpickler from tensorrt_llm.serialization + logger.info("Deserializing base64-encoded weight handles") + decoded_data = base64.b64decode(serialized_handles) + # Allow basic builtins and all torch modules + approved_imports = { + "builtins": [ + "list", + "tuple", + "str", + "int", + "float", + "bool", + "bytes", + "dict", + "NoneType", + "type", + ], + } + all_handles = serialization.loads( + decoded_data, + approved_imports=approved_imports, + approved_module_patterns=[r"^torch.*"], + ) + + # Verify the result is a list as expected + if not isinstance(all_handles, list): + raise ValueError( + f"Deserialized data must be a list, got {type(all_handles).__name__} instead" + ) + else: + # Data is already in the correct format (backward compatibility) + all_handles = serialized_handles + + for param_name, tensor_handle in all_handles: + func, args = tensor_handle + list_args = list(args) + list_args[6] = self.device_id + tensor = func(*list_args) + weights[param_name] = tensor + + logger.info(f"weights key size: {len(weights.keys())}") + + # Check if model supports partial loading and use appropriate strategy + model = self.engine.model_engine.model + load_weights_args = inspect.getfullargspec(model.load_weights).args + supports_partial_loading = "allow_partial_loading" in load_weights_args + + if supports_partial_loading: + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True) + else: + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False) else: + logger.info("Finalize update weights") for module in self.engine.model_engine.model.modules(): if hasattr(module, "process_weights_after_loading") and not getattr( module, "_weights_removed", False ): module.process_weights_after_loading() - if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False): + if hasattr(module, "post_load_weights") and not getattr( + module, "_weights_removed", False + ): module.post_load_weights() moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) if isinstance(moe_load_balancer, MoeLoadBalancer): @@ -79,3 +136,10 @@ def update_weights(self, ipc_handles: Optional[dict] = None): except Exception as e: logger.error("Encountered an error in update_weights") raise e + + def check_weights_updated(self) -> bool: + """Check if the weights are updated to 0.""" + weights_updated = True + for name, p in self.engine.model_engine.model.named_parameters(): + weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p)) + return weights_updated From f6e58b882ecccdfc42fde76dda58b556c9ea3fc6 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 19:53:43 +0800 Subject: [PATCH 05/24] Support for qwen2.5 vl --- verl/workers/rollout/trtllm_rollout/trtllm_async_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 24b380a6962..221655847a4 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -225,6 +225,7 @@ async def generate( input_dict = { "prompt_token_ids": prompt_ids, "multi_modal_data": {}, + "mm_processor_kwargs": {}, } if image_data: input_dict["multi_modal_data"]["image"] = image_data From 7af6917e3ed9d0b3c9fa68a769a1e7d2fc8d4ee6 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 19:56:46 +0800 Subject: [PATCH 06/24] Add trtllm rollout test script --- .../rollout/rollout_trtllm/__init__.py | 14 + .../test_trtllm_rollout_utils.py | 458 ++++++++++++++++++ 2 files changed, 472 insertions(+) create mode 100644 tests/workers/rollout/rollout_trtllm/__init__.py create mode 100644 tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py diff --git a/tests/workers/rollout/rollout_trtllm/__init__.py b/tests/workers/rollout/rollout_trtllm/__init__.py new file mode 100644 index 00000000000..46866da4cd9 --- /dev/null +++ b/tests/workers/rollout/rollout_trtllm/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py new file mode 100644 index 00000000000..dd99f09f60c --- /dev/null +++ b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py @@ -0,0 +1,458 @@ +import asyncio +import os +import uuid + +import numpy as np +import pytest +import ray +import torch +from omegaconf import OmegaConf +from PIL import Image +from transformers import AutoTokenizer + +UNIMODAL_MODEL_PATH = "Qwen/Qwen2.5-Math-7B" +MULTIMODAL_MODEL_PATH = "Qwen/Qwen2.5-VL-7B-Instruct" + +MAX_MODEL_LEN = 4096 +RESPONSE_LENGTH = 256 +MAX_NUM_SEQS = 16 +GPU_MEMORY_UTILIZATION = 0.8 +TENSOR_PARALLEL_SIZE = 1 + + +def create_test_image(width: int = 224, height: int = 224) -> Image.Image: + img_array = np.zeros((height, width, 3), dtype=np.uint8) + for i in range(height): + for j in range(width): + img_array[i, j] = [ + int(255 * i / height), + int(255 * j / width), + int(255 * (i + j) / (height + width)), + ] + return Image.fromarray(img_array) + + +def create_rollout_config_dict(): + config_dict = { + "_target_": "verl.workers.config.RolloutConfig", + "name": "trtllm", + "mode": "async", + "temperature": 0.7, + "top_k": 50, + "top_p": 0.9, + "do_sample": True, + "n": 1, + "prompt_length": 512, + "response_length": RESPONSE_LENGTH, + "dtype": "bfloat16", + "gpu_memory_utilization": GPU_MEMORY_UTILIZATION, + "ignore_eos": False, + "enforce_eager": True, + "free_cache_engine": False, + "data_parallel_size": 1, + "tensor_model_parallel_size": TENSOR_PARALLEL_SIZE, + "pipeline_model_parallel_size": 1, + "max_num_batched_tokens": 8192, + "max_model_len": MAX_MODEL_LEN, + "max_num_seqs": MAX_NUM_SEQS, + "load_format": "auto", + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + } + return OmegaConf.create(config_dict) + + +def create_model_config_dict(model_path: str): + config_dict = { + "_target_": "verl.workers.config.HFModelConfig", + "path": model_path, + "trust_remote_code": True, + "load_tokenizer": True, + } + return OmegaConf.create(config_dict) + + +def get_tokenizer(model_path: str): + return AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + +def get_processor(model_path: str): + from transformers import AutoProcessor + return AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestUnimodalTRTLLMRollout: + + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_replica(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(UNIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(UNIMODAL_MODEL_PATH) + + @pytest.mark.parametrize( + "prompt", + [ + "What is 2 + 2?", + "Solve for x: 3x + 5 = 20", + "Calculate the derivative of x^2 + 3x + 1", + ], + ) + def test_unimodal_generate(self, trtllm_replica, tokenizer, prompt): + replica = trtllm_replica + + messages = [ + {"role": "system", "content": "You are a helpful math assistant."}, + {"role": "user", "content": prompt}, + ] + + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": True, + } + + request_id = str(uuid.uuid4()) + output = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=request_id, + )) + + assert output is not None + assert hasattr(output, "token_ids") + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print(f"\n[Unimodal Test]") + print(f"Prompt: {prompt}") + print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") + + def test_unimodal_batch_generate(self, trtllm_replica, tokenizer): + replica = trtllm_replica + + prompts = [ + "What is 1 + 1?", + "What is 2 * 3?", + "What is 10 / 2?", + ] + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + results = [] + + for i, prompt in enumerate(prompts): + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + output = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + )) + results.append(output) + + assert len(results) == len(prompts) + for i, (prompt, result) in enumerate(zip(prompts, results)): + assert result is not None + assert len(result.token_ids) > 0 + generated = tokenizer.decode(result.token_ids, skip_special_tokens=True) + print(f"\n[Batch {i}] Prompt: {prompt}") + print(f"Generated: {generated[:100]}...") + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestMultimodalTRTLLMRollout: + + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_vlm_replica(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(MULTIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(MULTIMODAL_MODEL_PATH) + + @pytest.fixture(scope="class") + def processor(self): + return get_processor(MULTIMODAL_MODEL_PATH) + + @pytest.mark.parametrize( + "prompt", + [ + "Describe this image in detail.", + "What colors do you see in this image?", + "What patterns are visible in this image?", + ], + ) + def test_multimodal_generate_with_image(self, trtllm_vlm_replica, processor, tokenizer, prompt): + replica = trtllm_vlm_replica + + test_image = create_test_image(224, 224) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt}, + ], + } + ] + + text = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + print("text: ", text) + input_ids = processor.tokenizer(text, return_tensors="pt", padding=True)["input_ids"][0].tolist() + + print("input_ids decoded: ", processor.tokenizer.decode(input_ids, skip_special_tokens=False, add_special_tokens=False)) + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + )) + + assert output is not None + assert hasattr(output, "token_ids") + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print(f"\n[Multimodal Test]") + print(f"Prompt: {prompt}") + print(f"Image size: {test_image.size}") + print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") + + @pytest.mark.parametrize( + "image_size", + [(224, 224), (384, 384), (512, 512)], + ) + def test_multimodal_different_image_sizes(self, trtllm_vlm_replica, processor, tokenizer, image_size): + replica = trtllm_vlm_replica + + width, height = image_size + test_image = create_test_image(width, height) + + prompt = "What is shown in this image?" + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt}, + ], + } + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = processor.tokenizer(text, return_tensors="pt", padding=True)["input_ids"][0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + )) + + assert output is not None + assert len(output.token_ids) > 0 + print(f"\n[Image Size {image_size}] Generated {len(output.token_ids)} tokens") + + def test_multimodal_text_only_fallback(self, trtllm_vlm_replica, tokenizer): + replica = trtllm_vlm_replica + + prompt = "What is the capital of China?" + messages = [{"role": "user", "content": prompt}] + + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + )) + + assert output is not None + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print(f"\n[Text-only on VLM]") + print(f"Prompt: {prompt}") + print(f"Generated: {generated_text}") + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestTRTLLMServerLifecycle: + + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_replica_lifecycle(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(UNIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica, loop + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(UNIMODAL_MODEL_PATH) + + def test_wake_sleep_cycle(self, trtllm_replica_lifecycle, tokenizer): + replica, loop = trtllm_replica_lifecycle + + prompt = "Hello, world!" + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = {"temperature": 0.7, "top_p": 0.9, "top_k": 50, "logprobs": False} + + output1 = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + )) + assert output1 is not None + assert len(output1.token_ids) > 0 + print(f"\n[Before Sleep] Generated {len(output1.token_ids)} tokens") + + loop.run_until_complete(replica.sleep()) + print("[Sleep] Server put to sleep") + + loop.run_until_complete(replica.wake_up()) + print("[Wake Up] Server woken up") + + output2 = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + )) + assert output2 is not None + assert len(output2.token_ids) > 0 + print(f"[After Wake Up] Generated {len(output2.token_ids)} tokens") From 94c4eb0a1adce3881fd6d978c3a7e07cb8d6c0ae Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 19:57:22 +0800 Subject: [PATCH 07/24] Add test_trtllm_rollout workflow to test trtllm_rollout --- .github/workflows/test_trtllm_rollout.yml | 82 +++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 .github/workflows/test_trtllm_rollout.yml diff --git a/.github/workflows/test_trtllm_rollout.yml b/.github/workflows/test_trtllm_rollout.yml new file mode 100644 index 00000000000..9c714de4892 --- /dev/null +++ b/.github/workflows/test_trtllm_rollout.yml @@ -0,0 +1,82 @@ +name: test_trtllm_rollout + +on: + push: + branches: + - main + - v0.* + paths: + - "verl/workers/rollout/trtllm_rollout/**" + - "tests/workers/rollout/rollout_trtllm/**" + - ".github/workflows/test_trtllm_rollout.yml" + pull_request: + branches: + - main + - v0.* + paths: + - "verl/workers/rollout/trtllm_rollout/**" + - "tests/workers/rollout/rollout_trtllm/**" + - ".github/workflows/test_trtllm_rollout.yml" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +permissions: + contents: read + +env: + IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:trtllm1.2.0rc6" + DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" + +jobs: + setup: + if: github.repository_owner == 'verl-project' + runs-on: ubuntu-latest + outputs: + runner-label: ${{ steps.create-runner.outputs.runner-label }} + mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} + steps: + - uses: actions/checkout@v4 + - id: create-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "create" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-image: "${{ env.IMAGE }}" + + test_trtllm_rollout: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install -r requirements-test.txt + pip3 install --no-deps -e . + - name: Run TRT-LLM rollout tests + run: | + ray stop --force + pytest -v -s tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py + + cleanup: + runs-on: ubuntu-latest + needs: [setup, test_trtllm_rollout] + if: always() + steps: + - id: destroy-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "destroy" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" + From 25518fee2daee8e5d06575d1526a08c7a20fe124 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 20:00:34 +0800 Subject: [PATCH 08/24] Add back mistakenly deleted file --- .../config/test_optim_config_on_cpu.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/workers/config/test_optim_config_on_cpu.py diff --git a/tests/workers/config/test_optim_config_on_cpu.py b/tests/workers/config/test_optim_config_on_cpu.py new file mode 100644 index 00000000000..5aae6bb8c2c --- /dev/null +++ b/tests/workers/config/test_optim_config_on_cpu.py @@ -0,0 +1,48 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from verl.workers.config.optimizer import FSDPOptimizerConfig + + +class TestFSDPOptimizerConfigCPU: + def test_default_configuration(self): + config = FSDPOptimizerConfig(lr=0.1) + assert config.min_lr_ratio is None + assert config.lr_scheduler_type == "constant" + assert config.num_cycles == 0.5 + + @pytest.mark.parametrize("lr_scheduler_type", ["constant", "cosine"]) + def test_valid_lr_scheduler_types(self, lr_scheduler_type): + config = FSDPOptimizerConfig(lr_scheduler_type=lr_scheduler_type, lr=0.1) + assert config.lr_scheduler_type == lr_scheduler_type + + @pytest.mark.parametrize("warmup_style", ["constant", "cosine"]) + def test_valid_warmup_style_types(self, warmup_style): + config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1) + assert config.lr_scheduler_type == warmup_style + + def test_invalid_lr_scheduler_type(self): + with pytest.raises((ValueError, AssertionError)): + FSDPOptimizerConfig(lr_scheduler_type="invalid_style", lr=0.1) + + def test_invalid_warmup_style_type(self): + with pytest.raises((ValueError, AssertionError)): + FSDPOptimizerConfig(warmup_style="invalid_style", lr=0.1) + + @pytest.mark.parametrize("num_cycles", [0.1, 1.0, 2.5]) + def test_num_cycles_configuration(self, num_cycles): + config = FSDPOptimizerConfig(num_cycles=num_cycles, lr=0.1) + assert config.num_cycles == num_cycles \ No newline at end of file From fd007fb333b72d8fed9cc11aa105407f6fe9eaa5 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 20:03:20 +0800 Subject: [PATCH 09/24] Precommit check --- .../config/test_optim_config_on_cpu.py | 2 +- .../rollout/rollout_trtllm/__init__.py | 3 +- .../test_trtllm_rollout_utils.py | 121 +++++++++++------- .../trtllm_rollout/trtllm_worker_extension.py | 15 +-- 4 files changed, 80 insertions(+), 61 deletions(-) diff --git a/tests/workers/config/test_optim_config_on_cpu.py b/tests/workers/config/test_optim_config_on_cpu.py index 5aae6bb8c2c..b44cb40c6b1 100644 --- a/tests/workers/config/test_optim_config_on_cpu.py +++ b/tests/workers/config/test_optim_config_on_cpu.py @@ -45,4 +45,4 @@ def test_invalid_warmup_style_type(self): @pytest.mark.parametrize("num_cycles", [0.1, 1.0, 2.5]) def test_num_cycles_configuration(self, num_cycles): config = FSDPOptimizerConfig(num_cycles=num_cycles, lr=0.1) - assert config.num_cycles == num_cycles \ No newline at end of file + assert config.num_cycles == num_cycles diff --git a/tests/workers/rollout/rollout_trtllm/__init__.py b/tests/workers/rollout/rollout_trtllm/__init__.py index 46866da4cd9..d828409b82e 100644 --- a/tests/workers/rollout/rollout_trtllm/__init__.py +++ b/tests/workers/rollout/rollout_trtllm/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2026 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py index dd99f09f60c..21ab5689113 100644 --- a/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py +++ b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py @@ -1,5 +1,17 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import asyncio -import os import uuid import numpy as np @@ -78,6 +90,7 @@ def get_tokenizer(model_path: str): def get_processor(model_path: str): from transformers import AutoProcessor + return AutoProcessor.from_pretrained(model_path, trust_remote_code=True) @@ -86,7 +99,6 @@ def get_processor(model_path: str): reason="CUDA not available", ) class TestUnimodalTRTLLMRollout: - @pytest.fixture(scope="class") def ray_context(self): if ray.is_initialized(): @@ -153,18 +165,20 @@ def test_unimodal_generate(self, trtllm_replica, tokenizer, prompt): } request_id = str(uuid.uuid4()) - output = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=request_id, - )) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=request_id, + ) + ) assert output is not None assert hasattr(output, "token_ids") assert len(output.token_ids) > 0 generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) - print(f"\n[Unimodal Test]") + print("\n[Unimodal Test]") print(f"Prompt: {prompt}") print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") @@ -191,15 +205,17 @@ def test_unimodal_batch_generate(self, trtllm_replica, tokenizer): text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() - output = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - )) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) results.append(output) assert len(results) == len(prompts) - for i, (prompt, result) in enumerate(zip(prompts, results)): + for i, (prompt, result) in enumerate(zip(prompts, results, strict=False)): assert result is not None assert len(result.token_ids) > 0 generated = tokenizer.decode(result.token_ids, skip_special_tokens=True) @@ -212,7 +228,6 @@ def test_unimodal_batch_generate(self, trtllm_replica, tokenizer): reason="CUDA not available", ) class TestMultimodalTRTLLMRollout: - @pytest.fixture(scope="class") def ray_context(self): if ray.is_initialized(): @@ -283,8 +298,11 @@ def test_multimodal_generate_with_image(self, trtllm_vlm_replica, processor, tok print("text: ", text) input_ids = processor.tokenizer(text, return_tensors="pt", padding=True)["input_ids"][0].tolist() - print("input_ids decoded: ", processor.tokenizer.decode(input_ids, skip_special_tokens=False, add_special_tokens=False)) - + print( + "input_ids decoded: ", + processor.tokenizer.decode(input_ids, skip_special_tokens=False, add_special_tokens=False), + ) + sampling_params = { "temperature": 0.7, "top_p": 0.9, @@ -292,19 +310,21 @@ def test_multimodal_generate_with_image(self, trtllm_vlm_replica, processor, tok "logprobs": False, } - output = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - image_data=[test_image], - )) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + ) + ) assert output is not None assert hasattr(output, "token_ids") assert len(output.token_ids) > 0 generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) - print(f"\n[Multimodal Test]") + print("\n[Multimodal Test]") print(f"Prompt: {prompt}") print(f"Image size: {test_image.size}") print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") @@ -340,12 +360,14 @@ def test_multimodal_different_image_sizes(self, trtllm_vlm_replica, processor, t "logprobs": False, } - output = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - image_data=[test_image], - )) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + ) + ) assert output is not None assert len(output.token_ids) > 0 @@ -367,17 +389,19 @@ def test_multimodal_text_only_fallback(self, trtllm_vlm_replica, tokenizer): "logprobs": False, } - output = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - )) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) assert output is not None assert len(output.token_ids) > 0 generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) - print(f"\n[Text-only on VLM]") + print("\n[Text-only on VLM]") print(f"Prompt: {prompt}") print(f"Generated: {generated_text}") @@ -387,7 +411,6 @@ def test_multimodal_text_only_fallback(self, trtllm_vlm_replica, tokenizer): reason="CUDA not available", ) class TestTRTLLMServerLifecycle: - @pytest.fixture(scope="class") def ray_context(self): if ray.is_initialized(): @@ -433,11 +456,13 @@ def test_wake_sleep_cycle(self, trtllm_replica_lifecycle, tokenizer): sampling_params = {"temperature": 0.7, "top_p": 0.9, "top_k": 50, "logprobs": False} - output1 = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - )) + output1 = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) assert output1 is not None assert len(output1.token_ids) > 0 print(f"\n[Before Sleep] Generated {len(output1.token_ids)} tokens") @@ -448,11 +473,13 @@ def test_wake_sleep_cycle(self, trtllm_replica_lifecycle, tokenizer): loop.run_until_complete(replica.wake_up()) print("[Wake Up] Server woken up") - output2 = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - )) + output2 = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) assert output2 is not None assert len(output2.token_ids) > 0 print(f"[After Wake Up] Generated {len(output2.token_ids)} tokens") diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index ce56c8b9b5c..6bb5190dfbc 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -16,7 +16,6 @@ from typing import Optional import torch - from tensorrt_llm import serialization from tensorrt_llm._ray_utils import control_action_decorator from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer @@ -44,11 +43,9 @@ def update_weights(self, ipc_handles: Optional[dict] = None): try: if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): for module in self.engine.model_engine.model.modules(): - if hasattr(module, "pre_reload_weights") and not getattr( - module, "_weights_removed", False - ): + if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False): module.pre_reload_weights() - setattr(self.engine.model_engine.model, "first_pre_reload_weights", True) + self.engine.model_engine.model.first_pre_reload_weights = True if ipc_handles is not None: logger.info("Update weights from IPC handles") @@ -88,9 +85,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None): # Verify the result is a list as expected if not isinstance(all_handles, list): - raise ValueError( - f"Deserialized data must be a list, got {type(all_handles).__name__} instead" - ) + raise ValueError(f"Deserialized data must be a list, got {type(all_handles).__name__} instead") else: # Data is already in the correct format (backward compatibility) all_handles = serialized_handles @@ -120,9 +115,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None): module, "_weights_removed", False ): module.process_weights_after_loading() - if hasattr(module, "post_load_weights") and not getattr( - module, "_weights_removed", False - ): + if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False): module.post_load_weights() moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) if isinstance(moe_load_balancer, MoeLoadBalancer): From 24a66207d0a101ae09b1c437e0a0ca521179802e Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Wed, 11 Feb 2026 08:30:23 +0800 Subject: [PATCH 10/24] Modified to inherit the worker extension class of tensorrt llm --- .../trtllm_rollout/trtllm_worker_extension.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index 6bb5190dfbc..decabcc96f9 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -21,9 +21,9 @@ from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer from tensorrt_llm._torch.utils import get_device_uuid from tensorrt_llm.logger import logger +import tensorrt_llm.llmapi.rlhf_utils.WorkerExtension as trtllm_worker_extension - -class WorkerExtension: +class WorkerExtension(trtllm_worker_extension.WorkerExtension): def __init__(self): pass @@ -128,11 +128,4 @@ def update_weights(self, ipc_handles: Optional[dict] = None): except Exception as e: logger.error("Encountered an error in update_weights") - raise e - - def check_weights_updated(self) -> bool: - """Check if the weights are updated to 0.""" - weights_updated = True - for name, p in self.engine.model_engine.model.named_parameters(): - weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p)) - return weights_updated + raise e \ No newline at end of file From 6f055a213b4ab382bdf8ce906c363258060dcc0f Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Wed, 11 Feb 2026 08:31:02 +0800 Subject: [PATCH 11/24] Modified to inherit the worker extension class of tensorrt llm --- verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index decabcc96f9..d205566ab8b 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -23,7 +23,7 @@ from tensorrt_llm.logger import logger import tensorrt_llm.llmapi.rlhf_utils.WorkerExtension as trtllm_worker_extension -class WorkerExtension(trtllm_worker_extension.WorkerExtension): +class WorkerExtension(trtllm_worker_extension): def __init__(self): pass From d0b1d1dc8a12f180e64ed26c2d5ca0e04bc2873a Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Wed, 11 Feb 2026 08:44:18 +0800 Subject: [PATCH 12/24] fix readability problem of multimodal config --- .../trtllm_rollout/trtllm_async_server.py | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 80b83d3b501..c14c6262c9d 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -163,6 +163,7 @@ async def launch_server(self): } ) + multimodal_config = None if self.is_vlm_model: from tensorrt_llm.inputs.multimodal import MultimodalServerConfig @@ -180,24 +181,15 @@ async def launch_server(self): }, } ) - self.llm = await AsyncLLM(**llm_kwargs) - trtllm_server = OpenAIServer( - llm=self.llm, - model=self.model_config.local_path, - tool_parser=None, - server_role=None, - metadata_server_cfg=None, - multimodal_server_config=multimodal_config, - ) - else: - self.llm = await AsyncLLM(**llm_kwargs) - trtllm_server = OpenAIServer( - llm=self.llm, - model=self.model_config.local_path, - tool_parser=None, - server_role=None, - metadata_server_cfg=None, - ) + self.llm = await AsyncLLM(**llm_kwargs) + trtllm_server = OpenAIServer( + llm=self.llm, + model=self.model_config.local_path, + tool_parser=None, + server_role=None, + metadata_server_cfg=None, + multimodal_server_config=multimodal_config, + ) app = trtllm_server.app self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address) From 6b021f430c21562d6b2bc65d8f51a1b9cb8ad67f Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Wed, 11 Feb 2026 08:51:46 +0800 Subject: [PATCH 13/24] Remove need for multimodal server config --- .../trtllm_rollout/trtllm_async_server.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index c14c6262c9d..6cb71d6e412 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -163,24 +163,6 @@ async def launch_server(self): } ) - multimodal_config = None - if self.is_vlm_model: - from tensorrt_llm.inputs.multimodal import MultimodalServerConfig - - multimodal_config = MultimodalServerConfig( - media_io_kwargs={ - "image": { - "format": "pil", - "device": "cpu", - }, - "video": { - "num_frames": 8, - "fps": 30, - "format": "pil", - "device": "cpu", - }, - } - ) self.llm = await AsyncLLM(**llm_kwargs) trtllm_server = OpenAIServer( llm=self.llm, @@ -188,7 +170,6 @@ async def launch_server(self): tool_parser=None, server_role=None, metadata_server_cfg=None, - multimodal_server_config=multimodal_config, ) app = trtllm_server.app From a7faa7b801f4c66c41bff52a5fe001624b6e16a7 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Wed, 11 Feb 2026 08:58:53 +0800 Subject: [PATCH 14/24] Add vlm unit test into exisiting trtllm unit test --- .../workflows/e2e_ppo_grpo_trainer_trtllm.yml | 23 +++--- .github/workflows/test_trtllm_rollout.yml | 82 ------------------- 2 files changed, 13 insertions(+), 92 deletions(-) delete mode 100644 .github/workflows/test_trtllm_rollout.yml diff --git a/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml b/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml index c169f9dee55..fd9032e67a3 100644 --- a/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml +++ b/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml @@ -41,13 +41,13 @@ on: - main - v0.* paths: - - "**/*.py" - # Other entrypoints - - "!verl/trainer/fsdp_sft_trainer.py" - # Recipes - - "!recipe/**" - # FSDP - - "!verl/workers/**/*dp_*.py" + - "verl/workers/rollout/trtllm_rollout/**" + - "tests/workers/rollout/rollout_trtllm/**" + - ".github/workflows/e2e_ppo_grpo_trainer_trtllm.yml" + - "examples/data_preprocess/gsm8k.py" + - "examples/data_preprocess/geo3k.py" + - "examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh" + - "examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh" pull_request: branches: - main @@ -68,8 +68,9 @@ on: # FSDP - "!verl/workers/**/*dp_*.py" # Entrypoints - - "verl/workers/rollout/trtllm_rollout/*" - - ".github/workflows/e2e_ppo_grpo_trainer_trtllm" + - "verl/workers/rollout/trtllm_rollout/**" + - "tests/workers/rollout/rollout_trtllm/**" + - ".github/workflows/e2e_ppo_grpo_trainer_trtllm.yml" - "examples/data_preprocess/gsm8k.py" - "examples/data_preprocess/geo3k.py" # add back when ppo flow is ready @@ -128,9 +129,11 @@ jobs: - name: Run TRTLLM unit tests run: | export TRTLLM_TEST_MODEL_PATH_ROOT="${HOME}/models" + ray stop --force pytest -v -s \ tests/workers/rollout/rollout_trtllm/test_adapter.py \ - tests/workers/rollout/rollout_trtllm/test_async_server.py + tests/workers/rollout/rollout_trtllm/test_async_server.py \ + tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py e2e_grpo_trainer_fsdp-qwen2: needs: setup diff --git a/.github/workflows/test_trtllm_rollout.yml b/.github/workflows/test_trtllm_rollout.yml deleted file mode 100644 index 9c714de4892..00000000000 --- a/.github/workflows/test_trtllm_rollout.yml +++ /dev/null @@ -1,82 +0,0 @@ -name: test_trtllm_rollout - -on: - push: - branches: - - main - - v0.* - paths: - - "verl/workers/rollout/trtllm_rollout/**" - - "tests/workers/rollout/rollout_trtllm/**" - - ".github/workflows/test_trtllm_rollout.yml" - pull_request: - branches: - - main - - v0.* - paths: - - "verl/workers/rollout/trtllm_rollout/**" - - "tests/workers/rollout/rollout_trtllm/**" - - ".github/workflows/test_trtllm_rollout.yml" - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -permissions: - contents: read - -env: - IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:trtllm1.2.0rc6" - DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" - -jobs: - setup: - if: github.repository_owner == 'verl-project' - runs-on: ubuntu-latest - outputs: - runner-label: ${{ steps.create-runner.outputs.runner-label }} - mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} - steps: - - uses: actions/checkout@v4 - - id: create-runner - uses: volcengine/vemlp-github-runner@v1 - with: - mode: "create" - faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" - mlp-image: "${{ env.IMAGE }}" - - test_trtllm_rollout: - needs: setup - runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] - timeout-minutes: 60 - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -r requirements-test.txt - pip3 install --no-deps -e . - - name: Run TRT-LLM rollout tests - run: | - ray stop --force - pytest -v -s tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py - - cleanup: - runs-on: ubuntu-latest - needs: [setup, test_trtllm_rollout] - if: always() - steps: - - id: destroy-runner - uses: volcengine/vemlp-github-runner@v1 - with: - mode: "destroy" - faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" - mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" - From 8519d36ef9274d6df653847a054ccb0c9a604ea3 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Wed, 11 Feb 2026 09:06:44 +0800 Subject: [PATCH 15/24] add e2e script to train qwen2.5-vl with trtllm rollout --- .../grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh new file mode 100644 index 00000000000..0907c35304b --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh @@ -0,0 +1,60 @@ +set -x + +# python examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + algorithm.rollout_correction.rollout_is_threshold=2.0 \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.trust_remote_code=True \ + actor_rollout_ref.hybrid_engine=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + +actor_rollout_ref.model.override_config.attn_implementation=eager \ + +actor_rollout_ref.ref.model.override_config.attn_implementation=eager \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=trtllm \ + actor_rollout_ref.rollout.mode="async" \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=256 \ + actor_rollout_ref.rollout.max_num_batched_tokens=16384 \ + +actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_timeout_iters=32 \ + +actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_max_tokens_ratio=0.5 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.ref.strategy=fsdp2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + reward_manager.name=naive \ + reward_manager.source=register \ + trainer.critic_warmup=0 \ + trainer.logger='["console"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_trtllm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=10 \ + trainer.test_freq=5 \ + trainer.resume_mode=disable \ + trainer.total_epochs=10 \ No newline at end of file From 5a145a5b67febdd5c85eec7045ee8ef52bfcf582 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Thu, 12 Feb 2026 09:01:21 +0800 Subject: [PATCH 16/24] Change import statement --- verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index d205566ab8b..f9dc5e86aeb 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -21,7 +21,7 @@ from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer from tensorrt_llm._torch.utils import get_device_uuid from tensorrt_llm.logger import logger -import tensorrt_llm.llmapi.rlhf_utils.WorkerExtension as trtllm_worker_extension +from tensorrt_llm.llmapi.rlhf_utils import WorkerExtension as trtllm_worker_extension class WorkerExtension(trtllm_worker_extension): def __init__(self): From 37763386e4c61acfe5a9cebc0855341b398d61b1 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Thu, 12 Feb 2026 09:22:33 +0800 Subject: [PATCH 17/24] remove reward config in e2e script --- examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh index 0907c35304b..7f0dd590850 100644 --- a/examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh @@ -46,8 +46,6 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.strategy=fsdp2 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ - reward_manager.name=naive \ - reward_manager.source=register \ trainer.critic_warmup=0 \ trainer.logger='["console"]' \ trainer.project_name='verl_grpo_example_geo3k' \ From 1706e71d33c8f8ac643c89da499c49fb87958d56 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Thu, 12 Feb 2026 14:00:49 +0800 Subject: [PATCH 18/24] When multi modal input for trtllm, decode with special token first --- verl/experimental/agent_loop/agent_loop.py | 2 +- .../trtllm_rollout/trtllm_async_server.py | 9 +++-- .../rollout/trtllm_rollout/trtllm_rollout.py | 2 +- .../trtllm_rollout/trtllm_worker_extension.py | 40 ++++++++++++++++++- 4 files changed, 46 insertions(+), 7 deletions(-) diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 228d2248b7e..88665a9f8b1 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -113,7 +113,7 @@ async def generate( server = self._choose_server(request_id) output = await server.generate.remote( request_id=uuid4().hex, # use new request_id for each turn - prompt_ids=prompt_ids, + prompt_ids=prompt_ids, # for trtllm, this is the raw prompt sampling_params=sampling_params, image_data=image_data, video_data=video_data, diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 6afaa31354f..f949b014881 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -184,7 +184,7 @@ async def launch_server(self): async def generate( self, - prompt_ids: list[int], + prompt_ids: str, sampling_params: dict[str, Any], request_id: str, image_data: Optional[list[Any]] = None, @@ -201,9 +201,11 @@ async def generate( trt_llm_sampling_params = SamplingParams(**sampling_params) if self.is_vlm_model: + org_prompt = self.llm.tokenizer.decode(prompt_ids) if image_data or video_data: + input_dict = { - "prompt_token_ids": prompt_ids, + "prompt": org_prompt, "multi_modal_data": {}, "mm_processor_kwargs": {}, } @@ -211,6 +213,7 @@ async def generate( input_dict["multi_modal_data"]["image"] = image_data if video_data: input_dict["multi_modal_data"]["video"] = video_data + outputs = await self.llm.generate_async( inputs=input_dict, sampling_params=trt_llm_sampling_params, @@ -369,7 +372,7 @@ async def launch_servers(self): node_id=node_id, soft=False, ), - runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}}, + runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", "TLLM_NUMA_AWARE_WORKER_AFFINITY":"0"}}, name=name, ).remote( config=self.config, diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index 08457f38267..1abb1f107df 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -414,7 +414,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None total_available_bytes = int(self.config.checkpoint_engine.update_weights_bucket_megabytes) * 1024 * 1024 try: - device_uuid = get_device_uuid(self.gpu_id) + device_uuid = get_device_uuid(int(self.gpu_id)) except Exception as e: logger.error(f"Failed to get device UUID in update_weights(): {e}") device_uuid = None diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index f9dc5e86aeb..44241c7b229 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -62,7 +62,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None): # using restricted unpickler from tensorrt_llm.serialization logger.info("Deserializing base64-encoded weight handles") decoded_data = base64.b64decode(serialized_handles) - # Allow basic builtins and all torch modules + # Allow basic builtins and torch tensor reconstruction classes approved_imports = { "builtins": [ "list", @@ -76,11 +76,47 @@ def update_weights(self, ipc_handles: Optional[dict] = None): "NoneType", "type", ], + "torch": [ + "Tensor", + "FloatTensor", + "DoubleTensor", + "HalfTensor", + "BFloat16Tensor", + "IntTensor", + "LongTensor", + "ShortTensor", + "CharTensor", + "ByteTensor", + "BoolTensor", + "Size", + "dtype", + "device", + "float32", + "float16", + "int32", + "int64", + "int16", + "int8", + "uint8", + "bool", + ], + "torch.multiprocessing.reductions": [ + "rebuild_cuda_tensor", + "rebuild_tensor", + ], + "torch._utils": [ + "_rebuild_tensor_v2", + ], + "torch.storage": [ + "_load_from_bytes", + "_TypedStorage", + "UntypedStorage", + "TypedStorage", + ], } all_handles = serialization.loads( decoded_data, approved_imports=approved_imports, - approved_module_patterns=[r"^torch.*"], ) # Verify the result is a list as expected From 90837f33f5b292bbdc2af203b9f109af5b070fcb Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Thu, 12 Feb 2026 14:02:05 +0800 Subject: [PATCH 19/24] rever typo --- verl/workers/rollout/trtllm_rollout/trtllm_async_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index f949b014881..a27dba19c43 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -372,7 +372,7 @@ async def launch_servers(self): node_id=node_id, soft=False, ), - runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", "TLLM_NUMA_AWARE_WORKER_AFFINITY":"0"}}, + runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}}, name=name, ).remote( config=self.config, From 57506e28b42c4a98ce7c4aadf2ef8b6e1894c5e1 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Thu, 12 Feb 2026 14:03:35 +0800 Subject: [PATCH 20/24] revert typo --- verl/experimental/agent_loop/agent_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 88665a9f8b1..228d2248b7e 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -113,7 +113,7 @@ async def generate( server = self._choose_server(request_id) output = await server.generate.remote( request_id=uuid4().hex, # use new request_id for each turn - prompt_ids=prompt_ids, # for trtllm, this is the raw prompt + prompt_ids=prompt_ids, sampling_params=sampling_params, image_data=image_data, video_data=video_data, From e193d0dda09d07548ba593ecf5ea288daaaf5338 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Thu, 12 Feb 2026 14:09:53 +0800 Subject: [PATCH 21/24] pre commit check --- verl/workers/rollout/trtllm_rollout/trtllm_async_server.py | 3 +-- .../rollout/trtllm_rollout/trtllm_worker_extension.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index a27dba19c43..f2eaaecac13 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -203,7 +203,6 @@ async def generate( if self.is_vlm_model: org_prompt = self.llm.tokenizer.decode(prompt_ids) if image_data or video_data: - input_dict = { "prompt": org_prompt, "multi_modal_data": {}, @@ -213,7 +212,7 @@ async def generate( input_dict["multi_modal_data"]["image"] = image_data if video_data: input_dict["multi_modal_data"]["video"] = video_data - + outputs = await self.llm.generate_async( inputs=input_dict, sampling_params=trt_llm_sampling_params, diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index 44241c7b229..9b804228c59 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -15,13 +15,13 @@ import inspect from typing import Optional -import torch from tensorrt_llm import serialization from tensorrt_llm._ray_utils import control_action_decorator from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer from tensorrt_llm._torch.utils import get_device_uuid +from tensorrt_llm.llmapi.rlhf_utils import WorkerExtension as trtllm_worker_extension from tensorrt_llm.logger import logger -from tensorrt_llm.llmapi.rlhf_utils import WorkerExtension as trtllm_worker_extension + class WorkerExtension(trtllm_worker_extension): def __init__(self): @@ -164,4 +164,4 @@ def update_weights(self, ipc_handles: Optional[dict] = None): except Exception as e: logger.error("Encountered an error in update_weights") - raise e \ No newline at end of file + raise e From 81050ceded79fb2d44d90a81c2c7d81d2139a45a Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Fri, 27 Feb 2026 12:01:56 +0800 Subject: [PATCH 22/24] Fix bugs --- .../rollout/trtllm_rollout/trtllm_async_server.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index f2eaaecac13..e9484688bca 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -184,7 +184,7 @@ async def launch_server(self): async def generate( self, - prompt_ids: str, + prompt_ids: Union[str, list[int]], sampling_params: dict[str, Any], request_id: str, image_data: Optional[list[Any]] = None, @@ -200,9 +200,8 @@ async def generate( sampling_params.update(self.sampling_args) trt_llm_sampling_params = SamplingParams(**sampling_params) - if self.is_vlm_model: - org_prompt = self.llm.tokenizer.decode(prompt_ids) - if image_data or video_data: + if self.is_vlm_model and (image_data or video_data): + org_prompt = self.llm.tokenizer.decode(prompt_ids) input_dict = { "prompt": org_prompt, "multi_modal_data": {}, @@ -217,11 +216,6 @@ async def generate( inputs=input_dict, sampling_params=trt_llm_sampling_params, ) - else: - outputs = await self.llm.generate_async( - inputs=prompt_ids, - sampling_params=trt_llm_sampling_params, - ) else: outputs = await self.llm.generate_async( inputs=prompt_ids, @@ -230,7 +224,8 @@ async def generate( token_ids = outputs.outputs[0].token_ids log_probs = None if outputs.outputs[0].logprobs is not None: - log_probs = [logprobs[token_ids[i]].logprob for i, logprobs in enumerate(outputs.outputs[0].logprobs)] + # When logprobs=1, TRT-LLM returns only the sampled token's logprob at each position + log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs] return TokenOutput(token_ids=token_ids, log_probs=log_probs) async def wake_up(self): From 91d8c5969dfdd0b09b91e8c1029951d296edc4c3 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Fri, 27 Feb 2026 15:12:19 +0800 Subject: [PATCH 23/24] Update --- verl/workers/rollout/trtllm_rollout/trtllm_async_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index e9484688bca..e1ff979c454 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -14,7 +14,7 @@ import asyncio import logging import os -from typing import Any, Optional +from typing import Any, Optional, Union import ray import torch From 60dd50b307e71c9fcc3d5151aa1c57282592c348 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Fri, 27 Feb 2026 15:29:49 +0800 Subject: [PATCH 24/24] Update --- .../workers/rollout/trtllm_rollout/trtllm_worker_extension.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index 9b804228c59..4beb85f70e2 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -19,11 +19,11 @@ from tensorrt_llm._ray_utils import control_action_decorator from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer from tensorrt_llm._torch.utils import get_device_uuid -from tensorrt_llm.llmapi.rlhf_utils import WorkerExtension as trtllm_worker_extension +from tensorrt_llm.llmapi.rlhf_utils import WorkerExtension as TrtllmWorkerExtension from tensorrt_llm.logger import logger -class WorkerExtension(trtllm_worker_extension): +class WorkerExtension(TrtllmWorkerExtension): def __init__(self): pass