Skip to content

Commit 4441c2a

Browse files
committed
k1 metrics and clamping
1 parent 2c67b1d commit 4441c2a

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

verl/trainer/distillation/losses.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def distillation_loss(
159159
compute_distillation_loss_range(distillation_losses=distillation_losses, response_mask=response_mask)
160160
)
161161
if loss_config.loss_max_clamp is not None:
162-
distillation_losses = distillation_losses.clamp_max(loss_config.loss_max_clamp)
162+
# clamping min is for k1 loss which can be negative
163+
distillation_losses = distillation_losses.clamp(min=-loss_config.loss_max_clamp, max=loss_config.loss_max_clamp)
163164

164165
if loss_config.use_policy_gradient:
165166
# Use negative distillation loss as reward, as done by https://thinkingmachines.ai/blog/on-policy-distillation/.
@@ -298,4 +299,8 @@ def compute_distillation_loss_reverse_kl_estimator(
298299
distillation_losses = kl_penalty(
299300
logprob=student_log_probs, ref_logprob=teacher_log_probs, kl_penalty=loss_config.loss_mode
300301
)
301-
return distillation_losses, {}
302+
# Since k1 can be negative, log the mean absolute loss.
303+
metrics = {
304+
"distillation/abs_loss": Metric(AggregationType.MEAN, distillation_losses[response_mask].abs().mean()),
305+
}
306+
return distillation_losses, metrics

0 commit comments

Comments
 (0)