diff --git a/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml b/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml index c169f9dee55..fd9032e67a3 100644 --- a/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml +++ b/.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml @@ -41,13 +41,13 @@ on: - main - v0.* paths: - - "**/*.py" - # Other entrypoints - - "!verl/trainer/fsdp_sft_trainer.py" - # Recipes - - "!recipe/**" - # FSDP - - "!verl/workers/**/*dp_*.py" + - "verl/workers/rollout/trtllm_rollout/**" + - "tests/workers/rollout/rollout_trtllm/**" + - ".github/workflows/e2e_ppo_grpo_trainer_trtllm.yml" + - "examples/data_preprocess/gsm8k.py" + - "examples/data_preprocess/geo3k.py" + - "examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh" + - "examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh" pull_request: branches: - main @@ -68,8 +68,9 @@ on: # FSDP - "!verl/workers/**/*dp_*.py" # Entrypoints - - "verl/workers/rollout/trtllm_rollout/*" - - ".github/workflows/e2e_ppo_grpo_trainer_trtllm" + - "verl/workers/rollout/trtllm_rollout/**" + - "tests/workers/rollout/rollout_trtllm/**" + - ".github/workflows/e2e_ppo_grpo_trainer_trtllm.yml" - "examples/data_preprocess/gsm8k.py" - "examples/data_preprocess/geo3k.py" # add back when ppo flow is ready @@ -128,9 +129,11 @@ jobs: - name: Run TRTLLM unit tests run: | export TRTLLM_TEST_MODEL_PATH_ROOT="${HOME}/models" + ray stop --force pytest -v -s \ tests/workers/rollout/rollout_trtllm/test_adapter.py \ - tests/workers/rollout/rollout_trtllm/test_async_server.py + tests/workers/rollout/rollout_trtllm/test_async_server.py \ + tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py e2e_grpo_trainer_fsdp-qwen2: needs: setup diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh new file mode 100644 index 00000000000..7f0dd590850 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh @@ -0,0 +1,58 @@ +set -x + +# python examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + algorithm.rollout_correction.rollout_is_threshold=2.0 \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.trust_remote_code=True \ + actor_rollout_ref.hybrid_engine=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + +actor_rollout_ref.model.override_config.attn_implementation=eager \ + +actor_rollout_ref.ref.model.override_config.attn_implementation=eager \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=trtllm \ + actor_rollout_ref.rollout.mode="async" \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=256 \ + actor_rollout_ref.rollout.max_num_batched_tokens=16384 \ + +actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_timeout_iters=32 \ + +actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_max_tokens_ratio=0.5 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.ref.strategy=fsdp2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_trtllm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=10 \ + trainer.test_freq=5 \ + trainer.resume_mode=disable \ + trainer.total_epochs=10 \ No newline at end of file diff --git a/tests/workers/rollout/rollout_trtllm/__init__.py b/tests/workers/rollout/rollout_trtllm/__init__.py new file mode 100644 index 00000000000..d828409b82e --- /dev/null +++ b/tests/workers/rollout/rollout_trtllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py new file mode 100644 index 00000000000..21ab5689113 --- /dev/null +++ b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py @@ -0,0 +1,485 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import uuid + +import numpy as np +import pytest +import ray +import torch +from omegaconf import OmegaConf +from PIL import Image +from transformers import AutoTokenizer + +UNIMODAL_MODEL_PATH = "Qwen/Qwen2.5-Math-7B" +MULTIMODAL_MODEL_PATH = "Qwen/Qwen2.5-VL-7B-Instruct" + +MAX_MODEL_LEN = 4096 +RESPONSE_LENGTH = 256 +MAX_NUM_SEQS = 16 +GPU_MEMORY_UTILIZATION = 0.8 +TENSOR_PARALLEL_SIZE = 1 + + +def create_test_image(width: int = 224, height: int = 224) -> Image.Image: + img_array = np.zeros((height, width, 3), dtype=np.uint8) + for i in range(height): + for j in range(width): + img_array[i, j] = [ + int(255 * i / height), + int(255 * j / width), + int(255 * (i + j) / (height + width)), + ] + return Image.fromarray(img_array) + + +def create_rollout_config_dict(): + config_dict = { + "_target_": "verl.workers.config.RolloutConfig", + "name": "trtllm", + "mode": "async", + "temperature": 0.7, + "top_k": 50, + "top_p": 0.9, + "do_sample": True, + "n": 1, + "prompt_length": 512, + "response_length": RESPONSE_LENGTH, + "dtype": "bfloat16", + "gpu_memory_utilization": GPU_MEMORY_UTILIZATION, + "ignore_eos": False, + "enforce_eager": True, + "free_cache_engine": False, + "data_parallel_size": 1, + "tensor_model_parallel_size": TENSOR_PARALLEL_SIZE, + "pipeline_model_parallel_size": 1, + "max_num_batched_tokens": 8192, + "max_model_len": MAX_MODEL_LEN, + "max_num_seqs": MAX_NUM_SEQS, + "load_format": "auto", + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + } + return OmegaConf.create(config_dict) + + +def create_model_config_dict(model_path: str): + config_dict = { + "_target_": "verl.workers.config.HFModelConfig", + "path": model_path, + "trust_remote_code": True, + "load_tokenizer": True, + } + return OmegaConf.create(config_dict) + + +def get_tokenizer(model_path: str): + return AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + +def get_processor(model_path: str): + from transformers import AutoProcessor + + return AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestUnimodalTRTLLMRollout: + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_replica(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(UNIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(UNIMODAL_MODEL_PATH) + + @pytest.mark.parametrize( + "prompt", + [ + "What is 2 + 2?", + "Solve for x: 3x + 5 = 20", + "Calculate the derivative of x^2 + 3x + 1", + ], + ) + def test_unimodal_generate(self, trtllm_replica, tokenizer, prompt): + replica = trtllm_replica + + messages = [ + {"role": "system", "content": "You are a helpful math assistant."}, + {"role": "user", "content": prompt}, + ] + + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": True, + } + + request_id = str(uuid.uuid4()) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=request_id, + ) + ) + + assert output is not None + assert hasattr(output, "token_ids") + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print("\n[Unimodal Test]") + print(f"Prompt: {prompt}") + print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") + + def test_unimodal_batch_generate(self, trtllm_replica, tokenizer): + replica = trtllm_replica + + prompts = [ + "What is 1 + 1?", + "What is 2 * 3?", + "What is 10 / 2?", + ] + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + results = [] + + for i, prompt in enumerate(prompts): + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) + results.append(output) + + assert len(results) == len(prompts) + for i, (prompt, result) in enumerate(zip(prompts, results, strict=False)): + assert result is not None + assert len(result.token_ids) > 0 + generated = tokenizer.decode(result.token_ids, skip_special_tokens=True) + print(f"\n[Batch {i}] Prompt: {prompt}") + print(f"Generated: {generated[:100]}...") + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestMultimodalTRTLLMRollout: + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_vlm_replica(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(MULTIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(MULTIMODAL_MODEL_PATH) + + @pytest.fixture(scope="class") + def processor(self): + return get_processor(MULTIMODAL_MODEL_PATH) + + @pytest.mark.parametrize( + "prompt", + [ + "Describe this image in detail.", + "What colors do you see in this image?", + "What patterns are visible in this image?", + ], + ) + def test_multimodal_generate_with_image(self, trtllm_vlm_replica, processor, tokenizer, prompt): + replica = trtllm_vlm_replica + + test_image = create_test_image(224, 224) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt}, + ], + } + ] + + text = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + print("text: ", text) + input_ids = processor.tokenizer(text, return_tensors="pt", padding=True)["input_ids"][0].tolist() + + print( + "input_ids decoded: ", + processor.tokenizer.decode(input_ids, skip_special_tokens=False, add_special_tokens=False), + ) + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + ) + ) + + assert output is not None + assert hasattr(output, "token_ids") + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print("\n[Multimodal Test]") + print(f"Prompt: {prompt}") + print(f"Image size: {test_image.size}") + print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") + + @pytest.mark.parametrize( + "image_size", + [(224, 224), (384, 384), (512, 512)], + ) + def test_multimodal_different_image_sizes(self, trtllm_vlm_replica, processor, tokenizer, image_size): + replica = trtllm_vlm_replica + + width, height = image_size + test_image = create_test_image(width, height) + + prompt = "What is shown in this image?" + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt}, + ], + } + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = processor.tokenizer(text, return_tensors="pt", padding=True)["input_ids"][0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + ) + ) + + assert output is not None + assert len(output.token_ids) > 0 + print(f"\n[Image Size {image_size}] Generated {len(output.token_ids)} tokens") + + def test_multimodal_text_only_fallback(self, trtllm_vlm_replica, tokenizer): + replica = trtllm_vlm_replica + + prompt = "What is the capital of China?" + messages = [{"role": "user", "content": prompt}] + + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) + + assert output is not None + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print("\n[Text-only on VLM]") + print(f"Prompt: {prompt}") + print(f"Generated: {generated_text}") + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestTRTLLMServerLifecycle: + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_replica_lifecycle(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(UNIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica, loop + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(UNIMODAL_MODEL_PATH) + + def test_wake_sleep_cycle(self, trtllm_replica_lifecycle, tokenizer): + replica, loop = trtllm_replica_lifecycle + + prompt = "Hello, world!" + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = {"temperature": 0.7, "top_p": 0.9, "top_k": 50, "logprobs": False} + + output1 = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) + assert output1 is not None + assert len(output1.token_ids) > 0 + print(f"\n[Before Sleep] Generated {len(output1.token_ids)} tokens") + + loop.run_until_complete(replica.sleep()) + print("[Sleep] Server put to sleep") + + loop.run_until_complete(replica.wake_up()) + print("[Wake Up] Server woken up") + + output2 = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) + assert output2 is not None + assert len(output2.token_ids) > 0 + print(f"[After Wake Up] Generated {len(output2.token_ids)} tokens") diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index c6f62658cd1..e1ff979c454 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -14,7 +14,7 @@ import asyncio import logging import os -from typing import Any, Optional +from typing import Any, Optional, Union import ray import torch @@ -89,6 +89,10 @@ def __init__( logger.warning(f"rollout mode is {self.rollout_mode}, load_format is dummy, set to auto") self.config.load_format = "auto" + self.is_vlm_model = ( + self.model_config.hf_config is not None and hasattr(self.model_config.hf_config, "vision_config") + ) or hasattr(self.model_config, "vision_config") + # used for http server self._server_address = ray.util.get_node_ip_address().strip("[]") self._server_port = None @@ -130,7 +134,7 @@ async def launch_server(self): "enable_chunked_prefill": self.config.enable_chunked_prefill, "skip_tokenizer_init": self.config.skip_tokenizer_init, "orchestrator_type": "ray", - "ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + "ray_worker_extension_cls": "verl.workers.rollout.trtllm_rollout.trtllm_worker_extension.WorkerExtension", "kv_cache_config": kv_cache_config, "max_seq_len": self.config.max_model_len, "max_batch_size": self.config.max_num_seqs, @@ -167,7 +171,6 @@ async def launch_server(self): ) self.llm = await AsyncLLM(**llm_kwargs) - trtllm_server = OpenAIServer( llm=self.llm, model=self.model_config.local_path, @@ -175,20 +178,18 @@ async def launch_server(self): server_role=None, metadata_server_cfg=None, ) + app = trtllm_server.app self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address) async def generate( self, - prompt_ids: list[int], + prompt_ids: Union[str, list[int]], sampling_params: dict[str, Any], request_id: str, image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, ) -> TokenOutput: - """Generate sequence with token-in-token-out.""" - assert image_data is None and video_data is None, "Multimodality is not yet supported in TRTLLMHttpServer." - from tensorrt_llm.llmapi import SamplingParams max_tokens = min(self.config.response_length, self.config.max_model_len - len(prompt_ids)) @@ -199,14 +200,31 @@ async def generate( sampling_params.update(self.sampling_args) trt_llm_sampling_params = SamplingParams(**sampling_params) - outputs = await self.llm.generate_async( - inputs=prompt_ids, - sampling_params=trt_llm_sampling_params, - ) - + if self.is_vlm_model and (image_data or video_data): + org_prompt = self.llm.tokenizer.decode(prompt_ids) + input_dict = { + "prompt": org_prompt, + "multi_modal_data": {}, + "mm_processor_kwargs": {}, + } + if image_data: + input_dict["multi_modal_data"]["image"] = image_data + if video_data: + input_dict["multi_modal_data"]["video"] = video_data + + outputs = await self.llm.generate_async( + inputs=input_dict, + sampling_params=trt_llm_sampling_params, + ) + else: + outputs = await self.llm.generate_async( + inputs=prompt_ids, + sampling_params=trt_llm_sampling_params, + ) token_ids = outputs.outputs[0].token_ids log_probs = None - if trt_llm_sampling_params.logprobs is not None: + if outputs.outputs[0].logprobs is not None: + # When logprobs=1, TRT-LLM returns only the sampled token's logprob at each position log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs] return TokenOutput(token_ids=token_ids, log_probs=log_probs) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index 1cfffab6cc5..1abb1f107df 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -283,6 +283,7 @@ def __init__( self.is_leader_rank = None self.replica_rank = None self.is_dp_rank = None + self._supports_partial_loading = None # hybrid mode if self.device_mesh is not None: @@ -314,6 +315,21 @@ def __init__( self.node_ip = ray.util.get_node_ip_address().strip("[]") + async def get_supports_partial_loading(self) -> bool: + """Query and cache whether the model supports partial weight loading.""" + if self._supports_partial_loading is not None: + return self._supports_partial_loading + + await self._init_server_adapter() + try: + self._supports_partial_loading = await self.server_actor.supports_partial_loading.remote() + except Exception as e: + logger.warning(f"Failed to query partial loading support: {e}, defaulting to False") + self._supports_partial_loading = False + + logger.info(f"Model supports partial loading: {self._supports_partial_loading}") + return self._supports_partial_loading + async def _init_server_adapter(self): if self._adapter is not None: return @@ -398,7 +414,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None total_available_bytes = int(self.config.checkpoint_engine.update_weights_bucket_megabytes) * 1024 * 1024 try: - device_uuid = get_device_uuid(self.gpu_id) + device_uuid = get_device_uuid(int(self.gpu_id)) except Exception as e: logger.error(f"Failed to get device UUID in update_weights(): {e}") device_uuid = None @@ -416,15 +432,20 @@ async def flush(): cur_available_bytes = total_available_bytes cur_handles = [] + # Query if model supports partial loading + supports_partial_loading = await self.get_supports_partial_loading() + for name, param in weights: - size_in_bytes = param.element_size() * param.numel() - if size_in_bytes > cur_available_bytes: - await flush() + if supports_partial_loading: + size_in_bytes = param.element_size() * param.numel() + if size_in_bytes > cur_available_bytes: + await flush() + + assert cur_available_bytes >= size_in_bytes, ( + f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}" + ) + cur_available_bytes -= size_in_bytes - assert cur_available_bytes >= size_in_bytes, ( - f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}" - ) - cur_available_bytes -= size_in_bytes handle = reduce_tensor(param.detach()) cur_handles.append((name, handle)) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py new file mode 100644 index 00000000000..4beb85f70e2 --- /dev/null +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -0,0 +1,167 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import inspect +from typing import Optional + +from tensorrt_llm import serialization +from tensorrt_llm._ray_utils import control_action_decorator +from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer +from tensorrt_llm._torch.utils import get_device_uuid +from tensorrt_llm.llmapi.rlhf_utils import WorkerExtension as TrtllmWorkerExtension +from tensorrt_llm.logger import logger + + +class WorkerExtension(TrtllmWorkerExtension): + def __init__(self): + pass + + @control_action_decorator + def supports_partial_loading(self) -> bool: + """Check if the model supports partial weight loading.""" + try: + model = self.engine.model_engine.model + load_weights_args = inspect.getfullargspec(model.load_weights).args + return "allow_partial_loading" in load_weights_args + except Exception as e: + logger.warning(f"Failed to check partial loading support: {e}") + return False + + @control_action_decorator + def update_weights(self, ipc_handles: Optional[dict] = None): + try: + if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): + for module in self.engine.model_engine.model.modules(): + if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False): + module.pre_reload_weights() + self.engine.model_engine.model.first_pre_reload_weights = True + + if ipc_handles is not None: + logger.info("Update weights from IPC handles") + device_uuid = get_device_uuid(self.device_id) + + if device_uuid not in ipc_handles: + raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles") + + weights = {} + + serialized_handles = ipc_handles[device_uuid] + if isinstance(serialized_handles, str): + # Data is base64-encoded pickled bytes - deserialize it + # using restricted unpickler from tensorrt_llm.serialization + logger.info("Deserializing base64-encoded weight handles") + decoded_data = base64.b64decode(serialized_handles) + # Allow basic builtins and torch tensor reconstruction classes + approved_imports = { + "builtins": [ + "list", + "tuple", + "str", + "int", + "float", + "bool", + "bytes", + "dict", + "NoneType", + "type", + ], + "torch": [ + "Tensor", + "FloatTensor", + "DoubleTensor", + "HalfTensor", + "BFloat16Tensor", + "IntTensor", + "LongTensor", + "ShortTensor", + "CharTensor", + "ByteTensor", + "BoolTensor", + "Size", + "dtype", + "device", + "float32", + "float16", + "int32", + "int64", + "int16", + "int8", + "uint8", + "bool", + ], + "torch.multiprocessing.reductions": [ + "rebuild_cuda_tensor", + "rebuild_tensor", + ], + "torch._utils": [ + "_rebuild_tensor_v2", + ], + "torch.storage": [ + "_load_from_bytes", + "_TypedStorage", + "UntypedStorage", + "TypedStorage", + ], + } + all_handles = serialization.loads( + decoded_data, + approved_imports=approved_imports, + ) + + # Verify the result is a list as expected + if not isinstance(all_handles, list): + raise ValueError(f"Deserialized data must be a list, got {type(all_handles).__name__} instead") + else: + # Data is already in the correct format (backward compatibility) + all_handles = serialized_handles + + for param_name, tensor_handle in all_handles: + func, args = tensor_handle + list_args = list(args) + list_args[6] = self.device_id + tensor = func(*list_args) + weights[param_name] = tensor + + logger.info(f"weights key size: {len(weights.keys())}") + + # Check if model supports partial loading and use appropriate strategy + model = self.engine.model_engine.model + load_weights_args = inspect.getfullargspec(model.load_weights).args + supports_partial_loading = "allow_partial_loading" in load_weights_args + + if supports_partial_loading: + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True) + else: + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False) + else: + logger.info("Finalize update weights") + for module in self.engine.model_engine.model.modules(): + if hasattr(module, "process_weights_after_loading") and not getattr( + module, "_weights_removed", False + ): + module.process_weights_after_loading() + if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False): + module.post_load_weights() + moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) + if isinstance(moe_load_balancer, MoeLoadBalancer): + moe_load_balancer.register_weight_slots_after_to_cuda() + logger.info("moe_load_balancer finalizing model...") + moe_load_balancer.finalize_model() + logger.info("moe_load_balancer finalize model done") + self.engine.reset_prefix_cache() + delattr(self.engine.model_engine.model, "first_pre_reload_weights") + + except Exception as e: + logger.error("Encountered an error in update_weights") + raise e