Skip to content

Commit 9982041

Browse files
authored
Use int64 indexing for pids as well (#1195)
1 parent 1f2593c commit 9982041

File tree

3 files changed

+78
-8
lines changed

3 files changed

+78
-8
lines changed

helion/_compiler/program_id.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@
1212
from .device_function import DeviceFunction
1313
from .host_function import HostFunction
1414

15+
16+
def typed_program_id(dim: int = 0) -> str:
17+
"""Generate tl.program_id() with int64 casting when needed.
18+
19+
Only casts to int64 when index_dtype is int64, to avoid overhead
20+
for the common int32 case.
21+
"""
22+
env = CompileEnvironment.current()
23+
dtype = env.triton_index_type()
24+
if dtype != "tl.int32":
25+
return f"tl.program_id({dim}).to({dtype})"
26+
return f"tl.program_id({dim})"
27+
28+
1529
if TYPE_CHECKING:
1630
import sympy
1731

@@ -108,7 +122,7 @@ def _setup_persistent_kernel_and_wrap_body(
108122
@property
109123
def virtual_program_id(self) -> str:
110124
"""Get the virtual program ID expression for this strategy."""
111-
return "tl.program_id(0)"
125+
return typed_program_id(0)
112126

113127
def _is_persistent(self) -> bool:
114128
"""Check if this is a persistent strategy. Default False."""
@@ -157,7 +171,7 @@ def codegen_pid_init(self) -> list[ast.stmt]:
157171
pid_type = current_device_fn.config.get("pid_type", "flat")
158172
if isinstance(pid_type, str) and pid_type.startswith("persistent"):
159173
return []
160-
return [statement_from_string(f"{self.shared_pid_var} = tl.program_id(0)")]
174+
return [statement_from_string(f"{self.shared_pid_var} = {typed_program_id(0)}")]
161175

162176
def _get_cdiv_blocks(
163177
self, state: CodegenState, exclude_last: bool = False
@@ -228,7 +242,7 @@ class XYZProgramIDs(ProgramIDs):
228242
def codegen(self, state: CodegenState) -> None:
229243
for i, pid in enumerate(self.pid_info):
230244
state.codegen.statements_stack[-1].insert(
231-
i, statement_from_string(f"{pid.pid_var} = tl.program_id({i})")
245+
i, statement_from_string(f"{pid.pid_var} = {typed_program_id(i)}")
232246
)
233247

234248
def codegen_grid(self) -> ast.AST:
@@ -242,7 +256,7 @@ class FlatProgramIDs(ProgramIDs):
242256
"""Only use the x grid and compute other dimensions"""
243257

244258
def codegen(self, state: CodegenState) -> None:
245-
pid_var = self.shared_pid_var or "tl.program_id(0)"
259+
pid_var = self.shared_pid_var or typed_program_id(0)
246260
statements = self._decompose_pid_to_statements(pid_var, state)
247261
state.codegen.statements_stack[-1][:] = [
248262
*statements,
@@ -420,7 +434,7 @@ def __init__(self, is_blocked: bool = False) -> None:
420434
}
421435
else:
422436
self.range_kwargs: dict[str, str] = {
423-
"begin": "tl.program_id(0)",
437+
"begin": typed_program_id(0),
424438
"end": self.total_pids_var,
425439
"step": NUM_SM_VAR,
426440
}
@@ -471,7 +485,7 @@ def setup_persistent_kernel(
471485
),
472486
(
473487
self.start_pid_var,
474-
f"tl.program_id(0) * {self.block_size_var}",
488+
f"{typed_program_id(0)} * {self.block_size_var}",
475489
),
476490
(
477491
self.end_pid_var,
@@ -521,7 +535,7 @@ def _generate_pid_statements(self, state: CodegenState) -> list[ast.stmt]:
521535
if not self.virtual_pid_var:
522536
# Generate regular PID decomposition
523537
return self._decompose_pid_to_statements(
524-
self.shared_pid_var or "tl.program_id(0)", state
538+
self.shared_pid_var or typed_program_id(0), state
525539
)
526540

527541
# Generate persistent PID decomposition

helion/_compiler/tile_strategy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,13 +498,15 @@ def _codegen_common(
498498
return block_size_var, offsets_var, total_numel, statements
499499

500500
def codegen_grid(self, state: CodegenState) -> DeviceGridState:
501+
from .program_id import typed_program_id
502+
501503
block_size_var, offsets_var, total_numel, statements = self._codegen_common(
502504
state
503505
)
504506
env = CompileEnvironment.current()
505507
dtype = env.triton_index_type()
506508
state.add_statement(
507-
f"{offsets_var} = tl.program_id(0) * ({block_size_var}) + tl.arange(0, {block_size_var}).to({dtype})"
509+
f"{offsets_var} = {typed_program_id(0)} * ({block_size_var}) + tl.arange(0, {block_size_var}).to({dtype})"
508510
)
509511
state.codegen.statements_stack[-1].extend(statements)
510512

test/test_indexing.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,60 @@ def passthrough_int64(x: torch.Tensor) -> torch.Tensor:
582582
passthrough_int64.specialization_key((large,)),
583583
)
584584

585+
@skipIfRefEager("Test checks generated code")
586+
def test_program_id_cast_to_int64(self):
587+
"""Test that tl.program_id() is cast to int64 when index_dtype is int64."""
588+
589+
@helion.kernel(index_dtype=torch.int64)
590+
def add_kernel_int64(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
591+
out = torch.empty_like(x)
592+
for tile in hl.tile(x.size(0)):
593+
out[tile] = x[tile] + y[tile]
594+
return out
595+
596+
@helion.kernel(index_dtype=torch.int32)
597+
def add_kernel_int32(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
598+
out = torch.empty_like(x)
599+
for tile in hl.tile(x.size(0)):
600+
out[tile] = x[tile] + y[tile]
601+
return out
602+
603+
x = torch.randn(1024, device=DEVICE)
604+
y = torch.randn(1024, device=DEVICE)
605+
606+
# Test int64 case: program_id should be cast to int64
607+
code_int64, result_int64 = code_and_output(add_kernel_int64, (x, y))
608+
self.assertIn("tl.program_id(0).to(tl.int64)", code_int64)
609+
610+
# Test int32 case: program_id should NOT be cast
611+
code_int32, result_int32 = code_and_output(add_kernel_int32, (x, y))
612+
self.assertNotIn(".to(tl.int64)", code_int32)
613+
self.assertIn("tl.program_id(0)", code_int32)
614+
615+
# Both should produce correct results
616+
expected = x + y
617+
torch.testing.assert_close(result_int64, expected)
618+
torch.testing.assert_close(result_int32, expected)
619+
620+
@skipIfRefEager("Test checks for no IMA")
621+
@skipIfRocm("Test takes too long on ROCm")
622+
@skipIfCpu("Test requires GPU")
623+
@skipIfLowVRAM("Test requires large memory")
624+
def test_large_tensor(self):
625+
@helion.kernel(autotune_effort="none")
626+
def f(x: torch.Tensor) -> torch.Tensor:
627+
out = x.new_empty(x.shape)
628+
for (b,) in hl.grid([x.shape[0]]):
629+
for (x_tile,) in hl.tile([x.shape[1]]):
630+
out[b, x_tile] = x[b, x_tile]
631+
return out
632+
633+
B = 2**15
634+
D = 2**17
635+
inp = torch.randn(B, D, device=DEVICE, dtype=torch.float16)
636+
out = f(inp)
637+
assert (out == inp).all()
638+
585639
def test_assign_int(self):
586640
@helion.kernel
587641
def fn(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)