File tree Expand file tree Collapse file tree 1 file changed +7
-2
lines changed
verl/trainer/distillation Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -159,7 +159,8 @@ def distillation_loss(
159159 compute_distillation_loss_range (distillation_losses = distillation_losses , response_mask = response_mask )
160160 )
161161 if loss_config .loss_max_clamp is not None :
162- distillation_losses = distillation_losses .clamp_max (loss_config .loss_max_clamp )
162+ # clamping min is for k1 loss which can be negative
163+ distillation_losses = distillation_losses .clamp (min = - loss_config .loss_max_clamp , max = loss_config .loss_max_clamp )
163164
164165 if loss_config .use_policy_gradient :
165166 # Use negative distillation loss as reward, as done by https://thinkingmachines.ai/blog/on-policy-distillation/.
@@ -298,4 +299,8 @@ def compute_distillation_loss_reverse_kl_estimator(
298299 distillation_losses = kl_penalty (
299300 logprob = student_log_probs , ref_logprob = teacher_log_probs , kl_penalty = loss_config .loss_mode
300301 )
301- return distillation_losses , {}
302+ # Since k1 can be negative, log the mean absolute loss.
303+ metrics = {
304+ "distillation/abs_loss" : Metric (AggregationType .MEAN , distillation_losses [response_mask ].abs ().mean ()),
305+ }
306+ return distillation_losses , metrics
You can’t perform that action at this time.
0 commit comments