-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[training_utils] refactor: Extend response slicing to handle multi-dimensional model outputs #4964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||||||||||||||||||||||||||||||||||
| from tensordict import TensorDict | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| from verl.utils import tensordict_utils as tu | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -104,3 +105,45 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| values = full_values.squeeze(-1)[:, -max_response_len - 1 : -1] # (bsz, response_length) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| return values | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def _slice_response_from_unpad_output(tensor: torch.Tensor, data: TensorDict) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||
| """Slice response from unpad model output. | ||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||
| tensor: model output tensor of shape [total_tokens, *] or NestedTensor of shape [bsz, prompt_len + response_len, *] | ||||||||||||||||||||||||||||||||||||||||||||||||
| data: TensorDict with "prompts", "responses", "attention_mask" | ||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||
| tensor: sliced response tensor of shape [bsz, max_response_len, *] | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
| values = tensor.values() if tensor.is_nested else tensor | ||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_ids = data["prompts"] | ||||||||||||||||||||||||||||||||||||||||||||||||
| response_ids = data["responses"] | ||||||||||||||||||||||||||||||||||||||||||||||||
| attention_mask = data["attention_mask"] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if prompt_ids.is_nested: | ||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_lens = prompt_ids.offsets().diff() | ||||||||||||||||||||||||||||||||||||||||||||||||
| response_lens = response_ids.offsets().diff() | ||||||||||||||||||||||||||||||||||||||||||||||||
| max_response_len = response_lens.max().item() | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| assert not attention_mask.is_nested | ||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| max_response_len = response_ids.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+125
to
+133
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function does not handle empty batches, which will cause a runtime error. If To improve robustness and prevent crashes, it's important to handle this edge case. I suggest adding a check for empty batches at the start of this logic block and returning an appropriately shaped empty tensor.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| sequence_lens = prompt_lens + response_lens | ||||||||||||||||||||||||||||||||||||||||||||||||
| sequence_offsets = sequence_lens.cumsum(dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||
| assert sequence_offsets[-1].item() == values.shape[0] | ||||||||||||||||||||||||||||||||||||||||||||||||
| assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}" | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| response_list = [] | ||||||||||||||||||||||||||||||||||||||||||||||||
| # Skip padding dimensions after sequence dimensions, if any. | ||||||||||||||||||||||||||||||||||||||||||||||||
| skip_padding = (0, 0) * (values.ndim - 1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| for resp_len, seq_offset in zip(response_lens, sequence_offsets, strict=True): | ||||||||||||||||||||||||||||||||||||||||||||||||
| pad_size = max_response_len - resp_len | ||||||||||||||||||||||||||||||||||||||||||||||||
| # left-shift model output by one token for log_probs/values | ||||||||||||||||||||||||||||||||||||||||||||||||
| response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (*skip_padding, 0, pad_size))) | ||||||||||||||||||||||||||||||||||||||||||||||||
JacobHelwig marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| output = torch.stack(response_list, dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.