|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import math |
| 4 | +from typing import TYPE_CHECKING |
| 5 | + |
| 6 | +from .. import exc |
| 7 | +from .base_search import FlatConfig |
| 8 | +from .base_search import PopulationBasedSearch |
| 9 | +from .base_search import PopulationMember |
| 10 | +from .base_search import performance |
| 11 | + |
| 12 | +if TYPE_CHECKING: |
| 13 | + from collections.abc import Iterator |
| 14 | + from collections.abc import Sequence |
| 15 | + |
| 16 | + from ..runtime.config import Config |
| 17 | + from ..runtime.kernel import BoundKernel |
| 18 | + |
| 19 | + |
| 20 | +class PatternSearch(PopulationBasedSearch): |
| 21 | + """Search that explores single-parameter perturbations around the current best.""" |
| 22 | + |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + kernel: BoundKernel, |
| 26 | + args: Sequence[object], |
| 27 | + *, |
| 28 | + initial_population: int = 200, |
| 29 | + copies: int = 5, |
| 30 | + max_generations: int = 100, |
| 31 | + ) -> None: |
| 32 | + """ |
| 33 | + Create a PatternSearch autotuner. |
| 34 | +
|
| 35 | + Args: |
| 36 | + kernel: The kernel to be autotuned. |
| 37 | + args: The arguments to be passed to the kernel. |
| 38 | + initial_population: The number of random configurations to generate for the initial population. |
| 39 | + copies: Count of top Configs to run pattern search on. |
| 40 | + max_generations: The maximum number of generations to run. |
| 41 | + """ |
| 42 | + super().__init__(kernel, args) |
| 43 | + self.initial_population = initial_population |
| 44 | + self.copies = copies |
| 45 | + self.max_generations = max_generations |
| 46 | + |
| 47 | + def _autotune(self) -> Config: |
| 48 | + self.log( |
| 49 | + f"Starting PatternSearch with initial_population={self.initial_population}, copies={self.copies}" |
| 50 | + ) |
| 51 | + visited = set() |
| 52 | + self.population = [] |
| 53 | + for flat_config in self.config_gen.random_population_flat( |
| 54 | + self.initial_population |
| 55 | + ): |
| 56 | + member = self.make_unbenchmarked(flat_config) |
| 57 | + if member.config not in visited: |
| 58 | + visited.add(member.config) |
| 59 | + self.population.append(member) |
| 60 | + self.parallel_benchmark_population(self.population) |
| 61 | + # again with higher accuracy |
| 62 | + self.rebenchmark_population(self.population) |
| 63 | + self.population.sort(key=performance) |
| 64 | + starting_points = [] |
| 65 | + for member in self.population[: self.copies]: |
| 66 | + if math.isfinite(member.perf): # filter failed compiles |
| 67 | + starting_points.append(member) |
| 68 | + self.log( |
| 69 | + f"Initial random population of {len(self.population)}, {len(starting_points)} starting points:", |
| 70 | + self.statistics, |
| 71 | + ) |
| 72 | + if not starting_points: |
| 73 | + raise exc.NoConfigFound |
| 74 | + |
| 75 | + search_copies = [self._pattern_search_from(m, visited) for m in starting_points] |
| 76 | + for generation in range(1, self.max_generations + 1): |
| 77 | + prior_best = self.best |
| 78 | + new_population = {id(prior_best): prior_best} |
| 79 | + num_neighbors = 0 |
| 80 | + num_active = 0 |
| 81 | + for search_copy in search_copies: |
| 82 | + added = next(search_copy, ()) |
| 83 | + if added: |
| 84 | + assert len(added) > 1 |
| 85 | + num_active += 1 |
| 86 | + num_neighbors += len(added) - 1 |
| 87 | + for member in added: |
| 88 | + new_population[id(member)] = member |
| 89 | + if num_active == 0: |
| 90 | + break |
| 91 | + self.population = [*new_population.values()] |
| 92 | + # compile any unbenchmarked members in parallel |
| 93 | + self.parallel_benchmark_population( |
| 94 | + [m for m in self.population if len(m.perfs) == 0] |
| 95 | + ) |
| 96 | + # higher-accuracy rebenchmark |
| 97 | + self.rebenchmark_population(self.population) |
| 98 | + self.log( |
| 99 | + f"Generation {generation}, {num_neighbors} neighbors, {num_active} active:", |
| 100 | + self.statistics, |
| 101 | + ) |
| 102 | + return self.best.config |
| 103 | + |
| 104 | + def _pattern_search_from( |
| 105 | + self, current: PopulationMember, visited: set[Config] |
| 106 | + ) -> Iterator[list[PopulationMember]]: |
| 107 | + """ |
| 108 | + Run a single copy of pattern search from the given starting point. |
| 109 | +
|
| 110 | + We use a generator and yield the new population at each generation so that we can |
| 111 | + run multiple copies of pattern search in parallel. |
| 112 | + """ |
| 113 | + for _ in range(self.max_generations): |
| 114 | + candidates = [current] |
| 115 | + for flat_config in self._generate_neighbors(current.flat_values): |
| 116 | + new_member = self.make_unbenchmarked(flat_config) |
| 117 | + if new_member.config not in visited: |
| 118 | + visited.add(new_member.config) |
| 119 | + candidates.append(new_member) |
| 120 | + if len(candidates) <= 1: |
| 121 | + return # no new candidates, stop searching |
| 122 | + yield candidates # yield new population to benchmark in parallel |
| 123 | + best = min(candidates, key=performance) |
| 124 | + if best is current: |
| 125 | + return # no improvement, stop searching |
| 126 | + current = best |
| 127 | + |
| 128 | + def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]: |
| 129 | + """ |
| 130 | + Generate neighboring configurations by changing one or two parameters at a time. |
| 131 | + """ |
| 132 | + candidates_by_index = [ |
| 133 | + spec.pattern_neighbors(base[index]) |
| 134 | + for index, spec in enumerate(self.config_gen.flat_spec) |
| 135 | + ] |
| 136 | + assert len(candidates_by_index) == len(base) |
| 137 | + neighbors: list[FlatConfig] = [] |
| 138 | + |
| 139 | + # Add all single-parameter changes |
| 140 | + for index, candidates in enumerate(candidates_by_index): |
| 141 | + for candidate_value in candidates: |
| 142 | + new_flat = [*base] |
| 143 | + new_flat[index] = candidate_value |
| 144 | + neighbors.append(new_flat) |
| 145 | + |
| 146 | + # Block sizes are important enough to try pairs of changes at a time |
| 147 | + block_indices = self.config_gen.block_size_indices |
| 148 | + for i_pos, first in enumerate(block_indices): |
| 149 | + first_candidates = candidates_by_index[first] |
| 150 | + if not first_candidates: |
| 151 | + continue |
| 152 | + for second in block_indices[i_pos + 1 :]: |
| 153 | + second_candidates = candidates_by_index[second] |
| 154 | + if not second_candidates: |
| 155 | + continue |
| 156 | + for first_value in first_candidates: |
| 157 | + for second_value in second_candidates: |
| 158 | + new_flat = [*base] |
| 159 | + new_flat[first] = first_value |
| 160 | + new_flat[second] = second_value |
| 161 | + neighbors.append(new_flat) |
| 162 | + |
| 163 | + return neighbors |
0 commit comments