Skip to content

Commit 9d0b8bd

Browse files
authored
Fix unbackend symints in generated code (#1179)
1 parent 64061aa commit 9d0b8bd

File tree

4 files changed

+61
-4
lines changed

4 files changed

+61
-4
lines changed

helion/_compiler/compile_environment.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,12 @@ def size_hint(self, n: int | torch.SymInt) -> int:
383383
if isinstance(n, torch.SymInt):
384384
expr = n._sympy_()
385385
if _has_unbacked(expr):
386-
# If the size is a symbolic expression with unbacked symbols, then the shape environment
387-
# hint will be wrong since we assign a default value to unbacked symbols. Return a default hint.
386+
# For unbacked symbols, try to use the hint we stored in var_to_val
387+
# when creating the symint (see create_unbacked_symint).
388+
# This preserves the original value passed to the kernel.
389+
if expr in self.shape_env.var_to_val:
390+
return int(self.shape_env.var_to_val[expr])
391+
# Fall back to default hint if not found
388392
return 8192
389393

390394
# pyrefly: ignore [no-matching-overload]

test/test_misc.expected

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def call():
491491
# src[test_misc.py:N]: ) -> tuple[torch.Tensor, torch.Tensor]:
492492
# src[test_misc.py:N-N]: ...
493493
t = rand_strided(size=(16, 1), stride=(1, 1), dtype=torch.float32, device=DEVICE)
494-
i = 8192
494+
i = 1
495495
s = 'foo'
496496
b = False
497497
f = 1.1
@@ -546,7 +546,7 @@ def call():
546546
# src[test_misc.py:N]: ) -> tuple[torch.Tensor, torch.Tensor]:
547547
# src[test_misc.py:N-N]: ...
548548
t = rand_strided(size=(16, 1), stride=(1, 1), dtype=torch.float32, device=DEVICE)
549-
i = 8192
549+
i = 1
550550
s = 'foo'
551551
b = False
552552
f = 1.1

test/test_specialize.expected

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,38 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
335335
# src[test_specialize.py:N]: return out
336336
return out
337337

338+
--- assertExpectedJournal(TestSpecialize.test_specialize_tuple_element)
339+
from __future__ import annotations
340+
341+
import torch
342+
import triton
343+
import triton.language as tl
344+
from helion.runtime import default_launcher as _default_launcher
345+
346+
@triton.jit
347+
def _helion_foo(x, out, _BLOCK_SIZE_0: tl.constexpr):
348+
# src[test_specialize.py:N]: for x_tile in hl.tile([x.shape[0]]):
349+
pid_0 = tl.program_id(0)
350+
offset_0 = pid_0 * _BLOCK_SIZE_0
351+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
352+
# src[test_specialize.py:N]: out[x_tile] = x[x_tile] + (1 << (32 - val))
353+
load = tl.load(x + indices_0 * 1, None)
354+
v_0 = tl.full([], 65536, tl.int32)
355+
v_1 = load + v_0
356+
tl.store(out + indices_0 * 1, v_1, None)
357+
358+
def foo(x: torch.Tensor, bitshift: tuple[int, int], *, _launcher=_default_launcher):
359+
# src[test_specialize.py:N]: out = x.new_empty(x.shape)
360+
out = x.new_empty(x.shape)
361+
# src[test_specialize.py:N]: for x_tile in hl.tile([x.shape[0]]):
362+
_BLOCK_SIZE_0 = 32
363+
# src[test_specialize.py:N]: for x_tile in hl.tile([x.shape[0]]):
364+
# src[test_specialize.py:N]: # compute_val equivalent: 1 << (32 - val)
365+
# src[test_specialize.py:N]: out[x_tile] = x[x_tile] + (1 << (32 - val))
366+
_launcher(_helion_foo, (triton.cdiv(64, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
367+
# src[test_specialize.py:N]: return out
368+
return out
369+
338370
--- assertExpectedJournal(TestSpecialize.test_sqrt_does_not_specialize)
339371
from __future__ import annotations
340372

test/test_specialize.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,27 @@ def fn(
305305
)
306306
self.assertExpectedJournal(code)
307307

308+
def test_specialize_tuple_element(self):
309+
"""Test that hl.specialize works correctly with tuple elements."""
310+
311+
@helion.kernel(config=helion.Config(block_sizes=[32]))
312+
def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
313+
out = x.new_empty(x.shape)
314+
val = hl.specialize(bitshift[0])
315+
for x_tile in hl.tile([x.shape[0]]):
316+
# compute_val equivalent: 1 << (32 - val)
317+
out[x_tile] = x[x_tile] + (1 << (32 - val))
318+
return out
319+
320+
x = torch.ones(64, dtype=torch.int32, device=DEVICE)
321+
code, result = code_and_output(foo, (x, (16, 16)))
322+
# 1 << (32-16) = 1 << 16 = 65536
323+
expected = x + 65536
324+
torch.testing.assert_close(result, expected)
325+
# Verify that 65536 appears in the generated code as a constant
326+
self.assertIn("65536", code)
327+
self.assertExpectedJournal(code)
328+
308329

309330
if __name__ == "__main__":
310331
unittest.main()

0 commit comments

Comments
 (0)