11import functools
2+ import hashlib
23import os
34from abc import ABC , abstractmethod
45from contextlib import ContextDecorator , nullcontext
56from dataclasses import astuple , dataclass , field
67from fnmatch import fnmatchcase
78from typing import Callable , Literal , Mapping , Optional
89
10+ import numpy as np
911import torch
1012import torch .distributed as dist
1113import torch .nn as nn
2426from tqdm .auto import tqdm
2527from transformers import PreTrainedModel
2628
27- from bergson .config import AttentionConfig , IndexConfig
29+ from bergson .config import AttentionConfig , HessianConfig , IndexConfig
2830from bergson .data import pad_and_tensor
2931from bergson .gradients import (
3032 GradientProcessor ,
3133 LayerAdapter ,
3234)
3335from bergson .utils .logger import get_logger
3436from bergson .utils .peft import set_peft_enabled
35- from bergson .utils .utils import create_projection_matrix
3637
3738
3839@dataclass
@@ -78,6 +79,7 @@ class HookCollectorBase(ContextDecorator, ABC):
7879 Optional configuration specifying how to split up the attention module gradients
7980 into per-head gradients. See also bergson.config.AttentionConfig.
8081 """
82+ logger = get_logger ("HookCollectorBase" , level = "INFO" )
8183
8284 def __post_init__ (
8385 self ,
@@ -256,6 +258,28 @@ def projection(
256258 self .processor ._projection_matrices [key ] = A
257259 return A
258260
261+ def with_batch (self , valid_mask : Tensor | None = None ) -> "HookCollectorBase" :
262+ """
263+ Set the current batch indices and valid mask before entering the context.
264+
265+ This allows hooks to access batch indices and valid mask during
266+ forward/backward passes.
267+ Usage:
268+ with collector.with_batch(indices, valid_mask):
269+ # forward/backward pass
270+ # hooks can access self._current_indices and self._current_valid_mask
271+
272+ Args:
273+ indices: List of data indices in the current batch.
274+ valid_mask: Optional boolean tensor of shape [batch_size, seq_len]
275+ indicating which positions have valid labels for loss computation.
276+
277+ Returns:
278+ self, for use as a context manager.
279+ """
280+ self ._current_valid_mask = valid_mask
281+ return self
282+
259283 def __enter__ (self ):
260284 """Register forward and backward hooks on all target modules."""
261285 for name in self .target_info :
@@ -484,15 +508,23 @@ def run_with_collector_hooks(
484508 ):
485509 batch = self .data [indices ]
486510
511+ # Compute padded tensors and valid_mask before entering context
512+ x , y , valid_mask = pad_and_tensor (
513+ batch ["input_ids" ],
514+ labels = batch .get ("labels" ),
515+ device = self .model .device ,
516+ )
517+ total_processed += valid_mask .sum ()
518+
487519 with (
488- self .collector ,
520+ self .collector . with_batch ( valid_mask ) ,
489521 (
490522 record_function (f"step_{ step } " )
491523 if self .cfg .profile
492524 else nullcontext ()
493525 ),
494526 ):
495- losses = self .forward_backward (self .model , batch )
527+ losses = self .forward_backward (self .model , x , y , batch )
496528
497529 # TODO: currently builder also calls torch.cuda.synchronize
498530 torch .cuda .synchronize () if torch .cuda .is_available () else None
@@ -503,11 +535,17 @@ def run_with_collector_hooks(
503535 step += 1
504536
505537 self .collector .process_batch (indices , losses = losses )
506- total_processed += len (indices )
507538
508539 self .collector .teardown ()
540+
509541 if dist .is_initialized ():
510542 dist .all_reduce (total_processed , op = dist .ReduceOp .SUM )
543+
544+ if self .rank == 0 :
545+ torch .save (
546+ total_processed ,
547+ os .path .join (self .cfg .partial_run_path , "total_processed.pt" ),
548+ )
511549 self .logger .info (f"Total processed: { total_processed .item ()} " )
512550
513551
@@ -523,18 +561,17 @@ def fwd_bwd_factory(cfg: IndexConfig) -> Callable:
523561 summed loss.
524562
525563 Returns:
526- A callable fwd_bwd(model, batch) -> Tensor that performs a forward pass and
527- backward pass, returning the per-sample losses.
528- The batch must contain "input_ids" and optionally "labels" and "advantage".
564+ A callable fwd_bwd(model, x, y, batch) -> Tensor that performs a forward pass
565+ and backward pass, returning the per-sample losses.
566+ Args:
567+ model: The model to run forward/backward on.
568+ x: Padded input token ids tensor of shape [batch_size, seq_len].
569+ y: Padded label tensor of shape [batch_size, seq_len] with -100 for padding.
570+ batch: Original batch dict, used only for "advantage" if present.
529571 Returns a tensor of shape [batch_size] with one loss value per sample.
530572 """
531573
532- def fwd_bwd (model , batch ):
533- x , y = pad_and_tensor (
534- batch ["input_ids" ], # type: ignore
535- labels = batch .get ("labels" ), # type: ignore
536- device = model .device ,
537- )
574+ def fwd_bwd (model , x : Tensor , y : Tensor , batch : dict ):
538575 logits = model (x ).logits [:, :- 1 ]
539576 masks = y [:, 1 :] != - 100
540577 denoms = (
@@ -571,3 +608,68 @@ def fwd_bwd(model, batch):
571608 return losses
572609
573610 return fwd_bwd
611+
612+
613+ def fwd_bwd_hessian_factory (cfg : HessianConfig ) -> Callable :
614+ def fwd_bwd_hessian (model , x : Tensor , y : Tensor , batch : dict ):
615+ logits = model (x ).logits [:, :- 1 ]
616+ masks = y [:, 1 :] != - 100
617+ denoms = masks .sum (dim = 1 , dtype = model .dtype )
618+
619+ if not cfg .use_dataset_labels :
620+ losses = F .cross_entropy (
621+ logits .reshape (- 1 , logits .size (- 1 )),
622+ y [:, 1 :].flatten (),
623+ reduction = "none" ,
624+ ).reshape_as (y [:, 1 :])
625+ losses = losses .sum (1 ) / denoms
626+ else :
627+ with torch .no_grad ():
628+ probs = F .softmax (logits , dim = - 1 )
629+ sampled_tokens = torch .multinomial (
630+ probs .reshape (- 1 , probs .size (- 1 )),
631+ num_samples = 1 ,
632+ replacement = True ,
633+ ).reshape_as (y [:, 1 :])
634+ losses = F .cross_entropy (
635+ logits .reshape (- 1 , logits .size (- 1 )),
636+ sampled_tokens .flatten (),
637+ reduction = "none" ,
638+ ).reshape_as (y [:, 1 :])
639+
640+ losses .sum ().backward ()
641+ model .zero_grad ()
642+
643+ return losses
644+
645+ return fwd_bwd_hessian
646+
647+
648+ def create_projection_matrix (
649+ identifier : str ,
650+ m : int ,
651+ n : int ,
652+ dtype : torch .dtype ,
653+ device : torch .device ,
654+ projection_type : Literal ["normal" , "rademacher" ] = "normal" ,
655+ ) -> Tensor :
656+ """Create a projection matrix deterministically based on identifier and side."""
657+ # Seed the PRNG with the name of the layer and what "side" we are projecting
658+ message = bytes (identifier , "utf-8" )
659+ digest = hashlib .md5 (message ).digest ()
660+ seed = int .from_bytes (digest , byteorder = "big" ) % (2 ** 63 - 1 )
661+
662+ if projection_type == "normal" :
663+ prng = torch .Generator (device ).manual_seed (seed )
664+ A = torch .randn (m , n , device = device , dtype = dtype , generator = prng )
665+ elif projection_type == "rademacher" :
666+ numpy_rng = np .random .Generator (np .random .PCG64 (seed ))
667+ random_bytes = numpy_rng .bytes ((m * n + 7 ) // 8 )
668+ random_bytes = np .frombuffer (random_bytes , dtype = np .uint8 )
669+ A = np .unpackbits (random_bytes )[: m * n ].reshape ((m , n ))
670+ A = torch .from_numpy (A ).to (device , dtype = dtype )
671+ A = A .add_ (- 0.5 ).mul_ (2 )
672+ else :
673+ raise ValueError (f"Unknown projection type: { projection_type } " )
674+ A /= A .norm (dim = 1 , keepdim = True )
675+ return A
0 commit comments