@@ -156,6 +156,17 @@ def cleanup(self) -> None:
156156 self ._precompile_args_path = None
157157 self ._precompile_result_counter = count ()
158158
159+ def _get_checkpoint_dir (self ) -> Path :
160+ """Get checkpoint directory for autotuner checkpoints."""
161+ from torch ._inductor .runtime .cache_dir_utils import cache_dir
162+
163+ if (user_path := os .environ .get ("HELION_CACHE_DIR" , None )) is not None :
164+ base = Path (user_path )
165+ else :
166+ base = Path (cache_dir ()) / "helion"
167+
168+ return base / "autotuner_checkpoints"
169+
159170 def _clone_args (self , args : Sequence [object ]) -> Sequence [object ]:
160171 def _clone_leaf (leaf : object ) -> object :
161172 if isinstance (leaf , torch .Tensor ):
@@ -685,6 +696,43 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
685696 torch .save (self .args , args_path )
686697 self ._precompile_args_path = args_path
687698 exit_stack .callback (self .cleanup )
699+
700+ checkpoint_loaded = False
701+ if self .settings .autotune_checkpoint_id is not None :
702+ from .local_cache import LocalAutotuneCache
703+
704+ checkpoint_id = self .settings .autotune_checkpoint_id
705+ current_hash = LocalAutotuneCache (self )._generate_key ().stable_hash ()
706+
707+ # Checkpoint ID format: {8-char-hash}-{timestamp}-{8-char-uuid}
708+ # Extract hash prefix and check compatibility
709+ hash_prefix = checkpoint_id .split ("-" )[0 ]
710+ if hash_prefix != current_hash [:8 ]:
711+ self .log (
712+ f"Warning: Checkpoint '{ checkpoint_id } ' is for a different kernel "
713+ f"(hash mismatch). Ignoring checkpoint and starting fresh autotuning run." ,
714+ level = logging .WARNING ,
715+ )
716+ else :
717+ # Hash matches, load checkpoint
718+ checkpoint_dir = self ._get_checkpoint_dir ()
719+ checkpoint_file = checkpoint_dir / f"{ checkpoint_id } .checkpoint"
720+ if not checkpoint_file .exists ():
721+ raise FileNotFoundError (
722+ f"Checkpoint file not found: { checkpoint_file } "
723+ )
724+ self .log (f"Resuming from checkpoint: { checkpoint_file } " )
725+ with open (checkpoint_file , "rb" ) as f :
726+ state = pickle .load (f )
727+ self .load_state_dict (state )
728+ self .log (
729+ f"Resumed at generation { self ._current_generation } with "
730+ f"{ len (self .population )} configs" # type: ignore[attr-defined]
731+ )
732+ checkpoint_loaded = True
733+
734+ if not checkpoint_loaded :
735+ self ._init_search ()
688736 best = self ._autotune ()
689737 end = time .perf_counter ()
690738 kernel_decorator = self .kernel .format_kernel_decorator (best , self .settings )
@@ -701,6 +749,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
701749 print (triton_code , file = sys .stderr )
702750 return best
703751
752+ def _init_search (self ) -> None :
753+ """
754+ Initialize the search state for a fresh autotuning run.
755+
756+ This method is called when starting autotuning without a checkpoint.
757+ Subclasses should override this to set up initial population and state.
758+ After this method, _current_generation should reflect the last completed generation.
759+ """
760+ raise NotImplementedError
761+
704762 def _autotune (self ) -> Config :
705763 """
706764 Abstract method to perform the actual autotuning.
@@ -712,6 +770,102 @@ def _autotune(self) -> Config:
712770 """
713771 raise NotImplementedError
714772
773+ def save_checkpoint (self ) -> Path :
774+ """
775+ Save current autotuner state to checkpoint file.
776+
777+ Each call generates a new checkpoint ID for the saved checkpoint.
778+
779+ Returns:
780+ Path to saved checkpoint file
781+ """
782+ from .local_cache import LocalAutotuneCache
783+
784+ # Checkpoint ID format: {8-char-hash}-{timestamp}-{8-char-uuid}
785+ stable_hash = LocalAutotuneCache (self )._generate_key ().stable_hash ()[:8 ]
786+ timestamp = int (time .time ())
787+ short_uuid = uuid .uuid4 ().hex [:8 ]
788+ new_checkpoint_id = f"{ stable_hash } -{ timestamp } -{ short_uuid } "
789+ filename = f"{ new_checkpoint_id } .checkpoint"
790+
791+ checkpoint_dir = self ._get_checkpoint_dir ()
792+ checkpoint_dir .mkdir (parents = True , exist_ok = True )
793+ checkpoint_path = checkpoint_dir / filename
794+
795+ state = self .state_dict ()
796+
797+ # Atomic write using temp file
798+ tmp = checkpoint_dir / f"tmp.{ uuid .uuid4 ()!s} "
799+ with open (tmp , "wb" ) as f :
800+ pickle .dump (state , f )
801+ os .rename (str (tmp ), str (checkpoint_path ))
802+
803+ self .log (f"Checkpoint saved: { checkpoint_path } " )
804+ self .log (
805+ f"To resume from this checkpoint, set HELION_AUTOTUNE_CHECKPOINT_ID={ new_checkpoint_id } "
806+ f'or `autotune_checkpoint_id="{ new_checkpoint_id } "` in the kernel settings'
807+ )
808+ return checkpoint_path
809+
810+ def state_dict (self ) -> dict [str , Any ]:
811+ """
812+ Return autotuner state as a dictionary.
813+
814+ Subclasses should call super().state_dict() first, then update with their own fields.
815+ """
816+ import numpy as np
817+
818+ from .local_cache import LocalAutotuneCache
819+
820+ rng_state : dict [str , Any ] = {
821+ "random" : random .getstate (),
822+ "torch" : torch .random .get_rng_state (),
823+ "numpy" : np .random .get_state (), # noqa: NPY002
824+ }
825+ if torch .cuda .is_available ():
826+ rng_state ["torch_cuda" ] = torch .cuda .get_rng_state ()
827+
828+ return {
829+ "algorithm" : self .__class__ .__name__ ,
830+ "cache_key_stable_hash" : LocalAutotuneCache (self )
831+ ._generate_key ()
832+ .stable_hash (),
833+ "counters" : dict (self .counters ),
834+ "rng_state" : rng_state ,
835+ "best_perf_so_far" : self .best_perf_so_far ,
836+ "current_generation" : self ._current_generation ,
837+ }
838+
839+ def load_state_dict (self , state : dict [str , Any ]) -> None :
840+ """
841+ Restore autotuner state from a dictionary.
842+
843+ Subclasses should call super().load_state_dict(state) first,
844+ then restore their own fields.
845+ """
846+ from .local_cache import LocalAutotuneCache
847+
848+ current_hash = LocalAutotuneCache (self )._generate_key ().stable_hash ()
849+ if state .get ("cache_key_stable_hash" ) != current_hash :
850+ raise exc .CheckpointError (
851+ "State dict is incompatible: kernel, hardware, or input shapes may have changed"
852+ )
853+
854+ import numpy as np
855+
856+ # Restore RNG state
857+ rng_state = state ["rng_state" ]
858+ random .setstate (rng_state ["random" ])
859+ torch .random .set_rng_state (rng_state ["torch" ])
860+ np .random .set_state (rng_state ["numpy" ]) # noqa: NPY002
861+ if "torch_cuda" in rng_state and torch .cuda .is_available ():
862+ torch .cuda .set_rng_state (rng_state ["torch_cuda" ])
863+
864+ # Restore autotuner state
865+ self .counters = collections .Counter (state ["counters" ])
866+ self .best_perf_so_far = state ["best_perf_so_far" ]
867+ self ._current_generation = state ["current_generation" ]
868+
715869
716870@dataclasses .dataclass
717871class PopulationMember :
@@ -817,6 +971,8 @@ def best(self) -> PopulationMember:
817971
818972 def set_generation (self , generation : int ) -> None :
819973 self ._current_generation = generation
974+ if generation > 0 :
975+ self .save_checkpoint ()
820976
821977 def benchmark_flat (self , flat_values : FlatConfig ) -> PopulationMember :
822978 """
@@ -970,6 +1126,49 @@ def statistics(self) -> str:
9701126 """
9711127 return population_statistics (self .population )
9721128
1129+ def state_dict (self ) -> dict [str , Any ]:
1130+ state = super ().state_dict ()
1131+ # Serialize population (excluding fn which will be recompiled on load)
1132+ population_state = []
1133+ for member in self .population :
1134+ population_state .append (
1135+ {
1136+ "perfs" : member .perfs ,
1137+ "flat_values" : member .flat_values ,
1138+ "config" : member .config ,
1139+ "status" : member .status ,
1140+ "compile_time" : member .compile_time ,
1141+ }
1142+ )
1143+ state ["population" ] = population_state
1144+ return state
1145+
1146+ def load_state_dict (self , state : dict [str , Any ]) -> None :
1147+ super ().load_state_dict (state )
1148+
1149+ # Restore population
1150+ self .population = []
1151+ for member_state in state ["population" ]:
1152+ member = PopulationMember (
1153+ fn = _unset_fn ,
1154+ perfs = member_state ["perfs" ],
1155+ flat_values = member_state ["flat_values" ],
1156+ config = member_state ["config" ],
1157+ status = member_state ["status" ],
1158+ compile_time = member_state .get ("compile_time" ),
1159+ )
1160+ self .population .append (member )
1161+
1162+ # Recompile kernel functions for all population members
1163+ for member in self .population :
1164+ if member .fn is _unset_fn and member .status == "ok" :
1165+ try :
1166+ member .fn = self .kernel .compile_config (
1167+ member .config , allow_print = False
1168+ )
1169+ except Exception :
1170+ member .fn = _unset_fn
1171+
9731172
9741173def population_statistics (population : list [PopulationMember ]) -> str :
9751174 """
0 commit comments