-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[training_utils] refactor: Extend response slicing to handle multi-dimensional model outputs #4964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the _slice_response_from_unpad_output function by moving it to a different module to avoid circular dependencies and extends its functionality to handle multi-dimensional tensors. The extension for multi-dimensional tensors appears correct. However, a critical issue has been introduced: the module where the function is now used (verl/workers/utils/losses.py) is missing the necessary import, which will lead to a runtime error. I have also provided a suggestion to improve the accuracy of the function's docstring for better maintainability.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors _slice_response_from_unpad_output by moving it to verl.workers.utils.padding to resolve a circular dependency and extends it to support multi-dimensional tensors. The changes are logical and well-structured for the intended purpose. However, I've identified a critical pre-existing bug in the moved function related to the calculation of max_response_len for nested tensors, which could lead to incorrect behavior and potential crashes. I've provided a suggestion to fix it.
verl/workers/utils/padding.py
Outdated
| if prompt_ids.is_nested: | ||
| prompt_lens = prompt_ids.offsets().diff() | ||
| response_lens = response_ids.offsets().diff() | ||
| max_response_len = response_ids.offsets().max().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a critical bug in how max_response_len is calculated for nested tensors. response_ids.offsets().max().item() computes the total number of tokens in all responses in the batch, not the maximum length of a single response. This will lead to incorrect padding sizes, likely causing out-of-memory errors or producing tensors with incorrect shapes.
To fix this, you should calculate the maximum of the individual response lengths.
| max_response_len = response_ids.offsets().max().item() | |
| max_response_len = response_ids.offsets().diff().max().item() |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors _slice_response_from_unpad_output by moving it to a new file and extending its functionality to handle multi-dimensional tensors. While the refactoring and extension are well-motivated, I've identified a critical bug in the slicing logic that can lead to incorrect memory access when processing sequences with empty prompts. I have provided a detailed comment with a suggested fix to address this issue.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the _slice_response_from_unpad_output function by moving it to a separate utility file to resolve circular dependencies and extends its functionality to support multi-dimensional model outputs. The changes are well-structured and include a necessary bug fix for calculating max_response_len and an assertion to enforce the assumption of non-empty prompts. My review includes one high-severity suggestion to handle empty batches gracefully, which would currently cause a runtime error.
| if prompt_ids.is_nested: | ||
| prompt_lens = prompt_ids.offsets().diff() | ||
| response_lens = response_ids.offsets().diff() | ||
| max_response_len = response_lens.max().item() | ||
| else: | ||
| assert not attention_mask.is_nested | ||
| prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1) | ||
| response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1) | ||
| max_response_len = response_ids.shape[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function does not handle empty batches, which will cause a runtime error. If prompt_ids represents an empty batch, response_lens will be an empty tensor, and calling .max() on it will raise an exception. For non-nested tensors, an empty batch will lead to an IndexError on sequence_offsets later on.
To improve robustness and prevent crashes, it's important to handle this edge case. I suggest adding a check for empty batches at the start of this logic block and returning an appropriately shaped empty tensor.
| if prompt_ids.is_nested: | |
| prompt_lens = prompt_ids.offsets().diff() | |
| response_lens = response_ids.offsets().diff() | |
| max_response_len = response_lens.max().item() | |
| else: | |
| assert not attention_mask.is_nested | |
| prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1) | |
| response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1) | |
| max_response_len = response_ids.shape[1] | |
| if (prompt_ids.is_nested and not prompt_ids.numel()) or (not prompt_ids.is_nested and prompt_ids.shape[0] == 0): | |
| assert values.numel() == 0, "Non-empty values with empty batch" | |
| max_response_len = 0 if prompt_ids.is_nested else response_ids.shape[1] | |
| return torch.empty(0, max_response_len, *values.shape[1:], device=values.device, dtype=values.dtype) | |
| if prompt_ids.is_nested: | |
| prompt_lens = prompt_ids.offsets().diff() | |
| response_lens = response_ids.offsets().diff() | |
| max_response_len = response_lens.max().item() | |
| else: | |
| assert not attention_mask.is_nested | |
| prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1) | |
| response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1) | |
| max_response_len = response_ids.shape[1] |
What does this PR do?
_slice_response_from_unpad_outputoutside ofverl.workers.utils.lossesso that modules imported byverl.workers.utils.lossescan import_slice_response_from_unpad_outputwithout circular import_slice_response_from_unpad_outputto multi-dimensional tensors (e.g., instead oflog_probsof shape(S,),topk_log_probsof shape(S, K).Both changes are used by #4897 for computing top-k distillation loss.
Design & Code Changes
Mainly a re-factor; extension to multi-dimensional tensors are these lines:
Additionally corrects computation of
max_response_len:verl/verl/workers/utils/losses.py
Lines 72 to 75 in 65eb5a1
Also adds check on non-empty prompt: