Skip to content
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
dcaacfe
Fix partial load problem, Add vlm support for trtllm rollout
SchumiDing Jan 31, 2026
0394ab5
Precommit check
SchumiDing Jan 31, 2026
0664ab1
Add check for if the model is vlm in trtllmhttpserver
SchumiDing Jan 31, 2026
bf71c9b
Support latest trtllm
SchumiDing Feb 2, 2026
f6e58b8
Support for qwen2.5 vl
SchumiDing Feb 2, 2026
7af6917
Add trtllm rollout test script
SchumiDing Feb 2, 2026
94c4eb0
Add test_trtllm_rollout workflow to test trtllm_rollout
SchumiDing Feb 2, 2026
25518fe
Add back mistakenly deleted file
SchumiDing Feb 2, 2026
fd007fb
Precommit check
SchumiDing Feb 2, 2026
659ec01
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 4, 2026
55b55dc
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 5, 2026
e2cc50b
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 5, 2026
ca17f8a
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 6, 2026
62af0f2
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 11, 2026
24a6620
Modified to inherit the worker extension class of tensorrt llm
SchumiDing Feb 11, 2026
6f055a2
Modified to inherit the worker extension class of tensorrt llm
SchumiDing Feb 11, 2026
d0b1d1d
fix readability problem of multimodal config
SchumiDing Feb 11, 2026
6b021f4
Remove need for multimodal server config
SchumiDing Feb 11, 2026
a7faa7b
Add vlm unit test into exisiting trtllm unit test
SchumiDing Feb 11, 2026
8519d36
add e2e script to train qwen2.5-vl with trtllm rollout
SchumiDing Feb 11, 2026
9acdcd6
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 12, 2026
5a145a5
Change import statement
SchumiDing Feb 12, 2026
3776338
remove reward config in e2e script
SchumiDing Feb 12, 2026
1706e71
When multi modal input for trtllm, decode with special token first
SchumiDing Feb 12, 2026
90837f3
rever typo
SchumiDing Feb 12, 2026
57506e2
revert typo
SchumiDing Feb 12, 2026
e193d0d
pre commit check
SchumiDing Feb 12, 2026
81050ce
Fix bugs
SchumiDing Feb 27, 2026
91d8c59
Update
SchumiDing Feb 27, 2026
60dd50b
Update
SchumiDing Feb 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions .github/workflows/e2e_ppo_grpo_trainer_trtllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ on:
- main
- v0.*
paths:
- "**/*.py"
# Other entrypoints
- "!verl/trainer/fsdp_sft_trainer.py"
# Recipes
- "!recipe/**"
# FSDP
- "!verl/workers/**/*dp_*.py"
- "verl/workers/rollout/trtllm_rollout/**"
- "tests/workers/rollout/rollout_trtllm/**"
- ".github/workflows/e2e_ppo_grpo_trainer_trtllm.yml"
- "examples/data_preprocess/gsm8k.py"
- "examples/data_preprocess/geo3k.py"
- "examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh"
- "examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh"
pull_request:
branches:
- main
Expand All @@ -68,8 +68,9 @@ on:
# FSDP
- "!verl/workers/**/*dp_*.py"
# Entrypoints
- "verl/workers/rollout/trtllm_rollout/*"
- ".github/workflows/e2e_ppo_grpo_trainer_trtllm"
- "verl/workers/rollout/trtllm_rollout/**"
- "tests/workers/rollout/rollout_trtllm/**"
- ".github/workflows/e2e_ppo_grpo_trainer_trtllm.yml"
- "examples/data_preprocess/gsm8k.py"
- "examples/data_preprocess/geo3k.py"
# add back when ppo flow is ready
Expand Down Expand Up @@ -128,9 +129,11 @@ jobs:
- name: Run TRTLLM unit tests
run: |
export TRTLLM_TEST_MODEL_PATH_ROOT="${HOME}/models"
ray stop --force
pytest -v -s \
tests/workers/rollout/rollout_trtllm/test_adapter.py \
tests/workers/rollout/rollout_trtllm/test_async_server.py
tests/workers/rollout/rollout_trtllm/test_async_server.py \
tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py

e2e_grpo_trainer_fsdp-qwen2:
needs: setup
Expand Down
58 changes: 58 additions & 0 deletions examples/grpo_trainer/run_qwen2_5_vl-7b-trtllm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
set -x

# python examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
algorithm.rollout_correction.rollout_is_threshold=2.0 \
data.train_files=$HOME/data/geo3k/train.parquet \
data.val_files=$HOME/data/geo3k/test.parquet \
data.train_batch_size=512 \
data.max_prompt_length=1024 \
data.max_response_length=2048 \
data.return_raw_chat=True \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.trust_remote_code=True \
actor_rollout_ref.hybrid_engine=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
actor_rollout_ref.model.trust_remote_code=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.strategy=fsdp2 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+actor_rollout_ref.model.override_config.attn_implementation=eager \
+actor_rollout_ref.ref.model.override_config.attn_implementation=eager \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=trtllm \
actor_rollout_ref.rollout.mode="async" \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.rollout.max_num_seqs=256 \
actor_rollout_ref.rollout.max_num_batched_tokens=16384 \
+actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_timeout_iters=32 \
+actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_max_tokens_ratio=0.5 \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.ref.strategy=fsdp2 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console"]' \
trainer.project_name='verl_grpo_example_geo3k' \
trainer.experiment_name='qwen2_5_vl_7b_trtllm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=10 \
trainer.test_freq=5 \
trainer.resume_mode=disable \
trainer.total_epochs=10
13 changes: 13 additions & 0 deletions tests/workers/rollout/rollout_trtllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading