Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 1 addition & 40 deletions verl/workers/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


import torch
import torch.nn.functional as F
from tensordict import TensorDict

from verl.trainer.ppo.core_algos import agg_loss, compute_value_loss, get_policy_loss_fn, kl_penalty
Expand All @@ -23,6 +22,7 @@
from verl.utils.metric import AggregationType, Metric
from verl.utils.torch_functional import masked_mean, masked_sum
from verl.workers.config import ActorConfig, CriticConfig
from verl.workers.utils.padding import _slice_response_from_unpad_output


def sft_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None):
Expand Down Expand Up @@ -54,45 +54,6 @@ def sft_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None)
return loss, {}


def _slice_response_from_unpad_output(tensor: torch.Tensor, data: TensorDict) -> torch.Tensor:
"""Slice response from unpad model output.

Args:
tensor: model output tensor of shape [bsz, 1]
data: TensorDict with "prompt_ids", "response_ids", "attention_mask"

Returns:
tensor: sliced response tensor of shape [bsz, max_response_len]
"""
values = tensor.values() if tensor.is_nested else tensor
prompt_ids = data["prompts"]
response_ids = data["responses"]
attention_mask = data["attention_mask"]

if prompt_ids.is_nested:
prompt_lens = prompt_ids.offsets().diff()
response_lens = response_ids.offsets().diff()
max_response_len = response_ids.offsets().max().item()
else:
assert not attention_mask.is_nested
prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1)
response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1)
max_response_len = response_ids.shape[1]

sequence_lens = prompt_lens + response_lens
sequence_offsets = sequence_lens.cumsum(dim=0)
assert sequence_offsets[-1].item() == values.shape[0]

response_list = []
for resp_len, seq_offset in zip(response_lens, sequence_offsets, strict=True):
pad_size = max_response_len - resp_len
# left-shift model output by one token for log_probs/values
response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (0, pad_size)))

output = torch.stack(response_list, dim=0)
return output


def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None):
log_prob = _slice_response_from_unpad_output(model_output["log_probs"], data)
entropy = model_output.get("entropy", None)
Expand Down
43 changes: 43 additions & 0 deletions verl/workers/utils/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import torch
import torch.nn.functional as F
from tensordict import TensorDict

from verl.utils import tensordict_utils as tu
Expand Down Expand Up @@ -104,3 +105,45 @@
values = full_values.squeeze(-1)[:, -max_response_len - 1 : -1] # (bsz, response_length)

return values


def _slice_response_from_unpad_output(tensor: torch.Tensor, data: TensorDict) -> torch.Tensor:
"""Slice response from unpad model output.
Args:
tensor: model output tensor of shape [total_tokens, *] or NestedTensor of shape [bsz, prompt_len + response_len, *]

Check failure on line 114 in verl/workers/utils/padding.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (E501)

verl/workers/utils/padding.py:114:121: E501 Line too long (123 > 120)
data: TensorDict with "prompts", "responses", "attention_mask"
Returns:
tensor: sliced response tensor of shape [bsz, max_response_len, *]
"""
values = tensor.values() if tensor.is_nested else tensor
prompt_ids = data["prompts"]
response_ids = data["responses"]
attention_mask = data["attention_mask"]

if prompt_ids.is_nested:
prompt_lens = prompt_ids.offsets().diff()
response_lens = response_ids.offsets().diff()
max_response_len = response_lens.max().item()
else:
assert not attention_mask.is_nested
prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1)
response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1)
max_response_len = response_ids.shape[1]
Comment on lines +125 to +133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function does not handle empty batches, which will cause a runtime error. If prompt_ids represents an empty batch, response_lens will be an empty tensor, and calling .max() on it will raise an exception. For non-nested tensors, an empty batch will lead to an IndexError on sequence_offsets later on.

To improve robustness and prevent crashes, it's important to handle this edge case. I suggest adding a check for empty batches at the start of this logic block and returning an appropriately shaped empty tensor.

Suggested change
if prompt_ids.is_nested:
prompt_lens = prompt_ids.offsets().diff()
response_lens = response_ids.offsets().diff()
max_response_len = response_lens.max().item()
else:
assert not attention_mask.is_nested
prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1)
response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1)
max_response_len = response_ids.shape[1]
if (prompt_ids.is_nested and not prompt_ids.numel()) or (not prompt_ids.is_nested and prompt_ids.shape[0] == 0):
assert values.numel() == 0, "Non-empty values with empty batch"
max_response_len = 0 if prompt_ids.is_nested else response_ids.shape[1]
return torch.empty(0, max_response_len, *values.shape[1:], device=values.device, dtype=values.dtype)
if prompt_ids.is_nested:
prompt_lens = prompt_ids.offsets().diff()
response_lens = response_ids.offsets().diff()
max_response_len = response_lens.max().item()
else:
assert not attention_mask.is_nested
prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1)
response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1)
max_response_len = response_ids.shape[1]


sequence_lens = prompt_lens + response_lens
sequence_offsets = sequence_lens.cumsum(dim=0)
assert sequence_offsets[-1].item() == values.shape[0]
assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}"

response_list = []
# Skip padding dimensions after sequence dimensions, if any.
skip_padding = (0, 0) * (values.ndim - 1)
for resp_len, seq_offset in zip(response_lens, sequence_offsets, strict=True):
pad_size = max_response_len - resp_len
# left-shift model output by one token for log_probs/values
response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (*skip_padding, 0, pad_size)))

output = torch.stack(response_list, dim=0)
return output
Loading