4848from .inductor_lowering import CodegenState
4949from .inductor_lowering import codegen_call_with_graph
5050from .inductor_lowering import prepare_graph_lowerings
51+ from .loop_dependency_checker import LoopDependencyChecker
5152from .matmul_utils import tensor_matmul_replacement
5253from .matmul_utils import torch_matmul_replacement
5354from .node_masking import remove_unnecessary_masking
@@ -223,6 +224,8 @@ def codegen(self, state: CodegenState) -> list[object]:
223224
224225
225226class RootGraphInfo (GraphInfo ):
227+ phase_index : int = 0
228+
226229 @property
227230 def name (self ) -> str :
228231 return f"root_graph_{ self .graph_id } "
@@ -410,12 +413,22 @@ class RolledReductionInfo(NamedTuple):
410413 can_be_rolled_by_caller : bool
411414
412415
416+ @dataclasses .dataclass
417+ class KernelPhase :
418+ roots : list [int ] # store root indices
419+ root_nodes : list [ast .For ]
420+ loop_dependency_checker : LoopDependencyChecker = dataclasses .field (
421+ default_factory = LoopDependencyChecker
422+ )
423+
424+
413425class DeviceIR :
414426 def __init__ (self ) -> None :
415427 super ().__init__ ()
416428 self .graphs : list [GraphInfo ] = []
417429 self .root_ids : list [int ] = []
418430 self .rolled_reductions : list [RolledReductionInfo ] = []
431+ self .phases : list [KernelPhase ] = []
419432 self .grid_block_ids : list [list [int ]] = []
420433
421434 def get_root (self , config : Config , graph_id : int ) -> torch .fx .Graph :
@@ -469,6 +482,11 @@ def add_reduction_loop_graph(
469482 def add_root_graph (self , graph : torch .fx .Graph ) -> None :
470483 self .root_ids .append (self .add_graph (graph , graph_info_cls = RootGraphInfo ))
471484
485+ def phase_for_root (self , root_id : int ) -> int :
486+ graph_info = self .graphs [self .root_ids [root_id ]]
487+ assert isinstance (graph_info , RootGraphInfo )
488+ return graph_info .phase_index
489+
472490 def build_rolled_reductions (self ) -> None :
473491 env = CompileEnvironment .current ()
474492 rdims = [bs for bs in env .block_sizes if bs .reduction ]
@@ -1354,6 +1372,10 @@ class WalkHostAST(NodeVisitor):
13541372 def __init__ (self , device_ir : DeviceIR ) -> None :
13551373 super ().__init__ ()
13561374 self .device_ir = device_ir
1375+ self .root_index = 0
1376+ self .current_phase_roots : list [int ] = []
1377+ self .phases : list [KernelPhase ] = []
1378+ self .root_nodes : list [ast .For ] = []
13571379
13581380 def visit_For (self , node : ast .For ) -> None :
13591381 assert isinstance (node , ExtendedAST )
@@ -1372,9 +1394,44 @@ def visit_For(self, node: ast.For) -> None:
13721394 # pyrefly: ignore [missing-attribute]
13731395 block_ids = [inner .block_id ]
13741396 self .device_ir .grid_block_ids .append (block_ids )
1397+ # store root index (position) not graph id
1398+ self .root_nodes .append (node )
1399+ self .current_phase_roots .append (len (self .device_ir .root_ids ) - 1 )
1400+ self .root_index += 1
13751401 else :
13761402 self .generic_visit (node )
13771403
1404+ def visit_Expr (self , node : ast .Expr ) -> None :
1405+ # Record barrier placement between top-level loops.
1406+ from .type_propagation import BarrierResultType
1407+
1408+ assert isinstance (node , ExtendedAST )
1409+ assert isinstance (node .value , ExtendedAST )
1410+ is_barrier = isinstance (node .value ._type_info , BarrierResultType )
1411+
1412+ if is_barrier :
1413+ if self .root_index == 0 or not self .current_phase_roots :
1414+ raise exc .BarrierOnlyAllowedAtTopLevel
1415+ self .phases .append (
1416+ KernelPhase (
1417+ roots = self .current_phase_roots ,
1418+ root_nodes = [self .root_nodes [r ] for r in self .current_phase_roots ],
1419+ )
1420+ )
1421+ self .current_phase_roots = []
1422+ return
1423+ self .generic_visit (node )
1424+
1425+ def flush_phases (self ) -> None :
1426+ if self .current_phase_roots :
1427+ self .phases .append (
1428+ KernelPhase (
1429+ roots = self .current_phase_roots ,
1430+ root_nodes = [self .root_nodes [r ] for r in self .current_phase_roots ],
1431+ )
1432+ )
1433+ self .current_phase_roots = []
1434+
13781435
13791436def _count_device_loads_and_stores (device_ir : DeviceIR ) -> tuple [int , int , int ]:
13801437 """Count the number of load and store operations in device code for autotuning.
@@ -1466,6 +1523,18 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
14661523 visitor = WalkHostAST (device_ir )
14671524 for stmt in func .body :
14681525 visitor .visit (stmt )
1526+ visitor .flush_phases ()
1527+ device_ir .phases = visitor .phases
1528+ # Run dependency checks once, per phase, so codegen does not redo it per-config.
1529+ for phase in device_ir .phases :
1530+ checker = phase .loop_dependency_checker
1531+ for loop_node in phase .root_nodes :
1532+ checker .register_loop (loop_node )
1533+ for phase_idx , phase in enumerate (device_ir .phases ):
1534+ for ridx in phase .roots :
1535+ graph_info = device_ir .graphs [device_ir .root_ids [ridx ]]
1536+ assert isinstance (graph_info , RootGraphInfo )
1537+ graph_info .phase_index = phase_idx
14691538 # If there are no top-level device loops, we cannot generate a valid kernel.
14701539 # Raise a friendly error instead of emitting an empty Triton function body.
14711540 if len (device_ir .root_ids ) == 0 :
0 commit comments