Skip to content

Commit a2608a2

Browse files
natolambertclaude
andauthored
Add diagram infrastructure for reward model visualizations (#206)
Co-authored-by: Claude Opus 4.5 <[email protected]>
1 parent 92e656c commit a2608a2

16 files changed

+1701
-1
lines changed

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,9 @@ arxiv_check_results.json
1818
.claude
1919

2020
# Python/uv
21-
uv.lock
21+
uv.lock
22+
23+
# Generated diagrams (regenerate with: cd diagrams && make all)
24+
diagrams/generated/
25+
diagrams/feedback*/
26+
images/*_tokens.png

CLAUDE.md

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# RLHF Book - Claude Code Context
2+
3+
## Project Overview
4+
5+
This is the source repository for "RLHF Book" by Nathan Lambert - a comprehensive guide to Reinforcement Learning from Human Feedback.
6+
7+
**Live site:** https://rlhfbook.com
8+
9+
## Build System
10+
11+
- **Pandoc + Make** for multi-format output (HTML, PDF, EPUB, DOCX)
12+
- Run `make` to build all formats
13+
- Run `make html` for just the HTML site
14+
- Dependencies: pandoc, pandoc-crossref, basictex (for PDF)
15+
16+
## Python Commands
17+
18+
**Always use `uv run python` instead of bare `python`** to ensure the correct virtual environment and dependencies:
19+
20+
```bash
21+
# Correct
22+
uv run python scripts/some_script.py
23+
uv run python -c "import matplotlib"
24+
25+
# Incorrect
26+
python scripts/some_script.py
27+
```
28+
29+
## Directory Structure
30+
31+
```
32+
chapters/ # Markdown source files (01-introduction.md, etc.)
33+
images/ # Image assets referenced in chapters
34+
assets/ # Brand assets (covers, logos)
35+
templates/ # Pandoc templates for each output format
36+
scripts/ # Build utilities
37+
diagrams/ # Diagram sources (D2, Python scripts, specs)
38+
build/ # Generated output (not tracked in git)
39+
```
40+
41+
## Image Conventions
42+
43+
- Place images in `images/` directory
44+
- Reference: `![Description](images/filename.png){#fig:label}`
45+
- Optional sizing: `{width=450px}`
46+
- Cross-reference with `@fig:label`
47+
48+
## Diagram Workflow
49+
50+
The `diagrams/` directory contains source files for generating figures:
51+
52+
1. **specs/** - YAML specifications defining diagram content
53+
2. **d2/** - D2 language sources for pipeline diagrams
54+
3. **scripts/** - Python scripts for token strip visualizations
55+
4. **generated/** - Intermediate outputs
56+
57+
Generate diagrams with:
58+
```bash
59+
cd diagrams && make all
60+
```
61+
62+
Then copy final versions to `images/` for use in chapters.
63+
64+
## Future: Multimodal Feedback Loop
65+
66+
Plan to integrate Gemini API for diagram feedback:
67+
- Pass math content + generated diagrams to Gemini 2.5 Pro
68+
- Get feedback on visual clarity, correctness, consistency
69+
- Iterate on mockups before artist handoff
70+
71+
Example workflow:
72+
```python
73+
# Pseudocode for diagram feedback
74+
import google.generativeai as genai
75+
76+
model = genai.GenerativeModel('gemini-2.5-pro')
77+
response = model.generate_content([
78+
"Review this reward model diagram for accuracy:",
79+
diagram_image,
80+
"The math should show: " + latex_formula,
81+
"Is this correct and clear?"
82+
])
83+
```
84+
85+
## Key Chapters for Diagrams
86+
87+
- **Chapter 7 (Reward Models)**: Bradley-Terry, ORM, PRM, Generative RM
88+
- **Chapter 11 (Policy Gradients)**: PPO visualizations, async vs sync training
89+
- **Chapter 12 (DPO)**: Direct alignment visualizations
90+
91+
## Style Notes
92+
93+
- Keep diagrams simple and artist-friendly
94+
- Use consistent visual grammar across related figures
95+
- Prefer SVG for scalability, PNG for final book assets
96+
- Mockups are iterative - not pixel-perfect
97+
98+
## Next Steps (Diagrams PR)
99+
100+
1. **Finalize diagrams** - Review and polish the multilane diagrams (ORM, Value Function)
101+
2. **Add diagrams to chapter text** - Insert figure references in `chapters/07-reward-models.md`
102+
3. **Add RLHF overview diagram** - Add the same RLHF diagram to the start of the RM chapter to highlight where RMs fit in the pipeline
103+
4. **Review PR** - Check over the full PR before merge

chapters/07-reward-models.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Later in this section we also compare these to Outcome Reward Models (ORMs), Pro
2626

2727
*Throughout this chapter, we use $x$ to denote prompts and $y$ to denote completions. This notation is common in the language model literature, where methods operate on full prompt-completion pairs rather than individual tokens.*
2828

29+
![The reward model in RLHF plays the role of the environment component that returns rewards in standard RL. The key difference is that in RLHF, we get to control and learn this reward function from human preferences, rather than having it fixed by the environment.](images/rlhf-overview.png){#fig:rm-role-in-rlhf}
30+
2931
## Training Reward Models
3032

3133
The canonical implementation of a reward model is derived from the Bradley-Terry model of preference [@BradleyTerry].
@@ -74,6 +76,8 @@ $$\mathcal{L}(\theta) = \log \left( 1 + e^{r_{\theta}(y_r \mid x) - r_{\theta}(y
7476
These are equivalent by letting $\Delta = r_{\theta}(y_c \mid x) - r_{\theta}(y_r \mid x)$ and using $\sigma(\Delta) = \frac{1}{1 + e^{-\Delta}}$, which implies $-\log\sigma(\Delta) = \log(1 + e^{-\Delta}) = \log\left(1 + e^{r_{\theta}(y_r \mid x) - r_{\theta}(y_c \mid x)}\right)$.
7577
They both appear in the RLHF literature.
7678
79+
![Training a preference reward model requires pairs of chosen and rejected completions. The model computes a scalar score at the end-of-sequence (EOS) token for each, and the contrastive loss depends only on the score difference between the two.](images/pref_rm_training.png){#fig:pref_rm_training}
80+
7781
## Architecture
7882
7983
The most common way reward models are implemented is through an abstraction similar to Transformer's `AutoModelForSequenceClassification`, which appends a small linear head to the language model that performs classification between two outcomes -- chosen and rejected.
@@ -287,6 +291,10 @@ The important intuition here is that an ORM will output a probability of correct
287291
This can be a noisy process, as the updates and loss propagates per token depending on outcomes and attention mappings.
288292
<!-- On the other hand, this process is more computationally intensive. [@cobbe2021gsm8k] posits a few potential benefits to these models, such as (1) implementation of ORMs often being done with both the standard next-token language modelling loss and the reward modelling loss above in @eq:orm_loss and (2) the ORM design as a token-level loss outperforms completion-level loss calculation used in standard RMs. -->
289293
294+
![At inference time, an outcome reward model outputs per-token correctness probabilities. Prompt tokens are masked (e.g., label=-100), while completion tokens each receive a probability indicating whether the model believes the response leads to a correct answer.](images/orm_inference.png){#fig:orm_inference}
295+
296+
![Training an outcome reward model uses offline labels from a verifier or dataset (e.g., all 1s for correct completions). Each completion token is trained with binary cross-entropy against the outcome label, and per-token probabilities are aggregated into a final score for verification, filtering, or reranking.](images/orm_training.png){#fig:orm_training}
297+
290298
These models have continued in use, but are less supported in open-source RLHF tools.
291299
For example, the same type of ORM was used in the seminal work *Let's Verify Step by Step* [@lightman2023let], but without the language modeling prediction piece of the loss.
292300
Then, the final loss is a cross-entropy loss on every token predicting if the final answer is correct.
@@ -323,6 +331,8 @@ Traditionally PRMs are trained with a language modeling head that outputs a toke
323331
These predictions tend to be -1 for incorrect, 0 for neutral, and 1 for correct.
324332
These labels do not necessarily tie with whether or not the model is on the right path, but if the step is correct.
325333
334+
![Process reward models provide supervision only at step boundaries (e.g., newline tokens). Each step receives a 3-class label: correct (+1), neutral (0), or incorrect (-1). All other tokens are masked during training.](images/prm_training_inference.png){#fig:prm_training_inference}
335+
326336
An example construction of a PRM is shown below.
327337
328338
```python
@@ -394,6 +404,54 @@ Some notes, given the above table has a lot of edge cases.
394404
- Both in preference tuning and reasoning training, the value functions often have a discount factor of 1, which makes a value function even closer to an outcome reward model, but with a different training loss.
395405
- A process reward model can be supervised by doing rollouts from an intermediate state and collecting outcome data. This blends multiple ideas, but if the *loss* is per reasoning step labels, it is best referred to as a PRM.
396406
407+
**ORM vs. Value Function: The key distinction.**
408+
ORMs and value functions can appear similar since both produce per-token outputs with the same head architecture, but they differ in *what they predict* and *where targets come from*:
409+
410+
- **ORMs** predict an immediate, token-local quantity: $p(\text{correct}_t)$ or $r_t$. Targets come from *offline labels* (a verifier or dataset marking tokens/sequences as correct or incorrect).
411+
- **Value functions** predict the expected *remaining* return: $V(s_t) = \mathbb{E}[\sum_{k \geq t} \gamma^{k-t} r_k \mid s_t]$. Targets are typically *computed from on-policy rollouts* under the current policy $\pi_\theta$, and change as the policy changes (technically, value functions can also be off-policy, but this is not established for work in language modeling).
412+
413+
If you define a dense token reward $r_t = \mathbb{1}[\text{token is correct}]$ and use $\gamma = 1$, then an ORM is learning $r_t$ (or $p(r_t = 1)$) while the value head is learning the remaining-sum $\sum_{k \geq t} r_k$.
414+
They can share the same base model and head dimensions, but the *semantics and supervision pipeline* differ: ORMs are trained offline from fixed labels, while value functions are trained on-policy and used to compute advantages $A_t = \hat{R}_t - V_t$ for policy gradients.
415+
416+
### Inference Differences
417+
418+
The models handled data differently at inference-time, i.e. once they've been trained, in order to handle a suite of tasks that RMs are used for.
419+
420+
**Bradley-Terry RM (Preference Model):**
421+
422+
- *Input:* prompt $x$ + candidate completion $y$
423+
- *Output:* single scalar $r_\theta(x, y)$ from EOS hidden state
424+
- *Usage:* rerank $k$ completions, pick top-1 (best-of-N sampling); or provide terminal reward for RLHF
425+
- *Aggregation:* Not needed with scalar outputs
426+
427+
**Outcome RM:**
428+
429+
- *Input:* prompt $x$ + completion $y$
430+
- *Output:* per-token probabilities $p_t \approx P(\text{correct at token } t)$ over completion tokens
431+
- *Usage:* score finished candidates; aggregate via mean, min (tail risk), or product $\sum_t \log p_t$
432+
- *Aggregation choices:* mean correctness, minimum $p_t$, average over last $m$ tokens, or threshold flagging if any $p_t < \tau$
433+
434+
**Process RM:**
435+
436+
- *Input:* prompt $x$ + reasoning trace with step boundaries
437+
- *Output:* scores at step boundaries (e.g., class logits for correct/neutral/incorrect)
438+
- *Usage:* score completed chain-of-thought; or guide search/decoding by pruning low-scoring branches
439+
- *Aggregation:* over steps (not tokens) — mean step score, minimum (fail-fast), or weighted sum favoring later steps
440+
441+
**Value Function:**
442+
443+
- *Input:* prompt $x$ + current prefix $y_{\leq t}$ (a state)
444+
- Output: $V_t$ at each token position in the completion (expected remaining return from state $t$)
445+
- Usage: compute per-token advantages $A_t = \hat{R}_t - V_t$ during RL training; the values at each step serve as baselines
446+
- *Aggregation:* typically take $V$ at the last generated token; interpretation differs from "probability of correctness"
447+
448+
In summary, the way to understand the different models is:
449+
450+
- **RM:** "How good is this whole answer?" → scalar value
451+
- **ORM:** "Which parts look correct?" → per-token correctness
452+
- **PRM:** "Are the reasoning steps sound?" → per-step scores
453+
- **Value:** "How much reward remains from here?" → baseline for RL advantages
454+
397455
## Generative Reward Modeling
398456
399457
With the cost of preference data, a large research area emerged to use existing language models as a judge of human preferences or in other evaluation settings [@zheng2023judging].

chapters/11-policy-gradients.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,8 @@ Generalized Advantage Estimation (GAE) is considered the state-of-the-art and ca
418418
A value function can also be learned with Monte Carlo estimates from the rollouts used to update the policy.
419419
PPO has two losses -- one to learn the value function and another to use that value function to update the policy.
420420
421+
![Value function training uses on-policy rollouts to compute targets. The model predicts $V_t$ at each token, which is trained via MSE against the target return $\hat{V}_t$. The advantage $A_t = \hat{V}_t - V_t$ then weights the policy gradient update.](images/value_fn_training.png){#fig:value_fn_training}
422+
421423
A simple example implementation of a value network loss is shown below.
422424
423425
```python

diagrams/Makefile

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Makefile for generating RLHF Book diagrams
2+
#
3+
# Usage:
4+
# make all - Generate all diagrams
5+
# make tokens - Generate token strip diagrams
6+
# make clean - Remove generated files
7+
8+
GENERATED_DIR := generated
9+
10+
.PHONY: all tokens clean help
11+
12+
all: tokens
13+
@echo "All diagrams generated in $(GENERATED_DIR)/"
14+
15+
# Token strip diagrams - requires matplotlib
16+
# Uses two generators:
17+
# - generate_token_strips.py for Pref RM and PRM (simple strips)
18+
# - generate_multilane_strips.py for ORM and Value (multi-lane with targets/usage)
19+
# Generates both PNG (digital) and SVG (print) in separate folders
20+
tokens: | $(GENERATED_DIR)
21+
@mkdir -p $(GENERATED_DIR)/png $(GENERATED_DIR)/svg
22+
uv run python scripts/generate_token_strips.py --output-dir $(GENERATED_DIR)/png --format png
23+
uv run python scripts/generate_token_strips.py --output-dir $(GENERATED_DIR)/svg --format svg
24+
uv run python scripts/generate_multilane_strips.py --output-dir $(GENERATED_DIR)/png --format png
25+
uv run python scripts/generate_multilane_strips.py --output-dir $(GENERATED_DIR)/svg --format svg
26+
@echo "Token strip diagrams generated (PNG in png/, SVG in svg/)"
27+
28+
# Create output directory
29+
$(GENERATED_DIR):
30+
mkdir -p $(GENERATED_DIR)
31+
32+
# Clean generated files
33+
clean:
34+
rm -rf $(GENERATED_DIR)/*
35+
@echo "Cleaned $(GENERATED_DIR)/"
36+
37+
# Help
38+
help:
39+
@echo "RLHF Book Diagram Generator"
40+
@echo ""
41+
@echo "Targets:"
42+
@echo " all - Generate all diagrams"
43+
@echo " tokens - Generate token strip diagrams"
44+
@echo " clean - Remove generated files"
45+
@echo ""
46+
@echo "Requirements:"
47+
@echo " - Python + matplotlib: uv add matplotlib"

diagrams/README.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Diagram Sources for RLHF Book
2+
3+
This directory contains source files for generating diagrams. These are **mockups** intended for iteration with coding tools, to be refined by a professional artist.
4+
5+
## Directory Structure
6+
7+
```
8+
diagrams/
9+
├── specs/ # YAML specifications for each diagram type
10+
├── d2/ # D2 diagram source files (box-and-arrow flows)
11+
├── scripts/ # Python scripts for generating token strips and other visuals
12+
├── generated/ # Intermediate outputs (SVG, PNG before final placement)
13+
└── README.md # This file
14+
```
15+
16+
## Workflow
17+
18+
1. **Edit specs** in `specs/` to define the conceptual content
19+
2. **Generate diagrams** using the scripts or D2 CLI
20+
3. **Review outputs** in `generated/`
21+
4. **Copy final versions** to `images/` for use in the book
22+
5. **Commit both sources and outputs** for reproducibility
23+
24+
## Tooling Requirements
25+
26+
### D2 (for pipeline diagrams)
27+
28+
Install D2: https://d2lang.com/tour/install
29+
30+
```bash
31+
# macOS
32+
brew install d2
33+
34+
# or via script
35+
curl -fsSL https://d2lang.com/install.sh | sh -s --
36+
```
37+
38+
Generate SVG/PNG:
39+
```bash
40+
d2 d2/pref_rm_pipeline.d2 generated/pref_rm_pipeline.svg
41+
d2 d2/pref_rm_pipeline.d2 generated/pref_rm_pipeline.png
42+
```
43+
44+
### Python (for token strip visuals)
45+
46+
Dependencies (matplotlib) are managed via uv:
47+
```bash
48+
uv add matplotlib # if not already installed
49+
```
50+
51+
Generate token strips:
52+
```bash
53+
uv run python scripts/generate_token_strips.py
54+
```
55+
56+
## Generating All Diagrams
57+
58+
```bash
59+
# From repo root
60+
cd diagrams && make all
61+
62+
# Or just token strips (doesn't require D2)
63+
cd diagrams && make tokens
64+
65+
# Copy generated diagrams to images/
66+
cd diagrams && make install
67+
```
68+
69+
## Diagram Types
70+
71+
### 1. Pipeline Diagrams (D2)
72+
Box-and-arrow flows showing: Data → Model → Output → Loss
73+
74+
- `pref_rm_pipeline.d2` - Bradley-Terry Preference RM
75+
- `orm_pipeline.d2` - Outcome RM
76+
- `prm_pipeline.d2` - Process RM
77+
- `gen_rm_pipeline.d2` - Generative RM / LLM-as-Judge
78+
79+
### 2. Token Strip Visualizations (Python)
80+
Horizontal token sequences showing where supervision attaches:
81+
82+
- Preference RM: highlight EOS/last token only
83+
- ORM: highlight all completion tokens (prompt masked)
84+
- PRM: highlight step boundary tokens only
85+
- Value function: highlight all tokens (state values)
86+
87+
### 3. Inference Usage Diagrams (D2)
88+
Simple flows showing how each RM type is used at inference time.
89+
90+
## Handoff to Artist
91+
92+
When ready for professional refinement:
93+
94+
1. Export all diagrams as SVG
95+
2. Provide the YAML specs as semantic documentation
96+
3. Include a style guide (fonts, colors, stroke widths)
97+
4. Use consistent naming: `fig_rm_{type}_{variant}.svg`

0 commit comments

Comments
 (0)