Commit 37ff251
authored
[training_utils] fix: mask out-of-bounds vocab entries fused kernel LCE logsumexp (verl-project#5349)
### What does this PR do?
Fixes verl-project#2656 and verl-project#2899. When `vocab_size % BLOCK_SIZE_N != 0`, the final
Triton tile loads out-of-bounds weight rows (padded to zero by
`tl.load`), contributing `exp(0) = 1` per phantom token to the logsumexp
accumulator. This inflates the denominator and corrupts
log-probabilities and entropy. For uniform distributions this remains
undetected, but impacts severely (up to 100% token mismatch) when the
softmax is peaked, as commonly seen during RL training.
Fix: mask OOB positions to `-inf` before the running max/exp
accumulation via `tl.where`. The original (zero-padded) logits are kept
for the entropy accumulator.
### Checklist Before Starting
- [x] Search for similar PRs. Paste at least one query link here: [is:pr
lce vocab
padding](https://github.com/verl-project/verl/pulls?q=is%3Apr+lce+vocab+padding)
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`,
`rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`,
`deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`,
`model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
- Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`
### Test
Added `test_lce_non_divisible_vocab_padding` to
`tests/utils/test_linear_cross_entropy.py`. The test constructs a peaked
weight matrix (`w[:, 0] = -15·T`, `w[0, 0] = 3·T`, `vocab_size=152064`
which has mod 1024 =
512) to reliably trigger the bug, plus a divisible-vocab control case.
GPU results (H200, `vocab_size=152064`):
| Case | Without fix | With fix |
|---|---|---|
| non-divisible vocab (mod1024=512) | max\_diff=3.274581,
mismatch=100.0% ❌ | max\_diff=0.000001, mismatch=0.0% ✅ |
| divisible vocab (mod1024=0) | max\_diff=0.000000, mismatch=0.0% ✅ |
max\_diff=0.000000, mismatch=0.0% ✅ |
### API and Usage Example
No API changes.
### Design & Code Changes
- `verl/utils/kernel/kernels.py`
(`efficient_entropy_kernel_general_mainloop`): compute `vocab_bound =
min((pid_n + 1) * vocab_per_split, vocab_size)`, derive `logits_for_lse
= tl.where(offs_bn[None, :] <
vocab_bound, logits, float("-inf"))`, use it for the running max and
exp-sum while keeping original `logits` for the entropy accumulator.
- `tests/utils/test_linear_cross_entropy.py`: add
`test_lce_non_divisible_vocab_padding`.
### Checklist Before Submitting
> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs). (not
needed)
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: test requires a GPU
and is already in
`tests/utils/test_linear_cross_entropy.py` which is part of the existing
test suite.
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [x] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`. (not needed)1 parent f5c34bb commit 37ff251
File tree
2 files changed
+53
-2
lines changed- tests/utils
- verl/utils/kernel
2 files changed
+53
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
348 | 348 | | |
349 | 349 | | |
350 | 350 | | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
351 | 397 | | |
352 | 398 | | |
353 | 399 | | |
| |||
358 | 404 | | |
359 | 405 | | |
360 | 406 | | |
| 407 | + | |
| 408 | + | |
361 | 409 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
263 | 263 | | |
264 | 264 | | |
265 | 265 | | |
| 266 | + | |
266 | 267 | | |
267 | 268 | | |
268 | 269 | | |
| |||
308 | 309 | | |
309 | 310 | | |
310 | 311 | | |
| 312 | + | |
| 313 | + | |
311 | 314 | | |
312 | 315 | | |
313 | | - | |
| 316 | + | |
314 | 317 | | |
315 | 318 | | |
316 | | - | |
| 319 | + | |
317 | 320 | | |
318 | 321 | | |
319 | 322 | | |
| |||
0 commit comments