Skip to content

Commit 5a1f43b

Browse files
natolambertclaude
andauthored
Fix PPO implementation bugs: returns and KL masking (#216)
Co-authored-by: Claude Opus 4.5 <[email protected]>
1 parent 9137ab2 commit 5a1f43b

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

chapters/11-policy-gradients.md

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -883,9 +883,18 @@ rewards = rewards - self.beta * per_token_kl # Shape: (B*G, L)
883883
# Get value predictions
884884
values = value_net(completions) # Shape: (B*G, L)
885885

886-
# Compute simple advantages
887-
advantages = rewards - values.detach() # Shape: (B*G, L)
888-
# Note: We detach the value network here to not update the parameters of
886+
# Compute returns via backward pass (gamma typically 1.0 for LM RLHF)
887+
# Mask rewards to avoid padding tokens (which may have KL penalties) leaking into returns
888+
returns = torch.zeros_like(rewards)
889+
running = torch.zeros(rewards.shape[0], device=rewards.device, dtype=rewards.dtype)
890+
for t in reversed(range(rewards.shape[1])):
891+
# Zero out padding: only accumulate rewards/returns for valid completion tokens
892+
running = (rewards[:, t] + self.gamma * running) * completion_mask[:, t]
893+
returns[:, t] = running
894+
895+
# Compute advantages: A_t = G_t - V(s_t)
896+
advantages = returns - values.detach() # Shape: (B*G, L)
897+
# Note: We detach the value network here to not update the parameters of
889898
# the value function when computing the policy-gradient loss
890899

891900
# Normalize advantages (optional but stable)
@@ -900,8 +909,8 @@ pg_losses1 = -advantages * ratio # Shape: (B*G, L)
900909
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - eps, 1.0 + eps) # Shape: (B*G, L)
901910
pg_loss_max = torch.max(pg_losses1, pg_losses2) # Shape: (B*G, L)
902911

903-
# Simple value function loss
904-
vf_loss = 0.5 * ((rewards - values) ** 2) # Shape: (B*G, L)
912+
# Value function loss: predict returns
913+
vf_loss = 0.5 * ((returns - values) ** 2) # Shape: (B*G, L)
905914

906915
# Combine policy and value losses
907916
per_token_loss = pg_loss_max + self.vf_coef * vf_loss # Shape: (B*G, L)
@@ -916,7 +925,7 @@ with torch.no_grad():
916925
clip_frac = ((pg_losses2 > pg_losses1).float() * completion_mask).sum() / completion_mask.sum()
917926

918927
# Compute approximate KL
919-
approx_kl = 0.5 * ((new_per_token_logps - per_token_logps)**2).mean()
928+
approx_kl = (0.5 * ((new_per_token_logps - per_token_logps)**2) * completion_mask).sum() / completion_mask.sum()
920929

921930
# Compute value loss for logging
922931
value_loss = vf_loss.mean()
@@ -1004,7 +1013,7 @@ with torch.no_grad():
10041013
clip_frac = ((pg_losses2 > pg_losses1).float() * completion_mask).sum() / completion_mask.sum()
10051014

10061015
# Compute approximate KL
1007-
approx_kl = 0.5 * ((new_per_token_logps - per_token_logps)**2).mean()
1016+
approx_kl = (0.5 * ((new_per_token_logps - per_token_logps)**2) * completion_mask).sum() / completion_mask.sum()
10081017
```
10091018

10101019
For more details on how to interpret this code, see the PPO section above. The core differences from the PPO example are:

0 commit comments

Comments
 (0)