Skip to content

Commit 37ff251

Browse files
[training_utils] fix: mask out-of-bounds vocab entries fused kernel LCE logsumexp (verl-project#5349)
### What does this PR do? Fixes verl-project#2656 and verl-project#2899. When `vocab_size % BLOCK_SIZE_N != 0`, the final Triton tile loads out-of-bounds weight rows (padded to zero by `tl.load`), contributing `exp(0) = 1` per phantom token to the logsumexp accumulator. This inflates the denominator and corrupts log-probabilities and entropy. For uniform distributions this remains undetected, but impacts severely (up to 100% token mismatch) when the softmax is peaked, as commonly seen during RL training. Fix: mask OOB positions to `-inf` before the running max/exp accumulation via `tl.where`. The original (zero-padded) logits are kept for the entropy accumulator. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: [is:pr lce vocab padding](https://github.com/verl-project/verl/pulls?q=is%3Apr+lce+vocab+padding) - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test Added `test_lce_non_divisible_vocab_padding` to `tests/utils/test_linear_cross_entropy.py`. The test constructs a peaked weight matrix (`w[:, 0] = -15·T`, `w[0, 0] = 3·T`, `vocab_size=152064` which has mod 1024 = 512) to reliably trigger the bug, plus a divisible-vocab control case. GPU results (H200, `vocab_size=152064`): | Case | Without fix | With fix | |---|---|---| | non-divisible vocab (mod1024=512) | max\_diff=3.274581, mismatch=100.0% ❌ | max\_diff=0.000001, mismatch=0.0% ✅ | | divisible vocab (mod1024=0) | max\_diff=0.000000, mismatch=0.0% ✅ | max\_diff=0.000000, mismatch=0.0% ✅ | ### API and Usage Example No API changes. ### Design & Code Changes - `verl/utils/kernel/kernels.py` (`efficient_entropy_kernel_general_mainloop`): compute `vocab_bound = min((pid_n + 1) * vocab_per_split, vocab_size)`, derive `logits_for_lse = tl.where(offs_bn[None, :] < vocab_bound, logits, float("-inf"))`, use it for the running max and exp-sum while keeping original `logits` for the entropy accumulator. - `tests/utils/test_linear_cross_entropy.py`: add `test_lce_non_divisible_vocab_padding`. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). (not needed) - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: test requires a GPU and is already in `tests/utils/test_linear_cross_entropy.py` which is part of the existing test suite. - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [x] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. (not needed)
1 parent f5c34bb commit 37ff251

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

tests/utils/test_linear_cross_entropy.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,52 @@ def check_storage_all(self):
348348
self.check_storage("Kernel", linear_cross_entropy)
349349

350350

351+
def test_lce_non_divisible_vocab_padding():
352+
"""Regression test for the logsumexp padding bug.
353+
354+
When vocab_size % BLOCK_SIZE_N != 0 the last tile has fewer than
355+
BLOCK_SIZE_N valid entries. Without the fix, out-of-bounds positions
356+
are loaded as weight=0 → logit=0 → exp(0)=1, adding phantom probability
357+
mass to the logsumexp denominator. For peaked softmax distributions
358+
(small denominator) this causes large log-prob errors.
359+
360+
Reproducing construction: one token-logit at +3, all others at -15
361+
→ denominator ≈ 20, phantom adds ≈ 25 → error ≈ 0.82 per token.
362+
"""
363+
if not torch.cuda.is_available():
364+
return
365+
366+
torch.manual_seed(0)
367+
368+
V = 152064 # vocab_size % 1024 == 512 (triggers bug)
369+
V_div = 149 * 1024 # vocab_size % 1024 == 0 (control)
370+
D = 3584
371+
N = 512
372+
T = 1.5
373+
374+
def reference(hidden, weight, labels):
375+
h = hidden.squeeze(0).float()
376+
logits = torch.matmul(h, weight.float().T) / T
377+
lp = -torch.nn.functional.cross_entropy(logits, labels.squeeze(0), reduction="none")
378+
pd = torch.nn.functional.softmax(logits, dim=-1)
379+
ent = torch.logsumexp(logits, dim=-1) - (pd * logits).sum(-1)
380+
return lp, ent
381+
382+
for vocab_size, desc in [(V, "non-divisible vocab (mod1024=512)"), (V_div, "divisible vocab (mod1024=0)")]:
383+
w = torch.zeros(vocab_size, D, dtype=torch.bfloat16, device="cuda")
384+
w[:, 0] = -15.0 * T
385+
w[0, 0] = 3.0 * T
386+
h = torch.zeros(1, N, D, dtype=torch.bfloat16, device="cuda")
387+
h[:, :, 0] = 1.0
388+
labels = torch.zeros(1, N, dtype=torch.long, device="cuda")
389+
390+
ref_lp, ref_ent = reference(h, w, labels)
391+
ker_lp, ker_ent = linear_cross_entropy(h, w, labels, T)
392+
393+
torch.testing.assert_close(ref_lp, ker_lp, atol=1e-3, rtol=1e-3, msg=f"logprob mismatch: {desc}")
394+
torch.testing.assert_close(ref_ent, ker_ent, atol=1e-3, rtol=1e-3, msg=f"entropy mismatch: {desc}")
395+
396+
351397
if __name__ == "__main__":
352398
# torch.cuda.memory._record_memory_history()
353399

@@ -358,4 +404,6 @@ def check_storage_all(self):
358404
test.verify_correctness()
359405
test.check_storage_all()
360406

407+
test_lce_non_divisible_vocab_padding()
408+
361409
# torch.cuda.memory._dump_snapshot("test_linear_cross_entropy.pkl")

verl/utils/kernel/kernels.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def efficient_entropy_kernel_general_mainloop(
263263
_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
264264
_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
265265
_logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
266+
vocab_bound = min((pid_n + 1) * vocab_per_split, vocab_size)
266267
for n in range(0, num_pid_n):
267268
start_offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N
268269
offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N)
@@ -308,12 +309,14 @@ def efficient_entropy_kernel_general_mainloop(
308309
# scale logits by temperature
309310
logits *= rcp_temperature
310311

312+
logits_for_lse = tl.where(offs_bn[None, :] < vocab_bound, logits, float("-inf"))
313+
311314
# update global maximum
312315
_max_old = _max
313-
m_pid_n = tl.max(logits, axis=1)
316+
m_pid_n = tl.max(logits_for_lse, axis=1)
314317
_max = tl.maximum(_max_old, m_pid_n)
315318

316-
exp_logits = tl.exp(logits - _max[:, None])
319+
exp_logits = tl.exp(logits_for_lse - _max[:, None])
317320
coeff = tl.exp(_max_old - _max)
318321
_accu = coeff * _accu + tl.sum(exp_logits, axis=1)
319322

0 commit comments

Comments
 (0)