Skip to content

Commit d6bc629

Browse files
authored
[fsdp] feat: integrate PrefixGrouper for GRPO training acceleration (#4368)
### What does this PR do? Integrate [PrefixGrouper](https://github.com/johncaged/PrefixGrouper) into verl's FSDP worker to accelerate GRPO training by reducing redundant prefix computations. In GRPO training, each prompt is copied `G` times (rollout.n), leading to redundant self-attention computation on shared prefixes. PrefixGrouper decomposes this into **prefix self-attention + suffix concat-attention**, significantly reducing computation and memory usage. **Key changes:** - Add `use_prefix_grouper` config option in `ActorConfig` - Implement PG forward path in `DataParallelPPOActor._forward_micro_batch` - Add utility functions in `verl/trainer/ppo/prefix_grouper_utils.py` - Add example scripts and documentation in `examples/prefix_grouper_examples/` ### Test **Benchmark Results** (Qwen3-4B, 4×H800, `rollout.n=4`): | Context Length | Metric | PG | No PG | Speedup | |----------------|--------|-----|-------|---------| | **4K** | `old_log_prob` | 1.31s | 1.70s | **1.30x** | | | `update_actor` | 4.80s | 6.07s | **1.26x** | | | `step` | 17.08s | 19.40s | **1.14x** | | **8K** | `old_log_prob` | 1.69s | 2.63s | **1.56x** | | | `update_actor` | 5.98s | 10.18s | **1.70x** | | | `step` | 19.48s | 24.71s | **1.27x** | <img width="2234" height="1475" alt="timing_comparison_combined" src="https://github.com/user-attachments/assets/3ef5dc69-1b3a-46d7-9d60-608a3fdc56f5" /> As context length increases, the speedup becomes more pronounced. ### API and Usage Example ```python # Enable PrefixGrouper in training config actor_rollout_ref.actor.use_prefix_grouper=True trainer.balance_batch=False # Required: PG is incompatible with balance_batch actor_rollout_ref.model.use_remove_padding=False # Required: PG is incompatible with remove_padding ``` ```bash # Run example script bash examples/prefix_grouper_examples/run_qwen3_pg.sh ``` ### Design & Code Changes **High-level Design:** PrefixGrouper optimizes GRPO training by avoiding redundant computation on shared prefixes. When `rollout.n > 1`, multiple responses share the same prompt, but standard attention computes the prefix `n` times. PrefixGrouper decomposes this into: 1. **Prefix self-attention**: Compute once per unique prompt 2. **Suffix concat-attention**: Each response attends to the shared prefix output ### Design & Code Changes **High-level Design:** PrefixGrouper optimizes GRPO training by avoiding redundant computation on shared prefixes. When `rollout.n > 1`, multiple responses share the same prompt, but standard attention computes the prefix `n` times. PrefixGrouper decomposes this into: 1. **Prefix self-attention**: Compute once per unique prompt 2. **Suffix concat-attention**: Each response attends to the shared prefix output **Code Changes:** | File | Change | |------|--------| | `verl/workers/config/actor.py` | Add `use_prefix_grouper: bool = False` config option | | `verl/trainer/config/actor/actor.yaml` | Add `use_prefix_grouper: false` default config | | `verl/workers/actor/dp_actor.py` | (1) Add `self.use_prefix_grouper` and `self.use_dynamic_bsz` attributes in `__init__`; (2) Add PG forward path in `_forward_micro_batch` with lazy import and incompatibility checks; (3) Select extra keys (`prompts`, `response_mask`, `uid`) for PG in `compute_log_prob`; (4) Select extra keys (`prompts`, `uid`) for PG in `update_policy` | | `verl/trainer/ppo/prefix_grouper_utils.py` | New file with: `build_position_ids_for_prefix_grouper()` for position encoding, `build_pg_from_micro_batch()` to construct PrefixGrouper from micro batch, `pg_forward()` to execute PG-optimized forward pass | | `verl/workers/fsdp_workers.py` | Sync `use_prefix_grouper` config from actor to ref policy in `init_model` to ensure both use the same forward path | | `verl/trainer/ppo/ray_trainer.py` | Add `ValueError` check for `use_prefix_grouper + balance_batch` incompatibility at initialization | | `examples/prefix_grouper_examples/` | New directory with: `README.md` documentation, `run_qwen3_prefix_grouper.sh` example script, `qwen3/modeling_qwen3.py` modified model supporting PrefixGrouper | ### Limitations - **FSDP worker only**: Megatron worker is not supported yet - **Incompatible configurations:** - `use_dynamic_bsz=True` - `use_remove_padding=True` (Flash Attention V2 variable length) - `use_fused_kernels=True` - `use_ulysses_sp=True` (Ulysses sequence parallelism) - **Model modification required**: The model must accept `prefix_grouper` argument in its `forward` method ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). (Added [examples/prefix_grouper_examples/README.md](cci:7://file:///d:/workspace/verl-tpx/examples/prefix_grouper_examples/README.md:0:0-0:0)) - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: PrefixGrouper requires modified model files and specific hardware setup, tested manually with benchmark results above. - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1).
1 parent 718523c commit d6bc629

File tree

13 files changed

+656
-11
lines changed

13 files changed

+656
-11
lines changed

examples/prefix_grouper/README.md

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# PrefixGrouper Examples
2+
3+
This directory contains examples for using **PrefixGrouper**, an optimization technique that groups samples by shared prompts to reduce redundant computations in GRPO.
4+
5+
## Introduction
6+
7+
> Official Repository: [https://github.com/johncaged/PrefixGrouper](https://github.com/johncaged/PrefixGrouper)
8+
9+
``PrefixGrouper`` is a plug-and-play efficient GRPO training tool that requires minimal modifications to existing codebases to achieve reduced computation, lower device memory consumption, and accelerated training.
10+
11+
In current mainstream GRPO training pipelines, policy model training primarily involves copying prefixes (typically questions, multimodal inputs, etc.) `G` times. Consequently, when training data prefixes are sufficiently long (e.g., long-context reasoning, image/long-video inference), redundant computation during training becomes non-negligible.
12+
13+
**PrefixGrouper** decomposes the original redundant self-attention operation into prefix self-attention + suffix concat-attention.
14+
15+
<h3 align="center">
16+
<img src="https://raw.githubusercontent.com/johncaged/PrefixGrouper/main/assets/images/method.jpg">
17+
</h3>
18+
19+
## Installation
20+
21+
```bash
22+
pip install prefix_grouper
23+
```
24+
25+
## Limitations
26+
27+
- Currently only supports FSDP worker (Megatron worker is not supported yet).
28+
- Incompatible with `use_dynamic_bsz=True`.
29+
- Incompatible with `use_remove_padding=True` (Flash Attention V2 variable length).
30+
- Incompatible with `use_fused_kernels=True`.
31+
- Incompatible with Ulysses sequence parallelism (`use_ulysses_sp=True`) and ring-attention.
32+
33+
Note: `balance_batch=True` is now supported with group-level balancing, which keeps samples with the same uid together on the same rank. However, this requires `batch_size % (world_size * rollout.n) == 0`. For example, with `world_size=8` and `rollout.n=4`, you need `batch_size` to be a multiple of 32.
34+
35+
## How to Use
36+
37+
### 1. Enable PrefixGrouper in Config
38+
39+
Simply set `use_prefix_grouper=True` in your training config:
40+
41+
```yaml
42+
actor_rollout_ref:
43+
actor:
44+
use_prefix_grouper: True
45+
model:
46+
use_remove_padding: False
47+
```
48+
49+
Optionally enable balance_batch for better load distribution:
50+
```yaml
51+
trainer:
52+
balance_batch: True # Now supported with group-level balancing
53+
```
54+
55+
### 2. Run Training
56+
57+
Use the provided script `run_qwen3_prefix_grouper.sh` as an example:
58+
59+
```bash
60+
bash examples/prefix_grouper/run_qwen3_prefix_grouper.sh
61+
```
62+
63+
## How It Works
64+
65+
When `use_prefix_grouper=True`, verl automatically patches the attention functions in `transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS` to support the `prefix_grouper` parameter. No model code modifications are needed.
66+
67+
The patch wraps each attention function to:
68+
1. Extract `prefix_grouper` from kwargs
69+
2. If `prefix_grouper` is None, call original attention
70+
3. If `prefix_grouper` is provided, use PrefixGrouper's optimized attention computation
71+
72+
## Performance
73+
74+
**Benchmark Results** (Qwen3-4B, 4×H800, `rollout.n=4`):
75+
76+
| Context Length | Metric | PG | No PG | Speedup |
77+
|----------------|--------|-----|-------|---------|
78+
| **4K** | `old_log_prob` | 1.31s | 1.70s | **1.30x** |
79+
| | `update_actor` | 4.80s | 6.07s | **1.26x** |
80+
| | `step` | 17.08s | 19.40s | **1.14x** |
81+
| **8K** | `old_log_prob` | 1.69s | 2.63s | **1.56x** |
82+
| | `update_actor` | 5.98s | 10.18s | **1.70x** |
83+
| | `step` | 19.48s | 24.71s | **1.27x** |
84+
85+
As context length increases, the speedup becomes more pronounced.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
set -x
2+
3+
4+
python3 -m verl.trainer.main_ppo \
5+
algorithm.adv_estimator=grpo \
6+
data.train_files=$HOME/data/gsm8k/train.parquet \
7+
data.val_files=$HOME/data/gsm8k/test.parquet \
8+
data.train_batch_size=1024 \
9+
data.max_prompt_length=512 \
10+
data.max_response_length=1024 \
11+
data.filter_overlong_prompts=True \
12+
data.truncation='error' \
13+
actor_rollout_ref.model.path=Qwen/Qwen3-8B \
14+
actor_rollout_ref.actor.optim.lr=1e-6 \
15+
actor_rollout_ref.model.use_remove_padding=False \
16+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
17+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
18+
actor_rollout_ref.actor.use_kl_loss=True \
19+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
20+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
21+
actor_rollout_ref.actor.entropy_coeff=0 \
22+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
23+
actor_rollout_ref.actor.use_prefix_grouper=True \
24+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
25+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
26+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
27+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
28+
actor_rollout_ref.rollout.name=vllm \
29+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
30+
actor_rollout_ref.rollout.n=5 \
31+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
32+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
33+
algorithm.use_kl_in_reward=False \
34+
trainer.critic_warmup=0 \
35+
trainer.logger='["console","wandb"]' \
36+
trainer.project_name='verl_grpo_example_gsm8k' \
37+
trainer.experiment_name='qwen3_function_rm_pg' \
38+
trainer.n_gpus_per_node=8 \
39+
trainer.nnodes=1 \
40+
trainer.save_freq=20 \
41+
trainer.test_freq=5 \
42+
trainer.balance_batch=True \
43+
trainer.total_epochs=15 $@

tests/utils/test_seqlen_balancing.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,79 @@ def test_seqlen_balancing_distributed_params(tmp_path):
200200
nprocs=world_size,
201201
join=True,
202202
)
203+
204+
205+
def test_group_balanced_partitions():
206+
"""Test group-level balancing keeps same-uid samples together."""
207+
from verl.utils.seqlen_balancing import get_group_balanced_partitions
208+
209+
# Create test data: 4 groups with different sizes
210+
# Group 0 (uid=0): indices 0,1,2,3 with seqlens [100, 100, 100, 100]
211+
# Group 1 (uid=1): indices 4,5,6,7 with seqlens [200, 200, 200, 200]
212+
# Group 2 (uid=2): indices 8,9,10,11 with seqlens [150, 150, 150, 150]
213+
# Group 3 (uid=3): indices 12,13,14,15 with seqlens [50, 50, 50, 50]
214+
seqlen_list = [100] * 4 + [200] * 4 + [150] * 4 + [50] * 4
215+
uid_list = [0] * 4 + [1] * 4 + [2] * 4 + [3] * 4
216+
217+
# Partition into 2 groups
218+
partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=2)
219+
220+
assert len(partitions) == 2
221+
222+
# Verify all indices are covered
223+
all_indices = set()
224+
for partition in partitions:
225+
all_indices.update(partition)
226+
assert all_indices == set(range(16))
227+
228+
# Verify same-uid samples stay together
229+
for partition in partitions:
230+
uids_in_partition = set(uid_list[i] for i in partition)
231+
for uid in uids_in_partition:
232+
# All samples with this uid should be in this partition
233+
uid_indices = [i for i, u in enumerate(uid_list) if u == uid]
234+
assert all(i in partition for i in uid_indices), f"uid {uid} samples split across partitions"
235+
236+
237+
def test_group_balanced_partitions_single_sample_groups():
238+
"""Test group balancing with single-sample groups (n=1)."""
239+
from verl.utils.seqlen_balancing import get_group_balanced_partitions
240+
241+
# Each sample is its own group
242+
seqlen_list = [100, 200, 150, 50, 300, 250]
243+
uid_list = [0, 1, 2, 3, 4, 5]
244+
245+
partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=2)
246+
247+
assert len(partitions) == 2
248+
all_indices = set()
249+
for partition in partitions:
250+
all_indices.update(partition)
251+
assert all_indices == set(range(6))
252+
253+
254+
def test_group_balanced_partitions_equal_size():
255+
"""Test group balancing with equal_size constraint simulation."""
256+
from verl.utils.seqlen_balancing import get_group_balanced_partitions
257+
258+
# 8 groups, partition into 4 (simulating world_size=4)
259+
# Each group has 2 samples
260+
seqlen_list = [100, 100, 200, 200, 150, 150, 50, 50, 300, 300, 250, 250, 180, 180, 120, 120]
261+
uid_list = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
262+
263+
partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=4)
264+
265+
assert len(partitions) == 4
266+
267+
# Verify all indices are covered
268+
all_indices = set()
269+
for partition in partitions:
270+
all_indices.update(partition)
271+
assert all_indices == set(range(16))
272+
273+
# Verify same-uid samples stay together
274+
for partition in partitions:
275+
uids_in_partition = set(uid_list[i] for i in partition)
276+
for uid in uids_in_partition:
277+
uid_indices = [i for i, u in enumerate(uid_list) if u == uid]
278+
assert all(i in partition for i in uid_indices)

verl/models/transformers/monkey_patch.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,44 @@
3333
slice_input_tensor,
3434
)
3535

36+
_PREFIX_GROUPER_PATCHED = False
37+
_PREFIX_GROUPER_SUPPORTED_ATTENTIONS = {"flash_attention_2", "flash_attention_3", "sdpa", "flex_attention", "eager"}
38+
39+
40+
def _create_prefix_grouper_wrapper(original_fn):
41+
"""Wrap attention function to support prefix_grouper in kwargs."""
42+
43+
def wrapped(module, query, key, value, attention_mask, *args, **kwargs):
44+
prefix_grouper = kwargs.pop("prefix_grouper", None)
45+
if prefix_grouper is None:
46+
return original_fn(module, query, key, value, attention_mask, *args, **kwargs)
47+
48+
def attn_func(q, k, v, attn_mask, *inner_args, **inner_kwargs):
49+
out, _ = original_fn(module, q, k, v, attn_mask, *inner_args, **inner_kwargs)
50+
return out
51+
52+
return prefix_grouper.forward(attn_func, query, key, value, *args, **kwargs), None
53+
54+
return wrapped
55+
56+
57+
def apply_prefix_grouper_patch():
58+
"""Patch ALL_ATTENTION_FUNCTIONS to support prefix_grouper parameter."""
59+
global _PREFIX_GROUPER_PATCHED
60+
if _PREFIX_GROUPER_PATCHED:
61+
return
62+
63+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
64+
65+
patched = []
66+
for name in list(ALL_ATTENTION_FUNCTIONS.keys()):
67+
if name in _PREFIX_GROUPER_SUPPORTED_ATTENTIONS:
68+
ALL_ATTENTION_FUNCTIONS[name] = _create_prefix_grouper_wrapper(ALL_ATTENTION_FUNCTIONS[name])
69+
patched.append(name)
70+
71+
_PREFIX_GROUPER_PATCHED = True
72+
print(f"[PrefixGrouper] Patched: {patched}")
73+
3674

3775
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
3876
"""
@@ -251,11 +289,12 @@ def apply_monkey_patch(
251289
use_remove_padding: bool = True,
252290
use_fused_kernels: bool = False,
253291
fused_kernels_backend: str = None,
292+
use_prefix_grouper: bool = False,
254293
use_tiled_mlp: bool = False,
255294
tiled_mlp_shards: int = 4,
256295
):
257296
"""
258-
Apply monkey patch to the models for ulysses sequence parallel, fused kernel, and tiled MLP.
297+
Apply monkey patch to the models for ulysses sequence parallel, fused kernel, tiled MLP and prefix grouper.
259298
260299
In the end of this function forward function of the model is patched for fused kernel.
261300
If the model is not supported with fused kernel, please return after patch.
@@ -276,6 +315,9 @@ def apply_monkey_patch(
276315

277316
model_type = getattr(model.config, "model_type", None)
278317
apply_tiled_mlp_monkey_patch(num_shards=tiled_mlp_shards, model_type=model_type)
318+
# Apply PrefixGrouper patch if enabled
319+
if use_prefix_grouper:
320+
apply_prefix_grouper_patch()
279321

280322
"""Replace _flash_attention_forward to _ulysses_flash_attention_forward"""
281323
module = sys.modules[model.__module__]

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ actor_rollout_ref:
8585
entropy_coeff: 0
8686
calculate_entropy: false
8787
use_kl_loss: false
88+
use_prefix_grouper: false
8889
use_torch_compile: true
8990
kl_loss_coef: 0.001
9091
kl_loss_type: low_var_kl

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ actor_rollout_ref:
7272
entropy_coeff: 0
7373
calculate_entropy: false
7474
use_kl_loss: false
75+
use_prefix_grouper: false
7576
use_torch_compile: true
7677
kl_loss_coef: 0.001
7778
kl_loss_type: low_var_kl

verl/trainer/config/actor/actor.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ calculate_entropy: false
9494
# Whether to use KL loss instead of KL reward penalty. True for GRPO
9595
use_kl_loss: false
9696

97+
# Whether to enable PrefixGrouper shared-prefix forward
98+
use_prefix_grouper: false
99+
97100
# Whether to use torch.compile()
98101
# oc.select: the default val for ref.use_torch_compile
99102
use_torch_compile: true

0 commit comments

Comments
 (0)