Skip to content

Commit 5db4cdf

Browse files
committed
[Autotuner] Auto-checkpoint feature and ability to resume from checkpoint
1 parent 9cbfb30 commit 5db4cdf

File tree

12 files changed

+1422
-130
lines changed

12 files changed

+1422
-130
lines changed

docs/api/settings.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,13 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
197197
198198
Users can still override individual ``autotune_*`` settings; explicit values win over the preset. Controlled by ``HELION_AUTOTUNE_EFFORT``.
199199
200+
.. autoattribute:: Settings.autotune_checkpoint_id
201+
202+
Checkpoint ID for resuming autotuning from a previous checkpoint. When set, the autotuner attempts to load
203+
state from a checkpoint file matching this ID, allowing long-running autotuning sessions to be interrupted
204+
and resumed. The checkpoint ID contains a hash prefix that identifies the kernel, hardware, and input shapes.
205+
If the hash doesn't match, the checkpoint is ignored and autotuning starts fresh with a warning message.
206+
Controlled by ``HELION_AUTOTUNE_CHECKPOINT_ID``.
200207
201208
```
202209

@@ -295,6 +302,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
295302
| ``HELION_AUTOTUNE_PROGRESS_BAR`` | ``autotune_progress_bar`` | Enable or disable the progress bar UI during autotuning. |
296303
| ``HELION_AUTOTUNE_IGNORE_ERRORS`` | ``autotune_ignore_errors`` | Continue autotuning even when recoverable runtime errors occur. |
297304
| ``HELION_AUTOTUNE_CONFIG_OVERRIDES`` | ``autotune_config_overrides`` | Supply JSON forcing particular autotuner config key/value pairs. |
305+
| ``HELION_AUTOTUNE_CHECKPOINT_ID`` | ``autotune_checkpoint_id`` | Checkpoint ID for resuming autotuning from a previous checkpoint. |
298306
| ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. |
299307
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, ignore cached autotuning entries and rerun searches. |
300308
| ``HELION_ASSERT_CACHE_HIT`` | ``AutotuneCacheBase`` | When set to ``1``, require a cache hit; raises ``CacheAssertionError`` on cache miss with detailed diagnostics. |

docs/deployment_autotuning.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,31 @@ tuning time versus coverage, or try different search algorithms.
104104
need more reproducibility; see {doc}`api/settings`. Note this only
105105
affects which configs are tried, not the timing results.
106106

107+
### Checkpointing Long-Running Autotuning
108+
109+
For very long autotuning sessions, you can save and resume state using
110+
checkpoints. This is useful when tuning might be interrupted (e.g., preemptible
111+
instances) or when you want to continue tuning from a previous unfinished run.
112+
113+
The simplest approach is to use the `HELION_AUTOTUNE_CHECKPOINT_ID` environment
114+
variable. When autotuning runs, it periodically saves checkpoints and logs the
115+
checkpoint ID. To resume, set this environment variable to the checkpoint ID
116+
from a previous run.
117+
118+
```bash
119+
# First run - autotuning will log checkpoint IDs as it progresses:
120+
# "Checkpoint saved: .../autotuner_checkpoints/a1b2c3d4-1706123456-e5f6g7h8.checkpoint"
121+
# "To resume from this checkpoint, set HELION_AUTOTUNE_CHECKPOINT_ID=a1b2c3d4-1706123456-e5f6g7h8 ..."
122+
python run_kernel.py
123+
124+
# If interrupted, resume from the last checkpoint:
125+
HELION_AUTOTUNE_CHECKPOINT_ID=a1b2c3d4-1706123456-e5f6g7h8 python run_kernel.py
126+
```
127+
128+
The checkpoint ID contains a hash prefix that identifies the kernel, hardware,
129+
and input shapes. If the hash doesn't match, the checkpoint is ignored and autotuning
130+
starts fresh with a warning message.
131+
107132
## Deploy a Single Config
108133

109134
If one configuration wins for every production call, bake it into the decorator:

helion/_testing.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
import operator
1111
import os
1212
from pathlib import Path
13+
import random
1314
import re
1415
import sys
1516
from typing import TYPE_CHECKING
1617
from typing import Callable
1718
from typing import Generator
1819
import unittest
1920

21+
import numpy as np
2022
from packaging import version
2123
import pytest
2224
import torch
@@ -39,6 +41,26 @@
3941
from .runtime.kernel import Kernel
4042

4143

44+
def seed_rng(seed: int) -> None:
45+
random.seed(seed)
46+
np.random.seed(seed) # noqa: NPY002
47+
torch.manual_seed(seed)
48+
49+
50+
@contextlib.contextmanager
51+
def fork_rng() -> Generator[None, None, None]:
52+
"""Context manager that forks all RNGs and restores original state on exit."""
53+
python_state = random.getstate()
54+
numpy_state = np.random.get_state() # noqa: NPY002
55+
56+
with torch.random.fork_rng():
57+
try:
58+
yield
59+
finally:
60+
random.setstate(python_state)
61+
np.random.set_state(numpy_state) # noqa: NPY002
62+
63+
4264
def _strip_launcher_args(value: str) -> str:
4365
strip_pairs = []
4466
if supports_amd_cdna_tunables():

helion/autotuner/base_search.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,17 @@ def cleanup(self) -> None:
156156
self._precompile_args_path = None
157157
self._precompile_result_counter = count()
158158

159+
def _get_checkpoint_dir(self) -> Path:
160+
"""Get checkpoint directory for autotuner checkpoints."""
161+
from torch._inductor.runtime.cache_dir_utils import cache_dir
162+
163+
if (user_path := os.environ.get("HELION_CACHE_DIR", None)) is not None:
164+
base = Path(user_path)
165+
else:
166+
base = Path(cache_dir()) / "helion"
167+
168+
return base / "autotuner_checkpoints"
169+
159170
def _clone_args(self, args: Sequence[object]) -> Sequence[object]:
160171
def _clone_leaf(leaf: object) -> object:
161172
if isinstance(leaf, torch.Tensor):
@@ -685,6 +696,43 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
685696
torch.save(self.args, args_path)
686697
self._precompile_args_path = args_path
687698
exit_stack.callback(self.cleanup)
699+
700+
checkpoint_loaded = False
701+
if self.settings.autotune_checkpoint_id is not None:
702+
from .local_cache import LocalAutotuneCache
703+
704+
checkpoint_id = self.settings.autotune_checkpoint_id
705+
current_hash = LocalAutotuneCache(self)._generate_key().stable_hash()
706+
707+
# Checkpoint ID format: {8-char-hash}-{timestamp}-{8-char-uuid}
708+
# Extract hash prefix and check compatibility
709+
hash_prefix = checkpoint_id.split("-")[0]
710+
if hash_prefix != current_hash[:8]:
711+
self.log(
712+
f"Warning: Checkpoint '{checkpoint_id}' is for a different kernel "
713+
f"(hash mismatch). Ignoring checkpoint and starting fresh autotuning run.",
714+
level=logging.WARNING,
715+
)
716+
else:
717+
# Hash matches, load checkpoint
718+
checkpoint_dir = self._get_checkpoint_dir()
719+
checkpoint_file = checkpoint_dir / f"{checkpoint_id}.checkpoint"
720+
if not checkpoint_file.exists():
721+
raise FileNotFoundError(
722+
f"Checkpoint file not found: {checkpoint_file}"
723+
)
724+
self.log(f"Resuming from checkpoint: {checkpoint_file}")
725+
with open(checkpoint_file, "rb") as f:
726+
state = pickle.load(f)
727+
self.load_state_dict(state)
728+
self.log(
729+
f"Resumed at generation {self._current_generation} with "
730+
f"{len(self.population)} configs" # type: ignore[attr-defined]
731+
)
732+
checkpoint_loaded = True
733+
734+
if not checkpoint_loaded:
735+
self._init_search()
688736
best = self._autotune()
689737
end = time.perf_counter()
690738
kernel_decorator = self.kernel.format_kernel_decorator(best, self.settings)
@@ -701,6 +749,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
701749
print(triton_code, file=sys.stderr)
702750
return best
703751

752+
def _init_search(self) -> None:
753+
"""
754+
Initialize the search state for a fresh autotuning run.
755+
756+
This method is called when starting autotuning without a checkpoint.
757+
Subclasses should override this to set up initial population and state.
758+
After this method, _current_generation should reflect the last completed generation.
759+
"""
760+
raise NotImplementedError
761+
704762
def _autotune(self) -> Config:
705763
"""
706764
Abstract method to perform the actual autotuning.
@@ -712,6 +770,102 @@ def _autotune(self) -> Config:
712770
"""
713771
raise NotImplementedError
714772

773+
def save_checkpoint(self) -> Path:
774+
"""
775+
Save current autotuner state to checkpoint file.
776+
777+
Each call generates a new checkpoint ID for the saved checkpoint.
778+
779+
Returns:
780+
Path to saved checkpoint file
781+
"""
782+
from .local_cache import LocalAutotuneCache
783+
784+
# Checkpoint ID format: {8-char-hash}-{timestamp}-{8-char-uuid}
785+
stable_hash = LocalAutotuneCache(self)._generate_key().stable_hash()[:8]
786+
timestamp = int(time.time())
787+
short_uuid = uuid.uuid4().hex[:8]
788+
new_checkpoint_id = f"{stable_hash}-{timestamp}-{short_uuid}"
789+
filename = f"{new_checkpoint_id}.checkpoint"
790+
791+
checkpoint_dir = self._get_checkpoint_dir()
792+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
793+
checkpoint_path = checkpoint_dir / filename
794+
795+
state = self.state_dict()
796+
797+
# Atomic write using temp file
798+
tmp = checkpoint_dir / f"tmp.{uuid.uuid4()!s}"
799+
with open(tmp, "wb") as f:
800+
pickle.dump(state, f)
801+
os.rename(str(tmp), str(checkpoint_path))
802+
803+
self.log(f"Checkpoint saved: {checkpoint_path}")
804+
self.log(
805+
f"To resume from this checkpoint, set HELION_AUTOTUNE_CHECKPOINT_ID={new_checkpoint_id} "
806+
f'or `autotune_checkpoint_id="{new_checkpoint_id}"` in the kernel settings'
807+
)
808+
return checkpoint_path
809+
810+
def state_dict(self) -> dict[str, Any]:
811+
"""
812+
Return autotuner state as a dictionary.
813+
814+
Subclasses should call super().state_dict() first, then update with their own fields.
815+
"""
816+
import numpy as np
817+
818+
from .local_cache import LocalAutotuneCache
819+
820+
rng_state: dict[str, Any] = {
821+
"random": random.getstate(),
822+
"torch": torch.random.get_rng_state(),
823+
"numpy": np.random.get_state(), # noqa: NPY002
824+
}
825+
if torch.cuda.is_available():
826+
rng_state["torch_cuda"] = torch.cuda.get_rng_state()
827+
828+
return {
829+
"algorithm": self.__class__.__name__,
830+
"cache_key_stable_hash": LocalAutotuneCache(self)
831+
._generate_key()
832+
.stable_hash(),
833+
"counters": dict(self.counters),
834+
"rng_state": rng_state,
835+
"best_perf_so_far": self.best_perf_so_far,
836+
"current_generation": self._current_generation,
837+
}
838+
839+
def load_state_dict(self, state: dict[str, Any]) -> None:
840+
"""
841+
Restore autotuner state from a dictionary.
842+
843+
Subclasses should call super().load_state_dict(state) first,
844+
then restore their own fields.
845+
"""
846+
from .local_cache import LocalAutotuneCache
847+
848+
current_hash = LocalAutotuneCache(self)._generate_key().stable_hash()
849+
if state.get("cache_key_stable_hash") != current_hash:
850+
raise exc.CheckpointError(
851+
"State dict is incompatible: kernel, hardware, or input shapes may have changed"
852+
)
853+
854+
import numpy as np
855+
856+
# Restore RNG state
857+
rng_state = state["rng_state"]
858+
random.setstate(rng_state["random"])
859+
torch.random.set_rng_state(rng_state["torch"])
860+
np.random.set_state(rng_state["numpy"]) # noqa: NPY002
861+
if "torch_cuda" in rng_state and torch.cuda.is_available():
862+
torch.cuda.set_rng_state(rng_state["torch_cuda"])
863+
864+
# Restore autotuner state
865+
self.counters = collections.Counter(state["counters"])
866+
self.best_perf_so_far = state["best_perf_so_far"]
867+
self._current_generation = state["current_generation"]
868+
715869

716870
@dataclasses.dataclass
717871
class PopulationMember:
@@ -817,6 +971,8 @@ def best(self) -> PopulationMember:
817971

818972
def set_generation(self, generation: int) -> None:
819973
self._current_generation = generation
974+
if generation > 0:
975+
self.save_checkpoint()
820976

821977
def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
822978
"""
@@ -970,6 +1126,49 @@ def statistics(self) -> str:
9701126
"""
9711127
return population_statistics(self.population)
9721128

1129+
def state_dict(self) -> dict[str, Any]:
1130+
state = super().state_dict()
1131+
# Serialize population (excluding fn which will be recompiled on load)
1132+
population_state = []
1133+
for member in self.population:
1134+
population_state.append(
1135+
{
1136+
"perfs": member.perfs,
1137+
"flat_values": member.flat_values,
1138+
"config": member.config,
1139+
"status": member.status,
1140+
"compile_time": member.compile_time,
1141+
}
1142+
)
1143+
state["population"] = population_state
1144+
return state
1145+
1146+
def load_state_dict(self, state: dict[str, Any]) -> None:
1147+
super().load_state_dict(state)
1148+
1149+
# Restore population
1150+
self.population = []
1151+
for member_state in state["population"]:
1152+
member = PopulationMember(
1153+
fn=_unset_fn,
1154+
perfs=member_state["perfs"],
1155+
flat_values=member_state["flat_values"],
1156+
config=member_state["config"],
1157+
status=member_state["status"],
1158+
compile_time=member_state.get("compile_time"),
1159+
)
1160+
self.population.append(member)
1161+
1162+
# Recompile kernel functions for all population members
1163+
for member in self.population:
1164+
if member.fn is _unset_fn and member.status == "ok":
1165+
try:
1166+
member.fn = self.kernel.compile_config(
1167+
member.config, allow_print=False
1168+
)
1169+
except Exception:
1170+
member.fn = _unset_fn
1171+
9731172

9741173
def population_statistics(population: list[PopulationMember]) -> str:
9751174
"""

0 commit comments

Comments
 (0)