You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: chapters/07-reward-models.md
+58Lines changed: 58 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -26,6 +26,8 @@ Later in this section we also compare these to Outcome Reward Models (ORMs), Pro
26
26
27
27
*Throughout this chapter, we use $x$ to denote prompts and $y$ to denote completions. This notation is common in the language model literature, where methods operate on full prompt-completion pairs rather than individual tokens.*
28
28
29
+
{#fig:rm-role-in-rlhf}
30
+
29
31
## Training Reward Models
30
32
31
33
The canonical implementation of a reward model is derived from the Bradley-Terry model of preference [@BradleyTerry].
These are equivalent by letting $\Delta = r_{\theta}(y_c \mid x) - r_{\theta}(y_r \mid x)$ and using $\sigma(\Delta) = \frac{1}{1 + e^{-\Delta}}$, which implies $-\log\sigma(\Delta) = \log(1 + e^{-\Delta}) = \log\left(1 + e^{r_{\theta}(y_r \mid x) - r_{\theta}(y_c \mid x)}\right)$.
75
77
They both appear in the RLHF literature.
76
78
79
+
{#fig:pref_rm_training}
80
+
77
81
## Architecture
78
82
79
83
The most common way reward models are implemented is through an abstraction similar to Transformer's `AutoModelForSequenceClassification`, which appends a small linear head to the language model that performs classification between two outcomes -- chosen and rejected.
@@ -287,6 +291,10 @@ The important intuition here is that an ORM will output a probability of correct
287
291
This can be a noisy process, as the updates and loss propagates per token depending on outcomes and attention mappings.
288
292
<!-- On the other hand, this process is more computationally intensive. [@cobbe2021gsm8k] posits a few potential benefits to these models, such as (1) implementation of ORMs often being done with both the standard next-token language modelling loss and the reward modelling loss above in @eq:orm_loss and (2) the ORM design as a token-level loss outperforms completion-level loss calculation used in standard RMs. -->
289
293
294
+
{#fig:orm_inference}
295
+
296
+
{#fig:orm_training}
297
+
290
298
These models have continued in use, but are less supported in open-source RLHF tools.
291
299
For example, the same type of ORM was used in the seminal work *Let's Verify Step by Step* [@lightman2023let], but without the language modeling prediction piece of the loss.
292
300
Then, the final loss is a cross-entropy loss on every token predicting if the final answer is correct.
@@ -323,6 +331,8 @@ Traditionally PRMs are trained with a language modeling head that outputs a toke
323
331
These predictions tend to be -1 for incorrect, 0 for neutral, and 1 for correct.
324
332
These labels do not necessarily tie with whether or not the model is on the right path, but if the step is correct.
325
333
334
+
{#fig:prm_training_inference}
335
+
326
336
An example construction of a PRM is shown below.
327
337
328
338
```python
@@ -394,6 +404,54 @@ Some notes, given the above table has a lot of edge cases.
394
404
- Both in preference tuning and reasoning training, the value functions often have a discount factor of 1, which makes a value function even closer to an outcome reward model, but with a different training loss.
395
405
- A process reward model can be supervised by doing rollouts from an intermediate state and collecting outcome data. This blends multiple ideas, but if the *loss* is per reasoning step labels, it is best referred to as a PRM.
396
406
407
+
**ORM vs. Value Function: The key distinction.**
408
+
ORMs and value functions can appear similar since both produce per-token outputs with the same head architecture, but they differ in *what they predict* and *where targets come from*:
409
+
410
+
- **ORMs** predict an immediate, token-local quantity: $p(\text{correct}_t)$ or $r_t$. Targets come from *offline labels* (a verifier or dataset marking tokens/sequences as correct or incorrect).
411
+
- **Value functions** predict the expected *remaining* return: $V(s_t) = \mathbb{E}[\sum_{k \geq t} \gamma^{k-t} r_k \mid s_t]$. Targets are typically *computed from on-policy rollouts* under the current policy $\pi_\theta$, and change as the policy changes (technically, value functions can also be off-policy, but this is not established for work in language modeling).
412
+
413
+
If you define a dense token reward $r_t = \mathbb{1}[\text{token is correct}]$ and use $\gamma = 1$, then an ORM is learning $r_t$ (or $p(r_t = 1)$) while the value head is learning the remaining-sum $\sum_{k \geq t} r_k$.
414
+
They can share the same base model and head dimensions, but the *semantics and supervision pipeline* differ: ORMs are trained offline from fixed labels, while value functions are trained on-policy and used to compute advantages $A_t = \hat{R}_t - V_t$ for policy gradients.
415
+
416
+
### Inference Differences
417
+
418
+
The models handled data differently at inference-time, i.e. once they've been trained, in order to handle a suite of tasks that RMs are used for.
419
+
420
+
**Bradley-Terry RM (Preference Model):**
421
+
422
+
- *Input:* prompt $x$ + candidate completion $y$
423
+
- *Output:* single scalar $r_\theta(x, y)$ from EOS hidden state
424
+
- *Usage:* rerank $k$ completions, pick top-1 (best-of-N sampling); or provide terminal reward for RLHF
425
+
- *Aggregation:* Not needed with scalar outputs
426
+
427
+
**Outcome RM:**
428
+
429
+
- *Input:* prompt $x$ + completion $y$
430
+
- *Output:* per-token probabilities $p_t \approx P(\text{correct at token } t)$ over completion tokens
431
+
- *Usage:* score finished candidates; aggregate via mean, min (tail risk), or product $\sum_t \log p_t$
432
+
- *Aggregation choices:* mean correctness, minimum $p_t$, average over last $m$ tokens, or threshold flagging if any $p_t < \tau$
433
+
434
+
**Process RM:**
435
+
436
+
- *Input:* prompt $x$ + reasoning trace with step boundaries
437
+
- *Output:* scores at step boundaries (e.g., class logits for correct/neutral/incorrect)
438
+
- *Usage:* score completed chain-of-thought; or guide search/decoding by pruning low-scoring branches
439
+
- *Aggregation:* over steps (not tokens) — mean step score, minimum (fail-fast), or weighted sum favoring later steps
440
+
441
+
**Value Function:**
442
+
443
+
- *Input:* prompt $x$ + current prefix $y_{\leq t}$ (a state)
444
+
- Output: $V_t$ at each token position in the completion (expected remaining return from state $t$)
445
+
- Usage: compute per-token advantages $A_t = \hat{R}_t - V_t$ during RL training; the values at each step serve as baselines
446
+
- *Aggregation:* typically take $V$ at the last generated token; interpretation differs from "probability of correctness"
447
+
448
+
In summary, the way to understand the different models is:
449
+
450
+
- **RM:** "How good is this whole answer?" → scalar value
451
+
- **ORM:** "Which parts look correct?" → per-token correctness
452
+
- **PRM:** "Are the reasoning steps sound?" → per-step scores
453
+
- **Value:** "How much reward remains from here?" → baseline for RL advantages
454
+
397
455
## Generative Reward Modeling
398
456
399
457
With the cost of preference data, a large research area emerged to use existing language models as a judge of human preferences or in other evaluation settings [@zheng2023judging].
Copy file name to clipboardExpand all lines: chapters/11-policy-gradients.md
+2Lines changed: 2 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -418,6 +418,8 @@ Generalized Advantage Estimation (GAE) is considered the state-of-the-art and ca
418
418
A value function can also be learned with Monte Carlo estimates from the rollouts used to update the policy.
419
419
PPO has two losses -- one to learn the value function and another to use that value function to update the policy.
420
420
421
+
{#fig:value_fn_training}
422
+
421
423
A simple example implementation of a value network loss is shown below.
This directory contains source files for generating diagrams. These are **mockups** intended for iteration with coding tools, to be refined by a professional artist.
4
+
5
+
## Directory Structure
6
+
7
+
```
8
+
diagrams/
9
+
├── specs/ # YAML specifications for each diagram type
0 commit comments