Skip to content

Commit 2a2d2b2

Browse files
authored
Support hl.barrier for mega-kernels (#1151)
1 parent 6ba8ef9 commit 2a2d2b2

22 files changed

+1354
-67
lines changed

examples/split_k_barrier.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
Helion Split-K Matmul with Barrier Example
3+
==========================================
4+
This example demonstrates a two-stage split-K matrix multiplication using
5+
hl.barrier() for grid-wide synchronization. The barrier approach ensures
6+
deterministic results as opposed to atomic_add approaches.
7+
8+
Stage 1: Split K dimension into chunks and compute partial products
9+
Barrier: Grid-wide synchronization to ensure all partials are written
10+
Stage 2: Reduce partials across the split dimension
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import torch
16+
17+
import helion
18+
from helion._testing import DEVICE
19+
from helion._testing import run_example
20+
from helion.autotuner import PowerOfTwoFragment
21+
import helion.language as hl
22+
23+
24+
@helion.kernel(static_shapes=True, dot_precision="ieee")
25+
def split_k_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
26+
"""
27+
Two-stage split-K matmul using hl.barrier(). The barrier approach
28+
gives deterministic results as opposed to the atomic_add approach.
29+
30+
Stage 1:
31+
- Split K into `split_k` contiguous chunks.
32+
- Each chunk computes a partial [tile_m, tile_n] product into its own slice of `tmp`.
33+
34+
Barrier:
35+
- Grid-wide barrier to ensure all partials are written before reduction.
36+
37+
Stage 2:
38+
- Reduce partials across the split dimension and write `out`.
39+
40+
Shapes:
41+
a: [M, K]
42+
b: [K, N]
43+
tmp: [M, N, split_k]
44+
out: [M, N]
45+
46+
Notes:
47+
- Static shapes keep codegen simpler.
48+
- `split_k` is fixed for clarity; autotuning could choose it instead.
49+
"""
50+
m, k = a.shape
51+
_, n = b.shape
52+
split_k = hl.register_tunable("split_k", PowerOfTwoFragment(16, 512, 64))
53+
block_k = helion.next_power_of_2(helion.cdiv(k, split_k))
54+
tmp = torch.zeros((m, n, split_k), device=a.device, dtype=a.dtype)
55+
out = torch.empty((m, n), device=a.device, dtype=a.dtype)
56+
57+
for tile_m, tile_n, tile_k_outer in hl.tile(
58+
[m, n, k], block_size=[None, None, block_k]
59+
):
60+
acc = hl.zeros([tile_m, tile_n], device=a.device, dtype=a.dtype)
61+
for tile_k_inner in hl.tile(tile_k_outer.begin, tile_k_outer.end):
62+
acc = torch.addmm(acc, a[tile_m, tile_k_inner], b[tile_k_inner, tile_n])
63+
# this could be a hl.atomic_add to avoid the barrier, but that would be non-determinstic
64+
tmp[tile_m, tile_n, tile_k_outer.id] = acc
65+
66+
hl.barrier()
67+
68+
for tile_m, tile_n in hl.tile([m, n]):
69+
out[tile_m, tile_n] = torch.sum(tmp[tile_m, tile_n, :], dim=-1)
70+
71+
return out
72+
73+
74+
def check(m: int, k: int, n: int) -> None:
75+
a = torch.randn(m, k, device=DEVICE)
76+
b = torch.randn(n, k, device=DEVICE).T
77+
78+
run_example(
79+
split_k_matmul,
80+
torch.matmul,
81+
args=(a, b),
82+
atol=5e-1, # long reduction accumulate errors
83+
)
84+
85+
86+
def main() -> None:
87+
torch.manual_seed(0)
88+
check(16, 4096, 16)
89+
90+
91+
if __name__ == "__main__":
92+
main()

helion/_compiler/compile_environment.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
from .. import exc
2929
from ..language.constexpr import ConstExpr
30-
from .loop_dependency_checker import LoopDependencyChecker
3130
from .source_location import SourceLocation
3231
from .source_location import current_location
3332
from .variable_origin import BlockSizeOrigin
@@ -118,19 +117,19 @@ def __init__(
118117
self.block_sizes: list[BlockSizeInfo] = []
119118
self.debug_shape_renames: dict[sympy.Expr, sympy.Expr] = {}
120119
self.config_spec = ConfigSpec()
121-
if settings.autotune_force_persistent:
122-
for pid_type in ("flat", "xyz"):
123-
self.config_spec.disallow_pid_type(pid_type)
124120
self.kernel_tensor_sizes: dict[tuple[sympy.Expr, ...], int] = (
125121
collections.Counter()
126122
)
127123
self.specialized_vars: set[sympy.Symbol] = set()
128124
self.specialized_strides: set[tuple[str, int]] = set()
129-
self.loop_dependency_checker = LoopDependencyChecker()
130125
self._symint_cache: dict[object, torch.SymInt] = {}
131126
self.device_load_count = (
132127
0 # Track number of loads in all device code for eviction policy tuning
133128
)
129+
if settings.autotune_force_persistent:
130+
for pid_type in ("flat", "xyz"):
131+
self.config_spec.disallow_pid_type(pid_type)
132+
self.has_barrier: bool = False
134133

135134
def specialize_expr(self, expr: sympy.Expr) -> sympy.Expr:
136135
"""Substitute any specialized vars with their concrete values."""
@@ -561,7 +560,6 @@ def __enter__(self) -> Self:
561560
assert getattr(tls, "env", None) is None, "CompileEnvironment already active"
562561
self.fake_mode.__enter__()
563562
tls.env = self
564-
self.loop_dependency_checker = LoopDependencyChecker()
565563
return self
566564

567565
def __exit__(

helion/_compiler/device_function.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,12 @@ def __init__(self, val: int) -> None:
198198

199199

200200
class DeviceFunction:
201-
def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
201+
def __init__(
202+
self,
203+
name: str,
204+
config: Config,
205+
codegen: GenerateAST,
206+
) -> None:
202207
super().__init__()
203208
self.name = name
204209
self.config = config
@@ -673,6 +678,11 @@ def codegen_function_call(self) -> ast.AST:
673678
[
674679
f"num_warps={num_warps}",
675680
f"num_stages={self.config.num_stages}",
681+
*(
682+
["launch_cooperative_grid=True"]
683+
if CompileEnvironment.current().has_barrier
684+
else []
685+
),
676686
]
677687
+ [
678688
f"{x.removeprefix('_triton_config_')}={self.config[x]}"

helion/_compiler/device_ir.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from .inductor_lowering import CodegenState
4949
from .inductor_lowering import codegen_call_with_graph
5050
from .inductor_lowering import prepare_graph_lowerings
51+
from .loop_dependency_checker import LoopDependencyChecker
5152
from .matmul_utils import tensor_matmul_replacement
5253
from .matmul_utils import torch_matmul_replacement
5354
from .node_masking import remove_unnecessary_masking
@@ -223,6 +224,8 @@ def codegen(self, state: CodegenState) -> list[object]:
223224

224225

225226
class RootGraphInfo(GraphInfo):
227+
phase_index: int = 0
228+
226229
@property
227230
def name(self) -> str:
228231
return f"root_graph_{self.graph_id}"
@@ -410,12 +413,22 @@ class RolledReductionInfo(NamedTuple):
410413
can_be_rolled_by_caller: bool
411414

412415

416+
@dataclasses.dataclass
417+
class KernelPhase:
418+
roots: list[int] # store root indices
419+
root_nodes: list[ast.For]
420+
loop_dependency_checker: LoopDependencyChecker = dataclasses.field(
421+
default_factory=LoopDependencyChecker
422+
)
423+
424+
413425
class DeviceIR:
414426
def __init__(self) -> None:
415427
super().__init__()
416428
self.graphs: list[GraphInfo] = []
417429
self.root_ids: list[int] = []
418430
self.rolled_reductions: list[RolledReductionInfo] = []
431+
self.phases: list[KernelPhase] = []
419432
self.grid_block_ids: list[list[int]] = []
420433

421434
def get_root(self, config: Config, graph_id: int) -> torch.fx.Graph:
@@ -469,6 +482,11 @@ def add_reduction_loop_graph(
469482
def add_root_graph(self, graph: torch.fx.Graph) -> None:
470483
self.root_ids.append(self.add_graph(graph, graph_info_cls=RootGraphInfo))
471484

485+
def phase_for_root(self, root_id: int) -> int:
486+
graph_info = self.graphs[self.root_ids[root_id]]
487+
assert isinstance(graph_info, RootGraphInfo)
488+
return graph_info.phase_index
489+
472490
def build_rolled_reductions(self) -> None:
473491
env = CompileEnvironment.current()
474492
rdims = [bs for bs in env.block_sizes if bs.reduction]
@@ -1354,6 +1372,10 @@ class WalkHostAST(NodeVisitor):
13541372
def __init__(self, device_ir: DeviceIR) -> None:
13551373
super().__init__()
13561374
self.device_ir = device_ir
1375+
self.root_index = 0
1376+
self.current_phase_roots: list[int] = []
1377+
self.phases: list[KernelPhase] = []
1378+
self.root_nodes: list[ast.For] = []
13571379

13581380
def visit_For(self, node: ast.For) -> None:
13591381
assert isinstance(node, ExtendedAST)
@@ -1372,9 +1394,44 @@ def visit_For(self, node: ast.For) -> None:
13721394
# pyrefly: ignore [missing-attribute]
13731395
block_ids = [inner.block_id]
13741396
self.device_ir.grid_block_ids.append(block_ids)
1397+
# store root index (position) not graph id
1398+
self.root_nodes.append(node)
1399+
self.current_phase_roots.append(len(self.device_ir.root_ids) - 1)
1400+
self.root_index += 1
13751401
else:
13761402
self.generic_visit(node)
13771403

1404+
def visit_Expr(self, node: ast.Expr) -> None:
1405+
# Record barrier placement between top-level loops.
1406+
from .type_propagation import BarrierResultType
1407+
1408+
assert isinstance(node, ExtendedAST)
1409+
assert isinstance(node.value, ExtendedAST)
1410+
is_barrier = isinstance(node.value._type_info, BarrierResultType)
1411+
1412+
if is_barrier:
1413+
if self.root_index == 0 or not self.current_phase_roots:
1414+
raise exc.BarrierOnlyAllowedAtTopLevel
1415+
self.phases.append(
1416+
KernelPhase(
1417+
roots=self.current_phase_roots,
1418+
root_nodes=[self.root_nodes[r] for r in self.current_phase_roots],
1419+
)
1420+
)
1421+
self.current_phase_roots = []
1422+
return
1423+
self.generic_visit(node)
1424+
1425+
def flush_phases(self) -> None:
1426+
if self.current_phase_roots:
1427+
self.phases.append(
1428+
KernelPhase(
1429+
roots=self.current_phase_roots,
1430+
root_nodes=[self.root_nodes[r] for r in self.current_phase_roots],
1431+
)
1432+
)
1433+
self.current_phase_roots = []
1434+
13781435

13791436
def _count_device_loads_and_stores(device_ir: DeviceIR) -> tuple[int, int, int]:
13801437
"""Count the number of load and store operations in device code for autotuning.
@@ -1466,6 +1523,18 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
14661523
visitor = WalkHostAST(device_ir)
14671524
for stmt in func.body:
14681525
visitor.visit(stmt)
1526+
visitor.flush_phases()
1527+
device_ir.phases = visitor.phases
1528+
# Run dependency checks once, per phase, so codegen does not redo it per-config.
1529+
for phase in device_ir.phases:
1530+
checker = phase.loop_dependency_checker
1531+
for loop_node in phase.root_nodes:
1532+
checker.register_loop(loop_node)
1533+
for phase_idx, phase in enumerate(device_ir.phases):
1534+
for ridx in phase.roots:
1535+
graph_info = device_ir.graphs[device_ir.root_ids[ridx]]
1536+
assert isinstance(graph_info, RootGraphInfo)
1537+
graph_info.phase_index = phase_idx
14691538
# If there are no top-level device loops, we cannot generate a valid kernel.
14701539
# Raise a friendly error instead of emitting an empty Triton function body.
14711540
if len(device_ir.root_ids) == 0:

helion/_compiler/generate_ast.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .. import exc
1212
from ..language._decorators import is_api_func
13+
from ..runtime.config import Config
1314
from .ast_extension import ExtendedAST
1415
from .ast_extension import LoopType
1516
from .ast_extension import NodeVisitor
@@ -24,6 +25,7 @@
2425
from .helper_function import CodegenInterface
2526
from .inductor_lowering import CodegenState
2627
from .inductor_lowering import codegen_call_with_graph
28+
from .loop_dependency_checker import LoopDependencyChecker
2729
from .program_id import ForEachProgramID
2830
from .tile_strategy import DeviceLoopState
2931
from .variable_origin import ArgumentOrigin
@@ -35,6 +37,7 @@
3537

3638
from ..runtime import Config
3739
from .host_function import HostFunction
40+
from .loop_dependency_checker import LoopDependencyChecker
3841
from .tile_strategy import DeviceLoopOrGridState
3942
from .type_propagation import TensorType
4043

@@ -56,7 +59,11 @@ def __init__(self, func: HostFunction, config: Config) -> None:
5659
self.next_else_block: list[ast.AST] | None = None
5760

5861
# Now create device function and initialize CodegenInterface
59-
self.device_function = DeviceFunction(f"_helion_{func.name}", config, self)
62+
self.device_function = DeviceFunction(
63+
f"_helion_{func.name}",
64+
config,
65+
self,
66+
)
6067
CodegenInterface.__init__(self, self.device_function)
6168

6269
def offset_var(self, block_idx: int) -> str:
@@ -70,6 +77,10 @@ def mask_var(self, block_idx: int) -> str | None:
7077
return loops[-1].strategy.mask_var(block_idx)
7178
return None
7279

80+
def _phase_checker(self, root_id: int) -> LoopDependencyChecker:
81+
phase_idx = self.host_function.device_ir.phase_for_root(root_id)
82+
return self.host_function.device_ir.phases[phase_idx].loop_dependency_checker
83+
7384
def add_statement(self, stmt: ast.AST | str | None) -> None:
7485
if stmt is None:
7586
return
@@ -227,17 +238,20 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
227238
if node._loop_type == LoopType.GRID:
228239
assert not node.orelse
229240

241+
assert node._root_id is not None
242+
# Loop dependency checks were already run during lowering; phase checker kept for symmetry/debug.
243+
self._phase_checker(node._root_id)
244+
230245
if len(self.host_function.device_ir.root_ids) == 1:
231246
body = self.device_function.body
232247
else:
233248
assert len(self.host_function.device_ir.root_ids) > 1
234-
assert node._root_id is not None
235249
# Multiple top level for loops
236250

237251
if node._root_id == 0:
238252
self.device_function.set_pid(
239253
ForEachProgramID(
240-
self.device_function.new_var("pid_shared", dce=False)
254+
self.device_function.new_var("pid_shared", dce=False),
241255
)
242256
)
243257
self.device_function.body.extend(
@@ -310,6 +324,11 @@ def visit_For(self, node: ast.For) -> ast.AST | None:
310324
# This ensures block size and rdim vars are defined in the correct order
311325
self.device_function.flush_deferred_rdim_defs(self)
312326

327+
if isinstance(self.device_function.pid, ForEachProgramID):
328+
self.device_function.pid.case_phases.append(
329+
self.host_function.device_ir.phase_for_root(node._root_id)
330+
)
331+
313332
# If we are in a multi top level loop, for all loops except for the last one
314333
# emit ifthenelse blocks
315334
if node._root_id < len(self.host_function.device_ir.root_ids) - 1:
@@ -477,6 +496,9 @@ def generate_ast(
477496
func: HostFunction, config: Config, emit_repro_caller: bool
478497
) -> ast.AST:
479498
with func:
499+
if len(func.device_ir.phases) > 1:
500+
if not str(config.pid_type).startswith("persistent"):
501+
raise exc.BarrierRequiresPersistent(config.pid_type)
480502
codegen = GenerateAST(func, config)
481503
with codegen.device_function:
482504
for stmt in func.body:

0 commit comments

Comments
 (0)