Skip to content

Commit 5d1a583

Browse files
committed
ekfac implementation done (untested)
1 parent b5d38c4 commit 5d1a583

File tree

10 files changed

+1085
-53
lines changed

10 files changed

+1085
-53
lines changed

bergson/__main__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from simple_parsing import ArgumentParser, ConflictResolution
77

88
from .build import build
9-
from .config import IndexConfig, QueryConfig, ReduceConfig, ScoreConfig
9+
from .config import HessianConfig, IndexConfig, QueryConfig, ReduceConfig, ScoreConfig
10+
from .hessians.hessian_approximations import approximate_hessians
1011
from .query.query_index import query
1112
from .reduce import reduce
1213
from .score.score import score_dataset
@@ -99,11 +100,24 @@ def execute(self):
99100
query(self.query_cfg)
100101

101102

103+
@dataclass
104+
class Hessian:
105+
"""Approximate Hessian matrices using KFAC or EKFAC."""
106+
107+
hessian_cfg: HessianConfig
108+
index_cfg: IndexConfig
109+
110+
def execute(self):
111+
"""Compute Hessian approximation."""
112+
validate_run_path(self.index_cfg)
113+
approximate_hessians(self.index_cfg, self.hessian_cfg)
114+
115+
102116
@dataclass
103117
class Main:
104118
"""Routes to the subcommands."""
105119

106-
command: Union[Build, Query, Reduce, Score]
120+
command: Union[Build, Query, Reduce, Score, Hessian]
107121

108122
def execute(self):
109123
"""Run the script."""

bergson/collector/collector.py

Lines changed: 116 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import functools
2+
import hashlib
23
import os
34
from abc import ABC, abstractmethod
45
from contextlib import ContextDecorator, nullcontext
56
from dataclasses import astuple, dataclass, field
67
from fnmatch import fnmatchcase
78
from typing import Callable, Literal, Mapping, Optional
89

10+
import numpy as np
911
import torch
1012
import torch.distributed as dist
1113
import torch.nn as nn
@@ -24,15 +26,14 @@
2426
from tqdm.auto import tqdm
2527
from transformers import PreTrainedModel
2628

27-
from bergson.config import AttentionConfig, IndexConfig
29+
from bergson.config import AttentionConfig, HessianConfig, IndexConfig
2830
from bergson.data import pad_and_tensor
2931
from bergson.gradients import (
3032
GradientProcessor,
3133
LayerAdapter,
3234
)
3335
from bergson.utils.logger import get_logger
3436
from bergson.utils.peft import set_peft_enabled
35-
from bergson.utils.utils import create_projection_matrix
3637

3738

3839
@dataclass
@@ -78,6 +79,7 @@ class HookCollectorBase(ContextDecorator, ABC):
7879
Optional configuration specifying how to split up the attention module gradients
7980
into per-head gradients. See also bergson.config.AttentionConfig.
8081
"""
82+
logger = get_logger("HookCollectorBase", level="INFO")
8183

8284
def __post_init__(
8385
self,
@@ -256,6 +258,28 @@ def projection(
256258
self.processor._projection_matrices[key] = A
257259
return A
258260

261+
def with_batch(self, valid_mask: Tensor | None = None) -> "HookCollectorBase":
262+
"""
263+
Set the current batch indices and valid mask before entering the context.
264+
265+
This allows hooks to access batch indices and valid mask during
266+
forward/backward passes.
267+
Usage:
268+
with collector.with_batch(indices, valid_mask):
269+
# forward/backward pass
270+
# hooks can access self._current_indices and self._current_valid_mask
271+
272+
Args:
273+
indices: List of data indices in the current batch.
274+
valid_mask: Optional boolean tensor of shape [batch_size, seq_len]
275+
indicating which positions have valid labels for loss computation.
276+
277+
Returns:
278+
self, for use as a context manager.
279+
"""
280+
self._current_valid_mask = valid_mask
281+
return self
282+
259283
def __enter__(self):
260284
"""Register forward and backward hooks on all target modules."""
261285
for name in self.target_info:
@@ -484,15 +508,23 @@ def run_with_collector_hooks(
484508
):
485509
batch = self.data[indices]
486510

511+
# Compute padded tensors and valid_mask before entering context
512+
x, y, valid_mask = pad_and_tensor(
513+
batch["input_ids"],
514+
labels=batch.get("labels"),
515+
device=self.model.device,
516+
)
517+
total_processed += valid_mask.sum()
518+
487519
with (
488-
self.collector,
520+
self.collector.with_batch(valid_mask),
489521
(
490522
record_function(f"step_{step}")
491523
if self.cfg.profile
492524
else nullcontext()
493525
),
494526
):
495-
losses = self.forward_backward(self.model, batch)
527+
losses = self.forward_backward(self.model, x, y, batch)
496528

497529
# TODO: currently builder also calls torch.cuda.synchronize
498530
torch.cuda.synchronize() if torch.cuda.is_available() else None
@@ -503,11 +535,17 @@ def run_with_collector_hooks(
503535
step += 1
504536

505537
self.collector.process_batch(indices, losses=losses)
506-
total_processed += len(indices)
507538

508539
self.collector.teardown()
540+
509541
if dist.is_initialized():
510542
dist.all_reduce(total_processed, op=dist.ReduceOp.SUM)
543+
544+
if self.rank == 0:
545+
torch.save(
546+
total_processed,
547+
os.path.join(self.cfg.partial_run_path, "total_processed.pt"),
548+
)
511549
self.logger.info(f"Total processed: {total_processed.item()}")
512550

513551

@@ -523,18 +561,17 @@ def fwd_bwd_factory(cfg: IndexConfig) -> Callable:
523561
summed loss.
524562
525563
Returns:
526-
A callable fwd_bwd(model, batch) -> Tensor that performs a forward pass and
527-
backward pass, returning the per-sample losses.
528-
The batch must contain "input_ids" and optionally "labels" and "advantage".
564+
A callable fwd_bwd(model, x, y, batch) -> Tensor that performs a forward pass
565+
and backward pass, returning the per-sample losses.
566+
Args:
567+
model: The model to run forward/backward on.
568+
x: Padded input token ids tensor of shape [batch_size, seq_len].
569+
y: Padded label tensor of shape [batch_size, seq_len] with -100 for padding.
570+
batch: Original batch dict, used only for "advantage" if present.
529571
Returns a tensor of shape [batch_size] with one loss value per sample.
530572
"""
531573

532-
def fwd_bwd(model, batch):
533-
x, y = pad_and_tensor(
534-
batch["input_ids"], # type: ignore
535-
labels=batch.get("labels"), # type: ignore
536-
device=model.device,
537-
)
574+
def fwd_bwd(model, x: Tensor, y: Tensor, batch: dict):
538575
logits = model(x).logits[:, :-1]
539576
masks = y[:, 1:] != -100
540577
denoms = (
@@ -571,3 +608,68 @@ def fwd_bwd(model, batch):
571608
return losses
572609

573610
return fwd_bwd
611+
612+
613+
def fwd_bwd_hessian_factory(cfg: HessianConfig) -> Callable:
614+
def fwd_bwd_hessian(model, x: Tensor, y: Tensor, batch: dict):
615+
logits = model(x).logits[:, :-1]
616+
masks = y[:, 1:] != -100
617+
denoms = masks.sum(dim=1, dtype=model.dtype)
618+
619+
if not cfg.use_dataset_labels:
620+
losses = F.cross_entropy(
621+
logits.reshape(-1, logits.size(-1)),
622+
y[:, 1:].flatten(),
623+
reduction="none",
624+
).reshape_as(y[:, 1:])
625+
losses = losses.sum(1) / denoms
626+
else:
627+
with torch.no_grad():
628+
probs = F.softmax(logits, dim=-1)
629+
sampled_tokens = torch.multinomial(
630+
probs.reshape(-1, probs.size(-1)),
631+
num_samples=1,
632+
replacement=True,
633+
).reshape_as(y[:, 1:])
634+
losses = F.cross_entropy(
635+
logits.reshape(-1, logits.size(-1)),
636+
sampled_tokens.flatten(),
637+
reduction="none",
638+
).reshape_as(y[:, 1:])
639+
640+
losses.sum().backward()
641+
model.zero_grad()
642+
643+
return losses
644+
645+
return fwd_bwd_hessian
646+
647+
648+
def create_projection_matrix(
649+
identifier: str,
650+
m: int,
651+
n: int,
652+
dtype: torch.dtype,
653+
device: torch.device,
654+
projection_type: Literal["normal", "rademacher"] = "normal",
655+
) -> Tensor:
656+
"""Create a projection matrix deterministically based on identifier and side."""
657+
# Seed the PRNG with the name of the layer and what "side" we are projecting
658+
message = bytes(identifier, "utf-8")
659+
digest = hashlib.md5(message).digest()
660+
seed = int.from_bytes(digest, byteorder="big") % (2**63 - 1)
661+
662+
if projection_type == "normal":
663+
prng = torch.Generator(device).manual_seed(seed)
664+
A = torch.randn(m, n, device=device, dtype=dtype, generator=prng)
665+
elif projection_type == "rademacher":
666+
numpy_rng = np.random.Generator(np.random.PCG64(seed))
667+
random_bytes = numpy_rng.bytes((m * n + 7) // 8)
668+
random_bytes = np.frombuffer(random_bytes, dtype=np.uint8)
669+
A = np.unpackbits(random_bytes)[: m * n].reshape((m, n))
670+
A = torch.from_numpy(A).to(device, dtype=dtype)
671+
A = A.add_(-0.5).mul_(2)
672+
else:
673+
raise ValueError(f"Unknown projection type: {projection_type}")
674+
A /= A.norm(dim=1, keepdim=True)
675+
return A

bergson/config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,29 @@ class ReduceConfig:
298298
"""Whether to unit normalize the gradients before reducing them."""
299299

300300

301+
@dataclass
302+
class HessianConfig:
303+
"""Config for reducing the gradients."""
304+
305+
method: Literal["kfac"] = "kfac"
306+
"""Method for approximating the Hessian."""
307+
308+
ev_correction: bool = False
309+
"""Whether to additionally compute eigenvalue correction."""
310+
311+
hessian_dtype: Literal["auto", "bf16", "fp16", "fp32"] = "auto"
312+
"""Precision (dtype) to use for the Hessian approximation."""
313+
314+
lambda_damp_factor: float = 0.1
315+
"""Damping factor for the Hessian approximation.
316+
This will be a relative value multiplied
317+
by the average eigenvalue of each module."""
318+
319+
use_dataset_labels: bool = False
320+
"""Whether to use dataset labels for Hessian (empirical Fisher) approximation.
321+
If false, the model predictions will be used."""
322+
323+
301324
@dataclass
302325
class FaissConfig:
303326
"""Configuration for FAISS index."""

bergson/data.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def pad_and_tensor(
466466
padding_value: int = 0,
467467
dtype: torch.dtype | None = torch.long,
468468
device: torch.device | None = None,
469-
) -> tuple[torch.Tensor, torch.Tensor]:
469+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
470470
"""
471471
Pad a list of sequences to the same length and convert them to tensors.
472472
Returns a tuple of padded sequences and labels. The labels are the same as the
@@ -485,7 +485,12 @@ def pad_and_tensor(
485485
# convert to tensor
486486
padded_tokens = torch.tensor(padded, dtype=dtype, device=device)
487487
padded_labels = torch.tensor(labels, dtype=dtype, device=device)
488-
return padded_tokens, padded_labels
488+
# Compute valid_masks: position i is valid if labels[i+1] != -100
489+
N, S = padded_tokens.shape
490+
valid_masks = torch.zeros(N, S, dtype=torch.bool, device=device)
491+
valid_masks[:, :-1] = padded_labels[:, 1:] != -100
492+
493+
return padded_tokens, padded_labels, valid_masks
489494

490495

491496
def tokenize(batch: dict, *, args: DataConfig, tokenizer):

0 commit comments

Comments
 (0)