Skip to content

Commit bd64602

Browse files
authored
Autotune persistent kernels for multi occupancy (#1307)
1 parent ddc9f96 commit bd64602

File tree

8 files changed

+550
-42
lines changed

8 files changed

+550
-42
lines changed

helion/_compiler/device_function.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,13 @@ def codegen_function_call(self) -> ast.AST:
693693
for key in ("waves_per_eu", "matrix_instr_nonkdim"):
694694
if key in self.config:
695695
args.append(f"{key}={self.config[key]}")
696+
# Only pass maxnreg if it's set to a non-None value and not on AMD
697+
if (
698+
"maxnreg" in self.config
699+
and self.config["maxnreg"] is not None
700+
and torch.version.hip is None
701+
):
702+
args.append(f"maxnreg={self.config['maxnreg']}")
696703
pid = self.pid
697704
assert pid is not None
698705
# TODO(jansel): we should run CSE this statement

helion/_compiler/program_id.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,14 @@ def __init__(self, is_blocked: bool = False) -> None:
583583
device_function = DeviceFunction.current()
584584
self.virtual_pid_var: str = device_function.new_var("virtual_pid")
585585
self.total_pids_var: str = device_function.new_var("total_pids")
586+
# Get num_sm_multiplier from config for multi-occupancy support
587+
# pyrefly: ignore [bad-assignment]
588+
self.num_sm_multiplier: int = device_function.config.get("num_sm_multiplier", 1)
589+
# Compute grid size expression based on multiplier
590+
if self.num_sm_multiplier == 1:
591+
self.grid_size_expr: str = NUM_SM_VAR
592+
else:
593+
self.grid_size_expr = f"({NUM_SM_VAR} * {self.num_sm_multiplier})"
586594
# Generate variables and range expression based on strategy type
587595
if self.is_blocked:
588596
self.block_size_var: str = device_function.new_var("block_size")
@@ -596,7 +604,7 @@ def __init__(self, is_blocked: bool = False) -> None:
596604
self.range_kwargs: dict[str, str] = {
597605
"begin": typed_program_id(0),
598606
"end": self.total_pids_var,
599-
"step": NUM_SM_VAR,
607+
"step": self.grid_size_expr,
600608
}
601609
if device_function.constexpr_arg(NUM_SM_VAR):
602610
reserved_sms = CompileEnvironment.current().settings.persistent_reserved_sms
@@ -619,8 +627,8 @@ def get_device_str(self) -> str:
619627
return f"torch.{device!r}"
620628

621629
def codegen_grid(self) -> ast.AST:
622-
# Use num_sms for persistent kernels
623-
return expr_from_string(f"({NUM_SM_VAR},)")
630+
# Use num_sms * multiplier for persistent kernels (multi-occupancy)
631+
return expr_from_string(f"({self.grid_size_expr},)")
624632

625633
def setup_persistent_kernel(
626634
self, device_function: DeviceFunction, total_pids_expr: str | None = None
@@ -641,7 +649,7 @@ def setup_persistent_kernel(
641649
assignments = [
642650
(
643651
self.block_size_var,
644-
f"tl.cdiv({self.total_pids_var}, {NUM_SM_VAR})",
652+
f"tl.cdiv({self.total_pids_var}, {self.grid_size_expr})",
645653
),
646654
(
647655
self.start_pid_var,

helion/autotuner/config_spec.py

Lines changed: 104 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING
77
from typing import cast
88

9+
import torch
910
from torch._inductor.runtime.runtime_utils import next_power_of_2
1011

1112
from .._compat import supports_amd_cdna_tunables
@@ -52,12 +53,21 @@
5253
"num_warps",
5354
"num_stages",
5455
"pid_type",
56+
"num_sm_multiplier",
57+
"maxnreg",
5558
"indexing",
5659
"load_eviction_policies",
5760
*AMD_CDNA_TUNABLES,
5861
]
5962
)
6063
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
64+
MIN_NUM_SM_MULTIPLIER = 1
65+
MAX_NUM_SM_MULTIPLIER = 128
66+
DEFAULT_NUM_SM_MULTIPLIER = 1
67+
# maxnreg values: None means no limit, otherwise limit to this many registers per thread
68+
# Lower values allow higher occupancy but may hurt performance for register-heavy kernels
69+
VALID_MAXNREG = (None, 32, 64, 128, 256)
70+
DEFAULT_MAXNREG = None
6171
VALID_EVICTION_POLICIES = ("", "first", "last")
6272
VALID_WAVES_PER_EU = (1, 2, 3, 4)
6373
VALID_MATRIX_INSTR_NONKDIM = (0, 16, 32)
@@ -158,10 +168,18 @@ def disallow_pid_type(self, pid_type: PidTypeLiteral) -> None:
158168
)
159169
assert self.allowed_pid_types
160170

161-
def normalize(self, config: helion.Config | dict[str, object]) -> None:
162-
"""Normalize the config to match the block_sizes and validate the config."""
171+
def normalize(
172+
self, config: helion.Config | dict[str, object], *, _fix_invalid: bool = False
173+
) -> None:
174+
"""Normalize the config to match the block_sizes and validate the config.
175+
176+
Args:
177+
config: The config to normalize (modified in place).
178+
_fix_invalid: If True, silently fix invalid combinations instead of raising
179+
errors. Used internally during autotuning config generation.
180+
"""
163181
if isinstance(config, helion.Config):
164-
self.normalize(config.config)
182+
self.normalize(config.config, _fix_invalid=_fix_invalid)
165183
return
166184

167185
for name in (
@@ -250,19 +268,84 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
250268
elif key in config:
251269
raise InvalidConfig(f"{key} is not supported on this target hardware")
252270

253-
# TODO(jansel): include num_ctas and max_nreg
271+
if "pid_type" in config:
272+
if config["pid_type"] not in VALID_PID_TYPES:
273+
raise InvalidConfig(
274+
f"Invalid value for 'pid_type': {config['pid_type']!r} must be one of {list(VALID_PID_TYPES)!r}"
275+
)
276+
else:
277+
config["pid_type"] = VALID_PID_TYPES[0]
278+
279+
# Validate num_sm_multiplier is a power of two in range
280+
if "num_sm_multiplier" in config:
281+
val = config["num_sm_multiplier"]
282+
if (
283+
not isinstance(val, int)
284+
or val < MIN_NUM_SM_MULTIPLIER
285+
or val > MAX_NUM_SM_MULTIPLIER
286+
or (val & (val - 1)) != 0 # not a power of two
287+
):
288+
raise InvalidConfig(
289+
f"Invalid value for 'num_sm_multiplier': {val!r} must be a power of two between {MIN_NUM_SM_MULTIPLIER} and {MAX_NUM_SM_MULTIPLIER}"
290+
)
291+
else:
292+
config["num_sm_multiplier"] = DEFAULT_NUM_SM_MULTIPLIER
254293

255-
for name, values in (("pid_type", VALID_PID_TYPES),):
256-
if name in config:
257-
if config[name] not in values:
294+
# Only validate maxnreg on non-AMD devices (not supported on AMD)
295+
if torch.version.hip is None:
296+
if "maxnreg" in config:
297+
if config["maxnreg"] not in VALID_MAXNREG:
258298
raise InvalidConfig(
259-
f"Invalid value for {name!r}: {config[name]!r} must be one of {[*values]!r}"
299+
f"Invalid value for 'maxnreg': {config['maxnreg']!r} must be one of {list(VALID_MAXNREG)!r}"
260300
)
261301
else:
262-
config[name] = values[0]
302+
config["maxnreg"] = VALID_MAXNREG[0]
303+
else:
304+
# Remove maxnreg on AMD if present
305+
config.pop("maxnreg", None)
263306

264-
# Set default values for grid indices when pid_type is not persistent
307+
# Handle num_sm_multiplier and maxnreg for non-persistent pid_types
308+
# These options only make sense for persistent kernels
265309
pid_type = config["pid_type"]
310+
if pid_type in ("flat", "xyz"):
311+
# Handle num_sm_multiplier
312+
num_sm_multiplier = config.get(
313+
"num_sm_multiplier", DEFAULT_NUM_SM_MULTIPLIER
314+
)
315+
if num_sm_multiplier != DEFAULT_NUM_SM_MULTIPLIER:
316+
if _fix_invalid:
317+
# Silently fix during autotuning config generation
318+
config.pop("num_sm_multiplier", None)
319+
else:
320+
# Raise error for user-specified invalid combinations
321+
raise InvalidConfig(
322+
f"num_sm_multiplier={num_sm_multiplier} can only be used with persistent "
323+
f"pid_type ('persistent_blocked' or 'persistent_interleaved'), "
324+
f"got pid_type={pid_type!r}"
325+
)
326+
else:
327+
# Remove default value from config
328+
config.pop("num_sm_multiplier", None)
329+
330+
# Handle maxnreg - only makes sense for persistent kernels (and only on non-AMD)
331+
if torch.version.hip is None:
332+
maxnreg = config.get("maxnreg", DEFAULT_MAXNREG)
333+
if maxnreg != DEFAULT_MAXNREG:
334+
if _fix_invalid:
335+
# Silently fix during autotuning config generation
336+
config.pop("maxnreg", None)
337+
else:
338+
# Raise error for user-specified invalid combinations
339+
raise InvalidConfig(
340+
f"maxnreg={maxnreg} can only be used with persistent "
341+
f"pid_type ('persistent_blocked' or 'persistent_interleaved'), "
342+
f"got pid_type={pid_type!r}"
343+
)
344+
else:
345+
# Remove default value from config
346+
config.pop("maxnreg", None)
347+
348+
# Set default values for grid indices when pid_type is not persistent
266349
if pid_type in ("flat", "xyz") and self.grid_block_ids:
267350
for name, mapping in (
268351
("range_unroll_factors", self.range_unroll_factors),
@@ -322,8 +405,18 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
322405
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
323406
"indexing": fn(self.indexing),
324407
"pid_type": fn(EnumFragment(self.allowed_pid_types)),
408+
"num_sm_multiplier": fn(
409+
PowerOfTwoFragment(
410+
MIN_NUM_SM_MULTIPLIER,
411+
MAX_NUM_SM_MULTIPLIER,
412+
DEFAULT_NUM_SM_MULTIPLIER,
413+
)
414+
),
325415
"load_eviction_policies": fn(self.load_eviction_policies),
326416
}
417+
# Only include maxnreg on non-AMD devices (not supported on AMD)
418+
if torch.version.hip is None:
419+
config["maxnreg"] = fn(EnumFragment(VALID_MAXNREG))
327420
# Add tunable parameters
328421
config.update(
329422
{key: fn(fragment) for key, fragment in self.user_defined_tunables.items()}
@@ -345,7 +438,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
345438
):
346439
if not config.get(name):
347440
config.pop(name, None)
348-
self.normalize(config)
441+
self.normalize(config, _fix_invalid=True)
349442
# pyrefly: ignore [bad-argument-type]
350443
return helion.Config(**config)
351444

helion/runtime/config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
IndexingLiteral = Literal["pointer", "tensor_descriptor", "block_ptr"]
1313
PidTypeLiteral = Literal["flat", "xyz", "persistent_blocked", "persistent_interleaved"]
1414
EvictionPolicyLiteral = Literal["", "first", "last"]
15+
NumSmMultiplierLiteral = Literal[1, 2, 4, 8]
16+
MaxnregLiteral = Literal[32, 64, 128, 256] | None
1517

1618

1719
class Config(Mapping[str, object]):
@@ -36,6 +38,8 @@ def __init__(
3638
num_warps: int | None = None,
3739
num_stages: int | None = None,
3840
pid_type: PidTypeLiteral | None = None,
41+
num_sm_multiplier: NumSmMultiplierLiteral | None = None,
42+
maxnreg: MaxnregLiteral | None = None,
3943
indexing: IndexingLiteral | list[IndexingLiteral] | None = None,
4044
# For user-defined properties
4145
**kwargs: object,
@@ -58,6 +62,11 @@ def __init__(
5862
num_warps: Number of warps per block.
5963
num_stages: Number of stages for software pipelining.
6064
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
65+
num_sm_multiplier: Multiplier for the number of SMs in persistent kernels (1, 2, 4, 8).
66+
Controls multi-occupancy by launching N * num_sms thread blocks instead of just num_sms.
67+
maxnreg: Maximum number of registers per thread (None, 32, 64, 128, 256).
68+
Lower values allow higher occupancy but may hurt performance. Used with persistent kernels
69+
to ensure multi-occupancy can be achieved.
6170
indexing: Indexing strategy for load and store operations. Can be:
6271
- A single strategy string (all loads/stores use this strategy):
6372
indexing="block_ptr" # backward compatible
@@ -85,6 +94,8 @@ def __init__(
8594
"num_stages": num_stages,
8695
"indexing": indexing,
8796
"pid_type": pid_type,
97+
"num_sm_multiplier": num_sm_multiplier,
98+
"maxnreg": maxnreg,
8899
}
89100
for key, value in core_props.items():
90101
if value is not None:
@@ -182,6 +193,20 @@ def l2_groupings(self) -> list[int]:
182193
def pid_type(self) -> PidTypeLiteral:
183194
return cast("PidTypeLiteral", self.config.get("pid_type", "flat"))
184195

196+
@property
197+
def num_sm_multiplier(self) -> int:
198+
from ..autotuner.config_spec import DEFAULT_NUM_SM_MULTIPLIER
199+
200+
return cast(
201+
"int", self.config.get("num_sm_multiplier", DEFAULT_NUM_SM_MULTIPLIER)
202+
)
203+
204+
@property
205+
def maxnreg(self) -> int | None:
206+
from ..autotuner.config_spec import DEFAULT_MAXNREG
207+
208+
return cast("int | None", self.config.get("maxnreg", DEFAULT_MAXNREG))
209+
185210
@property
186211
def range_unroll_factors(self) -> list[int]:
187212
return cast("list[int]", self.config.get("range_unroll_factors", []))

0 commit comments

Comments
 (0)