@@ -883,9 +883,18 @@ rewards = rewards - self.beta * per_token_kl # Shape: (B*G, L)
883883# Get value predictions
884884values = value_net(completions) # Shape: (B*G, L)
885885
886- # Compute simple advantages
887- advantages = rewards - values.detach() # Shape: (B*G, L)
888- # Note: We detach the value network here to not update the parameters of
886+ # Compute returns via backward pass (gamma typically 1.0 for LM RLHF)
887+ # Mask rewards to avoid padding tokens (which may have KL penalties) leaking into returns
888+ returns = torch.zeros_like(rewards)
889+ running = torch.zeros(rewards.shape[0 ], device = rewards.device, dtype = rewards.dtype)
890+ for t in reversed (range (rewards.shape[1 ])):
891+ # Zero out padding: only accumulate rewards/returns for valid completion tokens
892+ running = (rewards[:, t] + self .gamma * running) * completion_mask[:, t]
893+ returns[:, t] = running
894+
895+ # Compute advantages: A_t = G_t - V(s_t)
896+ advantages = returns - values.detach() # Shape: (B*G, L)
897+ # Note: We detach the value network here to not update the parameters of
889898# the value function when computing the policy-gradient loss
890899
891900# Normalize advantages (optional but stable)
@@ -900,8 +909,8 @@ pg_losses1 = -advantages * ratio # Shape: (B*G, L)
900909pg_losses2 = - advantages * torch.clamp(ratio, 1.0 - eps, 1.0 + eps) # Shape: (B*G, L)
901910pg_loss_max = torch.max(pg_losses1, pg_losses2) # Shape: (B*G, L)
902911
903- # Simple value function loss
904- vf_loss = 0.5 * ((rewards - values) ** 2 ) # Shape: (B*G, L)
912+ # Value function loss: predict returns
913+ vf_loss = 0.5 * ((returns - values) ** 2 ) # Shape: (B*G, L)
905914
906915# Combine policy and value losses
907916per_token_loss = pg_loss_max + self .vf_coef * vf_loss # Shape: (B*G, L)
@@ -916,7 +925,7 @@ with torch.no_grad():
916925 clip_frac = ((pg_losses2 > pg_losses1).float() * completion_mask).sum() / completion_mask.sum()
917926
918927 # Compute approximate KL
919- approx_kl = 0.5 * ((new_per_token_logps - per_token_logps)** 2 ).mean ()
928+ approx_kl = ( 0.5 * ((new_per_token_logps - per_token_logps)** 2 ) * completion_mask).sum() / completion_mask.sum ()
920929
921930 # Compute value loss for logging
922931 value_loss = vf_loss.mean()
@@ -1004,7 +1013,7 @@ with torch.no_grad():
10041013 clip_frac = ((pg_losses2 > pg_losses1).float() * completion_mask).sum() / completion_mask.sum()
10051014
10061015 # Compute approximate KL
1007- approx_kl = 0.5 * ((new_per_token_logps - per_token_logps)** 2 ).mean ()
1016+ approx_kl = ( 0.5 * ((new_per_token_logps - per_token_logps)** 2 ) * completion_mask).sum() / completion_mask.sum ()
10081017```
10091018
10101019For more details on how to interpret this code, see the PPO section above. The core differences from the PPO example are:
0 commit comments