Skip to content

Conversation

@JacobHelwig
Copy link
Collaborator

@JacobHelwig JacobHelwig commented Jan 17, 2026

What does this PR do?

  • Moves _slice_response_from_unpad_output outside of verl.workers.utils.losses so that modules imported by verl.workers.utils.losses can import _slice_response_from_unpad_output without circular import
  • Extends _slice_response_from_unpad_output to multi-dimensional tensors (e.g., instead of log_probs of shape (S,), topk_log_probs of shape (S, K).

Both changes are used by #4897 for computing top-k distillation loss.

Design & Code Changes

Mainly a re-factor; extension to multi-dimensional tensors are these lines:

    # Skip padding dimensions after sequence dimensions, if any.
    skip_padding = (0, 0) * (values.ndim - 1)
    for resp_len, seq_offset in zip(response_lens, sequence_offsets, strict=True):
        pad_size = max_response_len - resp_len
        # left-shift model output by one token for log_probs/values
        response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (*skip_padding, 0, pad_size)))

Additionally corrects computation of max_response_len:

if prompt_ids.is_nested:
prompt_lens = prompt_ids.offsets().diff()
response_lens = response_ids.offsets().diff()
max_response_len = response_ids.offsets().max().item()

Also adds check on non-empty prompt:

assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}"

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the _slice_response_from_unpad_output function by moving it to a different module to avoid circular dependencies and extends its functionality to handle multi-dimensional tensors. The extension for multi-dimensional tensors appears correct. However, a critical issue has been introduced: the module where the function is now used (verl/workers/utils/losses.py) is missing the necessary import, which will lead to a runtime error. I have also provided a suggestion to improve the accuracy of the function's docstring for better maintainability.

@JacobHelwig
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors _slice_response_from_unpad_output by moving it to verl.workers.utils.padding to resolve a circular dependency and extends it to support multi-dimensional tensors. The changes are logical and well-structured for the intended purpose. However, I've identified a critical pre-existing bug in the moved function related to the calculation of max_response_len for nested tensors, which could lead to incorrect behavior and potential crashes. I've provided a suggestion to fix it.

if prompt_ids.is_nested:
prompt_lens = prompt_ids.offsets().diff()
response_lens = response_ids.offsets().diff()
max_response_len = response_ids.offsets().max().item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a critical bug in how max_response_len is calculated for nested tensors. response_ids.offsets().max().item() computes the total number of tokens in all responses in the batch, not the maximum length of a single response. This will lead to incorrect padding sizes, likely causing out-of-memory errors or producing tensors with incorrect shapes.

To fix this, you should calculate the maximum of the individual response lengths.

Suggested change
max_response_len = response_ids.offsets().max().item()
max_response_len = response_ids.offsets().diff().max().item()

@JacobHelwig
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors _slice_response_from_unpad_output by moving it to a new file and extending its functionality to handle multi-dimensional tensors. While the refactoring and extension are well-motivated, I've identified a critical bug in the slicing logic that can lead to incorrect memory access when processing sequences with empty prompts. I have provided a detailed comment with a suggested fix to address this issue.

@JacobHelwig
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the _slice_response_from_unpad_output function by moving it to a separate utility file to resolve circular dependencies and extends its functionality to support multi-dimensional model outputs. The changes are well-structured and include a necessary bug fix for calculating max_response_len and an assertion to enforce the assumption of non-empty prompts. My review includes one high-severity suggestion to handle empty batches gracefully, which would currently cause a runtime error.

Comment on lines +125 to +133
if prompt_ids.is_nested:
prompt_lens = prompt_ids.offsets().diff()
response_lens = response_ids.offsets().diff()
max_response_len = response_lens.max().item()
else:
assert not attention_mask.is_nested
prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1)
response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1)
max_response_len = response_ids.shape[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function does not handle empty batches, which will cause a runtime error. If prompt_ids represents an empty batch, response_lens will be an empty tensor, and calling .max() on it will raise an exception. For non-nested tensors, an empty batch will lead to an IndexError on sequence_offsets later on.

To improve robustness and prevent crashes, it's important to handle this edge case. I suggest adding a check for empty batches at the start of this logic block and returning an appropriately shaped empty tensor.

Suggested change
if prompt_ids.is_nested:
prompt_lens = prompt_ids.offsets().diff()
response_lens = response_ids.offsets().diff()
max_response_len = response_lens.max().item()
else:
assert not attention_mask.is_nested
prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1)
response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1)
max_response_len = response_ids.shape[1]
if (prompt_ids.is_nested and not prompt_ids.numel()) or (not prompt_ids.is_nested and prompt_ids.shape[0] == 0):
assert values.numel() == 0, "Non-empty values with empty batch"
max_response_len = 0 if prompt_ids.is_nested else response_ids.shape[1]
return torch.empty(0, max_response_len, *values.shape[1:], device=values.device, dtype=values.dtype)
if prompt_ids.is_nested:
prompt_lens = prompt_ids.offsets().diff()
response_lens = response_ids.offsets().diff()
max_response_len = response_lens.max().item()
else:
assert not attention_mask.is_nested
prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1)
response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1)
max_response_len = response_ids.shape[1]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant