Skip to content

Commit 66cbf6a

Browse files
authored
Add PatternSearch autotuning algorithm (#696)
1 parent 257d3c9 commit 66cbf6a

File tree

5 files changed

+308
-3
lines changed

5 files changed

+308
-3
lines changed

helion/autotuner/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,12 @@
1111
from .finite_search import FiniteSearch as FiniteSearch
1212
from .local_cache import LocalAutotuneCache as LocalAutotuneCache
1313
from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache
14+
from .pattern_search import PatternSearch as PatternSearch
1415
from .random_search import RandomSearch as RandomSearch
16+
17+
search_algorithms = {
18+
"DifferentialEvolutionSearch": DifferentialEvolutionSearch,
19+
"FiniteSearch": FiniteSearch,
20+
"PatternSearch": PatternSearch,
21+
"RandomSearch": RandomSearch,
22+
}

helion/autotuner/config_fragment.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import dataclasses
44
import enum
55
import random
6+
from typing import Iterable
67
from typing import TypeGuard
8+
from typing import cast
79

810
from ..exc import InvalidConfig
911

@@ -36,6 +38,10 @@ def random(self) -> object:
3638
"""Return the default value for this fragment."""
3739
raise NotImplementedError
3840

41+
def pattern_neighbors(self, current: object) -> list[object]:
42+
"""Return neighbors for PatternSearch."""
43+
raise NotImplementedError
44+
3945
def differential_mutation(self, a: object, b: object, c: object) -> object:
4046
"""Create a new value by combining a, b, and c with something like: `a + (b - c)`"""
4147
if b == c:
@@ -62,6 +68,24 @@ def default(self) -> list[int]:
6268
def random(self) -> list[int]:
6369
return random.sample(range(self.length), self.length)
6470

71+
def pattern_neighbors(self, current: object) -> list[object]:
72+
sequence = list(cast("Iterable[int]", current))
73+
if len(sequence) != self.length:
74+
raise ValueError(
75+
f"Expected permutation of length {self.length}, got {len(sequence)}"
76+
)
77+
if {*sequence} != {*range(self.length)}:
78+
raise ValueError(
79+
f"Expected permutation of range({self.length}), got {sequence!r}"
80+
)
81+
neighbors: list[object] = []
82+
for i in range(self.length):
83+
for j in range(i + 1, self.length):
84+
swapped = [*sequence]
85+
swapped[i], swapped[j] = swapped[j], swapped[i]
86+
neighbors.append(swapped)
87+
return neighbors
88+
6589

6690
@dataclasses.dataclass
6791
class BaseIntegerFragment(ConfigSpecFragment):
@@ -85,13 +109,37 @@ def clamp(self, val: int) -> int:
85109
def get_minimum(self) -> int:
86110
return self.low
87111

112+
def pattern_neighbors(self, current: object) -> list[object]:
113+
if type(current) is not int: # bool is not allowed
114+
raise TypeError(f"Expected int, got {type(current).__name__}")
115+
neighbors: list[object] = []
116+
lower = current - 1
117+
upper = current + 1
118+
if lower >= self.low:
119+
neighbors.append(lower)
120+
if upper <= self.high:
121+
neighbors.append(upper)
122+
return neighbors
123+
88124

89125
class PowerOfTwoFragment(BaseIntegerFragment):
90126
def random(self) -> int:
91127
assert_integer_power_of_two(self.low)
92128
assert_integer_power_of_two(self.high)
93129
return 2 ** random.randrange(self.low.bit_length() - 1, self.high.bit_length())
94130

131+
def pattern_neighbors(self, current: object) -> list[object]:
132+
if type(current) is not int or current <= 0:
133+
raise TypeError(f"Expected positive power-of-two int, got {current!r}")
134+
neighbors: list[object] = []
135+
lower = current // 2
136+
if lower >= self.low:
137+
neighbors.append(lower)
138+
upper = current * 2
139+
if upper <= self.high:
140+
neighbors.append(upper)
141+
return neighbors
142+
95143
def differential_mutation(self, a: object, b: object, c: object) -> int:
96144
ai = assert_integer_power_of_two(a)
97145
assert isinstance(b, int)
@@ -132,6 +180,11 @@ def default(self) -> object:
132180
def random(self) -> object:
133181
return random.choice(self.choices)
134182

183+
def pattern_neighbors(self, current: object) -> list[object]:
184+
if current not in self.choices:
185+
raise ValueError(f"{current!r} not a valid choice")
186+
return [choice for choice in self.choices if choice != current]
187+
135188
def differential_mutation(self, a: object, b: object, c: object) -> object:
136189
if b == c:
137190
return a
@@ -148,6 +201,11 @@ def default(self) -> bool:
148201
def random(self) -> bool:
149202
return random.choice((False, True))
150203

204+
def pattern_neighbors(self, current: object) -> list[object]:
205+
if type(current) is not bool:
206+
raise TypeError(f"Expected bool, got {type(current).__name__}")
207+
return [not current]
208+
151209
def differential_mutation(self, a: object, b: object, c: object) -> bool:
152210
assert isinstance(a, bool)
153211
if b is c:

helion/autotuner/pattern_search.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from __future__ import annotations
2+
3+
import math
4+
from typing import TYPE_CHECKING
5+
6+
from .. import exc
7+
from .base_search import FlatConfig
8+
from .base_search import PopulationBasedSearch
9+
from .base_search import PopulationMember
10+
from .base_search import performance
11+
12+
if TYPE_CHECKING:
13+
from collections.abc import Iterator
14+
from collections.abc import Sequence
15+
16+
from ..runtime.config import Config
17+
from ..runtime.kernel import BoundKernel
18+
19+
20+
class PatternSearch(PopulationBasedSearch):
21+
"""Search that explores single-parameter perturbations around the current best."""
22+
23+
def __init__(
24+
self,
25+
kernel: BoundKernel,
26+
args: Sequence[object],
27+
*,
28+
initial_population: int = 200,
29+
copies: int = 5,
30+
max_generations: int = 100,
31+
) -> None:
32+
"""
33+
Create a PatternSearch autotuner.
34+
35+
Args:
36+
kernel: The kernel to be autotuned.
37+
args: The arguments to be passed to the kernel.
38+
initial_population: The number of random configurations to generate for the initial population.
39+
copies: Count of top Configs to run pattern search on.
40+
max_generations: The maximum number of generations to run.
41+
"""
42+
super().__init__(kernel, args)
43+
self.initial_population = initial_population
44+
self.copies = copies
45+
self.max_generations = max_generations
46+
47+
def _autotune(self) -> Config:
48+
self.log(
49+
f"Starting PatternSearch with initial_population={self.initial_population}, copies={self.copies}"
50+
)
51+
visited = set()
52+
self.population = []
53+
for flat_config in self.config_gen.random_population_flat(
54+
self.initial_population
55+
):
56+
member = self.make_unbenchmarked(flat_config)
57+
if member.config not in visited:
58+
visited.add(member.config)
59+
self.population.append(member)
60+
self.parallel_benchmark_population(self.population)
61+
# again with higher accuracy
62+
self.rebenchmark_population(self.population)
63+
self.population.sort(key=performance)
64+
starting_points = []
65+
for member in self.population[: self.copies]:
66+
if math.isfinite(member.perf): # filter failed compiles
67+
starting_points.append(member)
68+
self.log(
69+
f"Initial random population of {len(self.population)}, {len(starting_points)} starting points:",
70+
self.statistics,
71+
)
72+
if not starting_points:
73+
raise exc.NoConfigFound
74+
75+
search_copies = [self._pattern_search_from(m, visited) for m in starting_points]
76+
for generation in range(1, self.max_generations + 1):
77+
prior_best = self.best
78+
new_population = {id(prior_best): prior_best}
79+
num_neighbors = 0
80+
num_active = 0
81+
for search_copy in search_copies:
82+
added = next(search_copy, ())
83+
if added:
84+
assert len(added) > 1
85+
num_active += 1
86+
num_neighbors += len(added) - 1
87+
for member in added:
88+
new_population[id(member)] = member
89+
if num_active == 0:
90+
break
91+
self.population = [*new_population.values()]
92+
# compile any unbenchmarked members in parallel
93+
self.parallel_benchmark_population(
94+
[m for m in self.population if len(m.perfs) == 0]
95+
)
96+
# higher-accuracy rebenchmark
97+
self.rebenchmark_population(self.population)
98+
self.log(
99+
f"Generation {generation}, {num_neighbors} neighbors, {num_active} active:",
100+
self.statistics,
101+
)
102+
return self.best.config
103+
104+
def _pattern_search_from(
105+
self, current: PopulationMember, visited: set[Config]
106+
) -> Iterator[list[PopulationMember]]:
107+
"""
108+
Run a single copy of pattern search from the given starting point.
109+
110+
We use a generator and yield the new population at each generation so that we can
111+
run multiple copies of pattern search in parallel.
112+
"""
113+
for _ in range(self.max_generations):
114+
candidates = [current]
115+
for flat_config in self._generate_neighbors(current.flat_values):
116+
new_member = self.make_unbenchmarked(flat_config)
117+
if new_member.config not in visited:
118+
visited.add(new_member.config)
119+
candidates.append(new_member)
120+
if len(candidates) <= 1:
121+
return # no new candidates, stop searching
122+
yield candidates # yield new population to benchmark in parallel
123+
best = min(candidates, key=performance)
124+
if best is current:
125+
return # no improvement, stop searching
126+
current = best
127+
128+
def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]:
129+
"""
130+
Generate neighboring configurations by changing one or two parameters at a time.
131+
"""
132+
candidates_by_index = [
133+
spec.pattern_neighbors(base[index])
134+
for index, spec in enumerate(self.config_gen.flat_spec)
135+
]
136+
assert len(candidates_by_index) == len(base)
137+
neighbors: list[FlatConfig] = []
138+
139+
# Add all single-parameter changes
140+
for index, candidates in enumerate(candidates_by_index):
141+
for candidate_value in candidates:
142+
new_flat = [*base]
143+
new_flat[index] = candidate_value
144+
neighbors.append(new_flat)
145+
146+
# Block sizes are important enough to try pairs of changes at a time
147+
block_indices = self.config_gen.block_size_indices
148+
for i_pos, first in enumerate(block_indices):
149+
first_candidates = candidates_by_index[first]
150+
if not first_candidates:
151+
continue
152+
for second in block_indices[i_pos + 1 :]:
153+
second_candidates = candidates_by_index[second]
154+
if not second_candidates:
155+
continue
156+
for first_value in first_candidates:
157+
for second_value in second_candidates:
158+
new_flat = [*base]
159+
new_flat[first] = first_value
160+
new_flat[second] = second_value
161+
neighbors.append(new_flat)
162+
163+
return neighbors

helion/runtime/settings.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,17 @@ def __exit__(self, *args: object) -> None:
6363
def default_autotuner_fn(
6464
bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
6565
) -> BaseAutotuner:
66-
from ..autotuner import DifferentialEvolutionSearch
6766
from ..autotuner import LocalAutotuneCache
68-
69-
return LocalAutotuneCache(DifferentialEvolutionSearch(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]
67+
from ..autotuner import search_algorithms
68+
69+
autotuner_name = os.environ.get("HELION_AUTOTUNER", "PatternSearch")
70+
autotuner_cls = search_algorithms.get(autotuner_name)
71+
if autotuner_cls is None:
72+
raise ValueError(
73+
f"Unknown HELION_AUTOTUNER value: {autotuner_name}, valid options are: "
74+
f"{', '.join(search_algorithms.keys())}"
75+
)
76+
return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]
7077

7178

7279
def _get_autotune_random_seed() -> int:

test/test_autotuner.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from pathlib import Path
77
import random
88
import tempfile
9+
from types import SimpleNamespace
910
import unittest
11+
from unittest import skip
1012
from unittest.mock import patch
1113

1214
import pytest
@@ -20,6 +22,11 @@
2022
from helion._testing import import_path
2123
from helion._testing import skipIfRocm
2224
from helion.autotuner import DifferentialEvolutionSearch
25+
from helion.autotuner import PatternSearch
26+
from helion.autotuner.config_fragment import BooleanFragment
27+
from helion.autotuner.config_fragment import EnumFragment
28+
from helion.autotuner.config_fragment import IntegerFragment
29+
from helion.autotuner.config_fragment import PowerOfTwoFragment
2330
from helion.autotuner.config_generation import ConfigGeneration
2431
from helion.autotuner.finite_search import FiniteSearch
2532
from helion.autotuner.random_search import RandomSearch
@@ -174,6 +181,68 @@ def test_differential_evolution_search(self):
174181
fn = bound_kernel.compile_config(best)
175182
torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1)
176183

184+
@skip("too slow")
185+
def test_pattern_search(self):
186+
args = (
187+
torch.randn([64, 64], device=DEVICE),
188+
torch.randn([64, 64], device=DEVICE),
189+
)
190+
bound_kernel = basic_kernels.add.bind(args)
191+
random.seed(123)
192+
best = PatternSearch(
193+
bound_kernel, args, initial_population=10, max_generations=2, copies=1
194+
).autotune()
195+
fn = bound_kernel.compile_config(best)
196+
torch.testing.assert_close(fn(*args), sum(args), rtol=1e-2, atol=1e-1)
197+
198+
def test_pattern_search_neighbor_values(self):
199+
self.assertEqual(
200+
PowerOfTwoFragment(1, 128, 32).pattern_neighbors(32),
201+
[16, 64],
202+
)
203+
self.assertEqual(
204+
sorted(IntegerFragment(1, 5, 3).pattern_neighbors(3)),
205+
[2, 4],
206+
)
207+
self.assertEqual(BooleanFragment().pattern_neighbors(True), [False])
208+
self.assertEqual(
209+
sorted(EnumFragment(("a", "b", "c")).pattern_neighbors("b")),
210+
["a", "c"],
211+
)
212+
213+
def test_pattern_search_block_size_pair_neighbors(self):
214+
search = PatternSearch.__new__(PatternSearch)
215+
search._visited = set()
216+
search.config_gen = SimpleNamespace(
217+
flat_spec=[
218+
PowerOfTwoFragment(16, 128, 32),
219+
PowerOfTwoFragment(16, 128, 64),
220+
EnumFragment(("a", "b")),
221+
],
222+
block_size_indices=[0, 1],
223+
)
224+
225+
base = [32, 64, "a"]
226+
neighbors = search._generate_neighbors(base)
227+
228+
def diff_count(flat):
229+
return sum(
230+
1
231+
for current, original in zip(flat, base, strict=False)
232+
if current != original
233+
)
234+
235+
pair_neighbors = [
236+
flat for flat in neighbors if diff_count(flat) == 2 and flat[2] == "a"
237+
]
238+
expected = [
239+
[16, 32, "a"],
240+
[16, 128, "a"],
241+
[64, 32, "a"],
242+
[64, 128, "a"],
243+
]
244+
self.assertEqual(sorted(pair_neighbors), sorted(expected))
245+
177246
def test_accuracy_check_filters_bad_config_wrong_output(self) -> None:
178247
bad_config = helion.Config(block_sizes=[1], num_warps=8)
179248
good_config = helion.Config(block_sizes=[1], num_warps=4)

0 commit comments

Comments
 (0)