1212from .device_function import DeviceFunction
1313from .host_function import HostFunction
1414
15+
16+ def typed_program_id (dim : int = 0 ) -> str :
17+ """Generate tl.program_id() with int64 casting when needed.
18+
19+ Only casts to int64 when index_dtype is int64, to avoid overhead
20+ for the common int32 case.
21+ """
22+ env = CompileEnvironment .current ()
23+ dtype = env .triton_index_type ()
24+ if dtype != "tl.int32" :
25+ return f"tl.program_id({ dim } ).to({ dtype } )"
26+ return f"tl.program_id({ dim } )"
27+
28+
1529if TYPE_CHECKING :
1630 import sympy
1731
@@ -108,7 +122,7 @@ def _setup_persistent_kernel_and_wrap_body(
108122 @property
109123 def virtual_program_id (self ) -> str :
110124 """Get the virtual program ID expression for this strategy."""
111- return "tl.program_id (0)"
125+ return typed_program_id (0 )
112126
113127 def _is_persistent (self ) -> bool :
114128 """Check if this is a persistent strategy. Default False."""
@@ -157,7 +171,7 @@ def codegen_pid_init(self) -> list[ast.stmt]:
157171 pid_type = current_device_fn .config .get ("pid_type" , "flat" )
158172 if isinstance (pid_type , str ) and pid_type .startswith ("persistent" ):
159173 return []
160- return [statement_from_string (f"{ self .shared_pid_var } = tl.program_id (0)" )]
174+ return [statement_from_string (f"{ self .shared_pid_var } = { typed_program_id (0 )} " )]
161175
162176 def _get_cdiv_blocks (
163177 self , state : CodegenState , exclude_last : bool = False
@@ -228,7 +242,7 @@ class XYZProgramIDs(ProgramIDs):
228242 def codegen (self , state : CodegenState ) -> None :
229243 for i , pid in enumerate (self .pid_info ):
230244 state .codegen .statements_stack [- 1 ].insert (
231- i , statement_from_string (f"{ pid .pid_var } = tl.program_id( { i } ) " )
245+ i , statement_from_string (f"{ pid .pid_var } = { typed_program_id ( i ) } " )
232246 )
233247
234248 def codegen_grid (self ) -> ast .AST :
@@ -242,7 +256,7 @@ class FlatProgramIDs(ProgramIDs):
242256 """Only use the x grid and compute other dimensions"""
243257
244258 def codegen (self , state : CodegenState ) -> None :
245- pid_var = self .shared_pid_var or "tl.program_id (0)"
259+ pid_var = self .shared_pid_var or typed_program_id (0 )
246260 statements = self ._decompose_pid_to_statements (pid_var , state )
247261 state .codegen .statements_stack [- 1 ][:] = [
248262 * statements ,
@@ -420,7 +434,7 @@ def __init__(self, is_blocked: bool = False) -> None:
420434 }
421435 else :
422436 self .range_kwargs : dict [str , str ] = {
423- "begin" : "tl.program_id (0)" ,
437+ "begin" : typed_program_id (0 ),
424438 "end" : self .total_pids_var ,
425439 "step" : NUM_SM_VAR ,
426440 }
@@ -471,7 +485,7 @@ def setup_persistent_kernel(
471485 ),
472486 (
473487 self .start_pid_var ,
474- f"tl.program_id (0) * { self .block_size_var } " ,
488+ f"{ typed_program_id (0 )} * { self .block_size_var } " ,
475489 ),
476490 (
477491 self .end_pid_var ,
@@ -521,7 +535,7 @@ def _generate_pid_statements(self, state: CodegenState) -> list[ast.stmt]:
521535 if not self .virtual_pid_var :
522536 # Generate regular PID decomposition
523537 return self ._decompose_pid_to_statements (
524- self .shared_pid_var or "tl.program_id (0)" , state
538+ self .shared_pid_var or typed_program_id (0 ), state
525539 )
526540
527541 # Generate persistent PID decomposition
0 commit comments