Commit d6bc629
authored
[fsdp] feat: integrate PrefixGrouper for GRPO training acceleration (#4368)
### What does this PR do?
Integrate [PrefixGrouper](https://github.com/johncaged/PrefixGrouper)
into verl's FSDP worker to accelerate GRPO training by reducing
redundant prefix computations.
In GRPO training, each prompt is copied `G` times (rollout.n), leading
to redundant self-attention computation on shared prefixes.
PrefixGrouper decomposes this into **prefix self-attention + suffix
concat-attention**, significantly reducing computation and memory usage.
**Key changes:**
- Add `use_prefix_grouper` config option in `ActorConfig`
- Implement PG forward path in
`DataParallelPPOActor._forward_micro_batch`
- Add utility functions in `verl/trainer/ppo/prefix_grouper_utils.py`
- Add example scripts and documentation in
`examples/prefix_grouper_examples/`
### Test
**Benchmark Results** (Qwen3-4B, 4×H800, `rollout.n=4`):
| Context Length | Metric | PG | No PG | Speedup |
|----------------|--------|-----|-------|---------|
| **4K** | `old_log_prob` | 1.31s | 1.70s | **1.30x** |
| | `update_actor` | 4.80s | 6.07s | **1.26x** |
| | `step` | 17.08s | 19.40s | **1.14x** |
| **8K** | `old_log_prob` | 1.69s | 2.63s | **1.56x** |
| | `update_actor` | 5.98s | 10.18s | **1.70x** |
| | `step` | 19.48s | 24.71s | **1.27x** |
<img width="2234" height="1475" alt="timing_comparison_combined"
src="https://github.com/user-attachments/assets/3ef5dc69-1b3a-46d7-9d60-608a3fdc56f5"
/>
As context length increases, the speedup becomes more pronounced.
### API and Usage Example
```python
# Enable PrefixGrouper in training config
actor_rollout_ref.actor.use_prefix_grouper=True
trainer.balance_batch=False # Required: PG is incompatible with balance_batch
actor_rollout_ref.model.use_remove_padding=False # Required: PG is incompatible with remove_padding
```
```bash
# Run example script
bash examples/prefix_grouper_examples/run_qwen3_pg.sh
```
### Design & Code Changes
**High-level Design:**
PrefixGrouper optimizes GRPO training by avoiding redundant computation
on shared prefixes. When `rollout.n > 1`, multiple responses share the
same prompt, but standard attention computes the prefix `n` times.
PrefixGrouper decomposes this into:
1. **Prefix self-attention**: Compute once per unique prompt
2. **Suffix concat-attention**: Each response attends to the shared
prefix output
### Design & Code Changes
**High-level Design:**
PrefixGrouper optimizes GRPO training by avoiding redundant computation
on shared prefixes. When `rollout.n > 1`, multiple responses share the
same prompt, but standard attention computes the prefix `n` times.
PrefixGrouper decomposes this into:
1. **Prefix self-attention**: Compute once per unique prompt
2. **Suffix concat-attention**: Each response attends to the shared
prefix output
**Code Changes:**
| File | Change |
|------|--------|
| `verl/workers/config/actor.py` | Add `use_prefix_grouper: bool =
False` config option |
| `verl/trainer/config/actor/actor.yaml` | Add `use_prefix_grouper:
false` default config |
| `verl/workers/actor/dp_actor.py` | (1) Add `self.use_prefix_grouper`
and `self.use_dynamic_bsz` attributes in `__init__`; (2) Add PG forward
path in `_forward_micro_batch` with lazy import and incompatibility
checks; (3) Select extra keys (`prompts`, `response_mask`, `uid`) for PG
in `compute_log_prob`; (4) Select extra keys (`prompts`, `uid`) for PG
in `update_policy` |
| `verl/trainer/ppo/prefix_grouper_utils.py` | New file with:
`build_position_ids_for_prefix_grouper()` for position encoding,
`build_pg_from_micro_batch()` to construct PrefixGrouper from micro
batch, `pg_forward()` to execute PG-optimized forward pass |
| `verl/workers/fsdp_workers.py` | Sync `use_prefix_grouper` config from
actor to ref policy in `init_model` to ensure both use the same forward
path |
| `verl/trainer/ppo/ray_trainer.py` | Add `ValueError` check for
`use_prefix_grouper + balance_batch` incompatibility at initialization |
| `examples/prefix_grouper_examples/` | New directory with: `README.md`
documentation, `run_qwen3_prefix_grouper.sh` example script,
`qwen3/modeling_qwen3.py` modified model supporting PrefixGrouper |
### Limitations
- **FSDP worker only**: Megatron worker is not supported yet
- **Incompatible configurations:**
- `use_dynamic_bsz=True`
- `use_remove_padding=True` (Flash Attention V2 variable length)
- `use_fused_kernels=True`
- `use_ulysses_sp=True` (Ulysses sequence parallelism)
- **Model modification required**: The model must accept
`prefix_grouper` argument in its `forward` method
### Checklist Before Submitting
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
(Added
[examples/prefix_grouper_examples/README.md](cci:7://file:///d:/workspace/verl-tpx/examples/prefix_grouper_examples/README.md:0:0-0:0))
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: PrefixGrouper
requires modified model files and specific hardware setup, tested
manually with benchmark results above.
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1).1 parent 718523c commit d6bc629
File tree
13 files changed
+656
-11
lines changed- examples/prefix_grouper
- tests/utils
- verl
- models/transformers
- trainer
- config
- actor
- ppo
- utils
- workers
- actor
- config
13 files changed
+656
-11
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
200 | 200 | | |
201 | 201 | | |
202 | 202 | | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
36 | 74 | | |
37 | 75 | | |
38 | 76 | | |
| |||
251 | 289 | | |
252 | 290 | | |
253 | 291 | | |
| 292 | + | |
254 | 293 | | |
255 | 294 | | |
256 | 295 | | |
257 | 296 | | |
258 | | - | |
| 297 | + | |
259 | 298 | | |
260 | 299 | | |
261 | 300 | | |
| |||
276 | 315 | | |
277 | 316 | | |
278 | 317 | | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
279 | 321 | | |
280 | 322 | | |
281 | 323 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
85 | 85 | | |
86 | 86 | | |
87 | 87 | | |
| 88 | + | |
88 | 89 | | |
89 | 90 | | |
90 | 91 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
72 | 72 | | |
73 | 73 | | |
74 | 74 | | |
| 75 | + | |
75 | 76 | | |
76 | 77 | | |
77 | 78 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
94 | 94 | | |
95 | 95 | | |
96 | 96 | | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
97 | 100 | | |
98 | 101 | | |
99 | 102 | | |
| |||
0 commit comments