-
Notifications
You must be signed in to change notification settings - Fork 13
Final EKFAC implementation #123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
51cc6d9 to
5d1a583
Compare
Code ReviewI've found 5 high-signal issues in this PR: Issue 1: CLAUDE.md violation - torch.cuda.empty_cache() usageFile: The code uses
Fix: Remove both Issue 2: Bug - Missing gradient masking in kfac.pyFile: The
This results in mathematically incorrect KFAC approximation since both covariances must be computed over the same data points. Fix: Apply the same masking in def backward_hook(self, module: nn.Module, g: Tensor) -> None:
"""Compute gradient covariance: G^T @ G."""
name = assert_type(str, module._name)
S_cov_po = self.S_cov_dict[name]
mask = self._current_valid_mask
assert mask is not None, "Valid mask not set for backward hook."
# g: [N, S, O], valid_masks: [N, S] -> select valid positions
g_bo = g[mask] # [num_valid, O]
# Compute local covariance
local_update_oo = g_bo.mT @ g_bo
# ... rest of the functionIssue 3: Bug - Missing gradient masking in tkfac.pyFile: Same issue as in Fix: Apply the same masking in def backward_hook(self, module: nn.Module, g: Tensor) -> None:
"""Compute trace-weighted gradient covariance."""
name = assert_type(str, module._name)
S_tcov_po = self.S_tcov_dict[name]
A_tcov_ki = self.A_tcov_dict[name]
mask = self._current_valid_mask
assert mask is not None, "Valid mask not set for backward hook."
# g: [N, S, O], valid_masks: [N, S] -> select valid positions
g_bo = g[mask] # [num_valid, O]
a_bi = module._inputs
# ... rest of the functionIssue 4: Bug - Missing loss normalization when use_dataset_labels=TrueFile: In
This inconsistency will cause issues in downstream code expecting per-sample losses. Fix: Add normalization in the else:
# Use dataset labels instead of sampling
losses = F.cross_entropy(
logits[:, :-1].flatten(0, 1),
y[:, 1:].flatten(0, 1),
reduction="none",
ignore_index=-100,
).reshape_as(y[:, 1:])
losses = losses.sum(1) / denoms # <-- Add this lineIssue 5: Bug - dtype string not converted to torch.dtypeFile: When Fix: Use hessian_dtype = (
model.dtype
if hessian_cfg.hessian_dtype == "auto"
else convert_precision_to_torch(hessian_cfg.hessian_dtype)
)Similar pattern is correctly implemented in Summary
|
|
@claude Give it another go. I think i fixed the issues now |
|
Claude finished @LouisYRYJ's task in 2m 10s —— View job Re-reviewing PR after fixes
Review ResultsGreat progress! Most issues have been fixed. Here's the status: ✅ Fixed Issues
|
|
@claude https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html ignore_index is set by default to -100, so should be fine? |
|
Claude finished @LouisYRYJ's task in 37s —— View job Re-reviewing PR after fixes
|
|
|
||
|
|
||
| @dataclass | ||
| class Hessian: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow-up suggestion: could consider having ability to compute the preconditioners + normalizers here too. Seems like a fairly different operation to a full run (less data, more need to parallelize preconditioner computations). Not sure if this makes sense yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have thought about it too and I agree! I think for now I want to make it lower priority. But sounds like just shuffling some pipelines around. Also we have the compute_eigendecomposition function now in eigenvectors.py which does distirbuted eigenvector decomposition that may be reused
* Fix mask bug and add batch size invariance test wih toy model The backward_hook was using g.reshape(-1, O) which includes padding positions in the covariance computation. This causes incorrect results when batches have different sequence lengths. Before this commit, the added test failed with: > FAILED tests/ekfac_tests/test_batch_size_invariance.py::test_trace_batch_invariant[seq_lengths1-20] - AssertionError: Scalars are not close! > > Expected 1.231401894309304 but got 0.8983965093439276. > Absolute difference: 0.33300538496537635 (up to 1e-4 allowed) > Relative difference: 0.27042786478102654 (up to 0.01 allowed) * Fix use_dataset_labels condition and add FIM accuracy test The condition `if not hessian_cfg.use_dataset_labels:` was inverted, causing the empirical Fisher (with dataset labels) to use sampled labels and vice versa. Add test_fim_accuracy.py which verifies that KFAC approximates the Fisher Information Matrix within tolerance for both empirical FIM (dataset labels) and true FIM (sampled labels). * Add ground truth ekfac tests This is still missing FSDP support and test_apply_ekfac.py from #68 Co-Authored-By: LouisYRYJ <louis.yousif@yahoo.de>
- Replace set_all_seeds by existing setup_reproducibility - Reuse approximate_hessians instead of doing something equivalent manually.
Allow configuring the number of samples from pile-10k dataset via pytest command line option instead of hardcoding 100. The dataset directory is now named dynamically (e.g., pile_100_examples).
Restore the calls to dist.barrier that existed in #13, without them the process would hang when running with world_size > 1. For testing, we add _allocate_batches_world to compute the batches for the ground truth. The tests don't pass due to numerical errors, this is handled in the next commit by changing our comparison logic.
- Eigenvectors: Check |cosine_similarity| ≈ 1 per column, which naturally handles sign ambiguity (eigenvectors are only defined up to sign) - Covariances: Check relative Frobenius norm since values should match exactly - Eigenvalue corrections: Align signs based on eigenvector orientation, then check relative error (λ[i,j] transforms as sign_G[i] * sign_A[j]) - Also reenable CPU tests which pass after this change.
With world_size > 1, floating-point reduction order differs between ground truth (single process) and distributed run, causing larger numerical differences in some layers. For eigenvectors, use average |cos_sim| instead of minimum - this tolerates occasional outlier eigenvectors while maintaining a stricter threshold (1e-3 vs 0.1 that would be needed for min). For eigenvalue corrections, use atol=0.2 when world_size > 1.
179706c to
6fa74be
Compare
…apply_hessian (WIP)
|
lmk when this is ready! |

This is it. The final EKFAC merge. Most of it is pretty self-contained and doesn't interfere much with the rest of the code, as we already had set-up much of the work before.
I did slight changes in the code to accomodate for some functionalties that we were missing. In particular, I need to pass a mask tensor to the CovarianceCollector to ensure that we only process the correct activations (and e.g. not the activations of the user tokens).
Technically this needs to happen for any collector, but GradientCollector for example gets around this because gradients will be always 0 at the masked tokens and when we do a*g, the masked token activations get multiplied by 0 and will be ignored (h/t @smarter).