Skip to content

Commit 4dc7a7f

Browse files
authored
Add tuple comprehension support (#1190)
1 parent 0a9a19d commit 4dc7a7f

File tree

4 files changed

+277
-17
lines changed

4 files changed

+277
-17
lines changed

helion/_compiler/device_ir.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,36 +1036,43 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[object, ...]:
10361036
def visit_List(self, node: ast.List) -> list[object]:
10371037
return [self.visit(x) for x in node.elts]
10381038

1039-
def visit_ListComp(self, node: ast.ListComp) -> tuple[object, ...]:
1040-
"""Handle list comprehension unrolling similar to tuple unrolling."""
1039+
def _visit_comprehension(
1040+
self, node: ast.ListComp | ast.GeneratorExp, name: str
1041+
) -> tuple[object, ...]:
1042+
"""Handle list comprehension or generator expression unrolling."""
10411043
assert isinstance(node, ExtendedAST)
10421044

10431045
# Only handle simple cases with single generator and no if conditions
10441046
if len(node.generators) != 1 or node.generators[0].ifs:
1045-
raise exc.StatementNotSupported(
1046-
"Complex list comprehensions are not supported"
1047-
)
1047+
raise exc.StatementNotSupported(f"Complex {name}s are not supported")
10481048

10491049
generator = node.generators[0]
10501050
assert isinstance(generator.iter, ExtendedAST)
10511051
iter_type = generator.iter._type_info
10521052

10531053
# Check if we're iterating over a sequence (similar to tuple unrolling)
10541054
if isinstance(iter_type, SequenceType):
1055-
return self._handle_listcomp_unrolling(node)
1055+
return self._handle_comprehension_unrolling(node.elt, generator)
10561056

10571057
# For non-sequence iterables, we could extend this later
10581058
raise exc.StatementNotSupported(
1059-
"List comprehensions over non-sequence types are not supported"
1059+
f"{name.capitalize()}s over non-sequence types are not supported"
10601060
)
10611061

1062-
def _handle_listcomp_unrolling(self, node: ast.ListComp) -> tuple[object, ...]:
1063-
"""Handle unrolling of list comprehensions over sequences."""
1064-
generator = node.generators[0]
1062+
def visit_ListComp(self, node: ast.ListComp) -> tuple[object, ...]:
1063+
return self._visit_comprehension(node, "list comprehension")
1064+
1065+
def visit_GeneratorExp(self, node: ast.GeneratorExp) -> tuple[object, ...]:
1066+
return self._visit_comprehension(node, "generator expression")
1067+
1068+
def _handle_comprehension_unrolling(
1069+
self, elt: ast.expr, generator: ast.comprehension
1070+
) -> tuple[object, ...]:
1071+
"""Handle unrolling of comprehensions (list comp or generator exp) over sequences."""
10651072

10661073
def evaluate_expression() -> object:
10671074
# Evaluate the comprehension expression
1068-
result = self.visit(node.elt)
1075+
result = self.visit(elt)
10691076
# If the result is a SymInt that can be evaluated to a concrete value, do so
10701077
if isinstance(result, torch.SymInt):
10711078
try:

helion/_compiler/type_propagation.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2429,21 +2429,26 @@ def _evaluate_comprehension(
24292429
# Fallback to generic list type
24302430
return SequenceType(self.origin(), [element_result_type])
24312431

2432-
def visit_ListComp(self, node: ast.ListComp) -> TypeInfo:
2433-
"""Type propagation for list comprehensions."""
2432+
def _visit_comprehension(
2433+
self, node: ast.ListComp | ast.GeneratorExp, name: str
2434+
) -> TypeInfo:
2435+
"""Type propagation for list comprehensions and generator expressions."""
24342436
if len(node.generators) != 1:
24352437
raise exc.StatementNotSupported(
2436-
"List comprehensions with multiple generators are not supported"
2438+
f"{name.capitalize()}s with multiple generators are not supported"
24372439
)
2438-
24392440
return self._evaluate_comprehension(node.generators[0], node.elt)
24402441

2442+
def visit_ListComp(self, node: ast.ListComp) -> TypeInfo:
2443+
return self._visit_comprehension(node, "list comprehension")
2444+
2445+
def visit_GeneratorExp(self, node: ast.GeneratorExp) -> TypeInfo:
2446+
return self._visit_comprehension(node, "generator expression")
2447+
24412448
# TODO(jansel): need to implement these
24422449
# pyrefly: ignore [bad-assignment, bad-param-name-override]
24432450
visit_SetComp: _VisitMethod = _not_supported
24442451
# pyrefly: ignore [bad-assignment, bad-param-name-override]
2445-
visit_GeneratorExp: _VisitMethod = _not_supported
2446-
# pyrefly: ignore [bad-assignment, bad-param-name-override]
24472452
visit_DictComp: _VisitMethod = _not_supported
24482453

24492454
# TODO(jansel): support closure functions defined on host

test/test_unroll_tuples.expected

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,152 @@ def kernel_static_range_with_start(x: torch.Tensor, *, _launcher=_default_launch
890890
# src[test_unroll_tuples.py:N]: return result
891891
return result
892892

893+
--- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension)
894+
from __future__ import annotations
895+
896+
import torch
897+
import triton
898+
import triton.language as tl
899+
from helion.runtime import default_launcher as _default_launcher
900+
901+
@triton.jit
902+
def _helion_kernel_tuple_comprehension(x, result, multipliers_item_0, multipliers_item_1, multipliers_item_2, _BLOCK_SIZE_0: tl.constexpr):
903+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
904+
pid_0 = tl.program_id(0)
905+
offset_0 = pid_0 * _BLOCK_SIZE_0
906+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
907+
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
908+
acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
909+
# src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multiplier
910+
load = tl.load(x + indices_0 * 1, None)
911+
v_0 = tl.cast(multipliers_item_0, tl.float32)
912+
v_1 = load * v_0
913+
v_2 = acc + v_1
914+
load_1 = tl.load(x + indices_0 * 1, None)
915+
v_3 = tl.cast(multipliers_item_1, tl.float32)
916+
v_4 = load_1 * v_3
917+
v_5 = v_2 + v_4
918+
load_2 = tl.load(x + indices_0 * 1, None)
919+
v_6 = tl.cast(multipliers_item_2, tl.float32)
920+
v_7 = load_2 * v_6
921+
v_8 = v_5 + v_7
922+
# src[test_unroll_tuples.py:N]: result[tile_idx] = acc
923+
tl.store(result + indices_0 * 1, v_8, None)
924+
925+
def kernel_tuple_comprehension(x: torch.Tensor, *, _launcher=_default_launcher):
926+
"""Test tuple comprehension with generator expression."""
927+
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(x)
928+
result = torch.zeros_like(x)
929+
# src[test_unroll_tuples.py:N]: multipliers = tuple(m * 2 for m in (1, 2, 3))
930+
multipliers = tuple((m * 2 for m in (1, 2, 3)))
931+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
932+
_BLOCK_SIZE_0 = 16
933+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
934+
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
935+
# src[test_unroll_tuples.py:N]: for multiplier in multipliers:
936+
# src[test_unroll_tuples.py:N-N]: ...
937+
_launcher(_helion_kernel_tuple_comprehension, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, multipliers[0], multipliers[1], multipliers[2], _BLOCK_SIZE_0, num_warps=4, num_stages=1)
938+
# src[test_unroll_tuples.py:N]: return result
939+
return result
940+
941+
--- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension_with_static_range)
942+
from __future__ import annotations
943+
944+
import torch
945+
import helion.language as hl
946+
import triton
947+
import triton.language as tl
948+
from helion.runtime import default_launcher as _default_launcher
949+
950+
@triton.jit
951+
def _helion_kernel_tuple_comprehension_with_static_range(x, result, multipliers_item_0, multipliers_item_1, multipliers_item_2, multipliers_item_3, _BLOCK_SIZE_0: tl.constexpr):
952+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
953+
pid_0 = tl.program_id(0)
954+
offset_0 = pid_0 * _BLOCK_SIZE_0
955+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
956+
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
957+
acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
958+
# src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[i]
959+
load = tl.load(x + indices_0 * 1, None)
960+
v_0 = tl.cast(multipliers_item_0, tl.float32)
961+
v_1 = load * v_0
962+
v_2 = acc + v_1
963+
load_1 = tl.load(x + indices_0 * 1, None)
964+
v_3 = tl.cast(multipliers_item_1, tl.float32)
965+
v_4 = load_1 * v_3
966+
v_5 = v_2 + v_4
967+
load_2 = tl.load(x + indices_0 * 1, None)
968+
v_6 = tl.cast(multipliers_item_2, tl.float32)
969+
v_7 = load_2 * v_6
970+
v_8 = v_5 + v_7
971+
load_3 = tl.load(x + indices_0 * 1, None)
972+
v_9 = tl.cast(multipliers_item_3, tl.float32)
973+
v_10 = load_3 * v_9
974+
v_11 = v_8 + v_10
975+
# src[test_unroll_tuples.py:N]: result[tile_idx] = acc
976+
tl.store(result + indices_0 * 1, v_11, None)
977+
978+
def kernel_tuple_comprehension_with_static_range(x: torch.Tensor, N: hl.constexpr, *, _launcher=_default_launcher):
979+
"""Test tuple comprehension with static_range for indexing."""
980+
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(x)
981+
result = torch.zeros_like(x)
982+
# src[test_unroll_tuples.py:N]: multipliers = tuple(i + 1 for i in range(N))
983+
multipliers = tuple((i + 1 for i in range(4)))
984+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
985+
_BLOCK_SIZE_0 = 16
986+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
987+
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
988+
# src[test_unroll_tuples.py:N]: for i in hl.static_range(N):
989+
# src[test_unroll_tuples.py:N-N]: ...
990+
_launcher(_helion_kernel_tuple_comprehension_with_static_range, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, multipliers[0], multipliers[1], multipliers[2], multipliers[3], _BLOCK_SIZE_0, num_warps=4, num_stages=1)
991+
# src[test_unroll_tuples.py:N]: return result
992+
return result
993+
994+
--- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension_with_tensors)
995+
from __future__ import annotations
996+
997+
import torch
998+
import triton
999+
import triton.language as tl
1000+
from helion.runtime import default_launcher as _default_launcher
1001+
1002+
@triton.jit
1003+
def _helion_kernel_tuple_comprehension_with_tensors(scaled_item_0, scaled_item_1, scaled_item_2, result, _BLOCK_SIZE_0: tl.constexpr):
1004+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
1005+
pid_0 = tl.program_id(0)
1006+
offset_0 = pid_0 * _BLOCK_SIZE_0
1007+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1008+
mask_0 = indices_0 < 18
1009+
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
1010+
acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
1011+
# src[test_unroll_tuples.py:N]: acc += tensor[tile_idx]
1012+
load = tl.load(scaled_item_0 + indices_0 * 1, mask_0, other=0)
1013+
v_0 = acc + load
1014+
load_1 = tl.load(scaled_item_1 + indices_0 * 1, mask_0, other=0)
1015+
v_1 = v_0 + load_1
1016+
load_2 = tl.load(scaled_item_2 + indices_0 * 1, mask_0, other=0)
1017+
v_2 = v_1 + load_2
1018+
# src[test_unroll_tuples.py:N]: result[tile_idx] = acc
1019+
tl.store(result + indices_0 * 1, v_2, mask_0)
1020+
1021+
def kernel_tuple_comprehension_with_tensors(tensors: tuple[torch.Tensor, torch.Tensor, torch.Tensor], *, _launcher=_default_launcher):
1022+
"""Test tuple comprehension that transforms a tuple of tensors."""
1023+
# src[test_unroll_tuples.py:N]: result = torch.zeros_like(tensors[0])
1024+
result = torch.zeros_like(tensors[0])
1025+
# src[test_unroll_tuples.py:N]: scales = (0.5, 1.0, 1.5)
1026+
scales = (0.5, 1.0, 1.5)
1027+
# src[test_unroll_tuples.py:N]: scaled = tuple(t * s for t, s in zip(tensors, scales, strict=False))
1028+
scaled = tuple((t * s for t, s in zip(tensors, scales, strict=False)))
1029+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
1030+
_BLOCK_SIZE_0 = 32
1031+
# src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
1032+
# src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
1033+
# src[test_unroll_tuples.py:N]: for tensor in scaled:
1034+
# src[test_unroll_tuples.py:N-N]: ...
1035+
_launcher(_helion_kernel_tuple_comprehension_with_tensors, (triton.cdiv(18, _BLOCK_SIZE_0),), scaled[0], scaled[1], scaled[2], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
1036+
# src[test_unroll_tuples.py:N]: return result
1037+
return result
1038+
8931039
--- assertExpectedJournal(TestUnrollTuples.test_tuple_with_scaling_factors)
8941040
from __future__ import annotations
8951041

test/test_unroll_tuples.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,57 @@ def kernel_simple_list_comprehension(
227227
return result
228228

229229

230+
@helion.kernel(autotune_effort="none")
231+
def kernel_tuple_comprehension(
232+
x: torch.Tensor,
233+
) -> torch.Tensor:
234+
"""Test tuple comprehension with generator expression."""
235+
result = torch.zeros_like(x)
236+
# Create tuple using generator expression
237+
multipliers = tuple(m * 2 for m in (1, 2, 3))
238+
for tile_idx in hl.tile(result.size(0)):
239+
acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
240+
for multiplier in multipliers:
241+
acc += x[tile_idx] * multiplier
242+
result[tile_idx] = acc
243+
return result
244+
245+
246+
@helion.kernel(autotune_effort="none")
247+
def kernel_tuple_comprehension_with_static_range(
248+
x: torch.Tensor,
249+
N: hl.constexpr,
250+
) -> torch.Tensor:
251+
"""Test tuple comprehension with static_range for indexing."""
252+
result = torch.zeros_like(x)
253+
# Create tuple using generator expression with range
254+
multipliers = tuple(i + 1 for i in range(N))
255+
for tile_idx in hl.tile(result.size(0)):
256+
acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
257+
for i in hl.static_range(N):
258+
acc += x[tile_idx] * multipliers[i]
259+
result[tile_idx] = acc
260+
return result
261+
262+
263+
@helion.kernel(autotune_effort="none")
264+
def kernel_tuple_comprehension_with_tensors(
265+
tensors: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
266+
) -> torch.Tensor:
267+
"""Test tuple comprehension that transforms a tuple of tensors."""
268+
result = torch.zeros_like(tensors[0])
269+
# Create scaled versions using generator expression
270+
scales = (0.5, 1.0, 1.5)
271+
scaled = tuple(t * s for t, s in zip(tensors, scales, strict=False))
272+
273+
for tile_idx in hl.tile(result.size(0)):
274+
acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
275+
for tensor in scaled:
276+
acc += tensor[tile_idx]
277+
result[tile_idx] = acc
278+
return result
279+
280+
230281
@helion.kernel(autotune_effort="none")
231282
def kernel_list_comprehension_with_function(
232283
x: torch.Tensor,
@@ -623,6 +674,57 @@ def test_simple_list_comprehension(self):
623674
expected = x * 12
624675
torch.testing.assert_close(result, expected)
625676

677+
def test_tuple_comprehension(self):
678+
"""Test tuple comprehension with generator expression."""
679+
size = (16,)
680+
x = torch.randn(size, device=DEVICE)
681+
682+
code, result = code_and_output(kernel_tuple_comprehension, (x,))
683+
684+
# Validate generated code
685+
self.assertExpectedJournal(code)
686+
687+
# Test correctness - should be x * (2 + 4 + 6) = x * 12
688+
expected = x * 12
689+
torch.testing.assert_close(result, expected)
690+
691+
def test_tuple_comprehension_with_static_range(self):
692+
"""Test tuple comprehension with static_range for indexing."""
693+
size = (16,)
694+
x = torch.randn(size, device=DEVICE)
695+
N = 4
696+
697+
code, result = code_and_output(
698+
kernel_tuple_comprehension_with_static_range, (x, N)
699+
)
700+
701+
# Validate generated code
702+
self.assertExpectedJournal(code)
703+
704+
# Test correctness - should be x * (1 + 2 + 3 + 4) = x * 10
705+
expected = x * 10
706+
torch.testing.assert_close(result, expected)
707+
708+
def test_tuple_comprehension_with_tensors(self):
709+
"""Test tuple comprehension that transforms a tuple of tensors."""
710+
size = (18,)
711+
tensor1 = torch.randn(size, device=DEVICE)
712+
tensor2 = torch.randn(size, device=DEVICE)
713+
tensor3 = torch.randn(size, device=DEVICE)
714+
715+
tensors = (tensor1, tensor2, tensor3)
716+
717+
code, result = code_and_output(
718+
kernel_tuple_comprehension_with_tensors, (tensors,)
719+
)
720+
721+
# Validate generated code
722+
self.assertExpectedJournal(code)
723+
724+
# Test correctness - should be tensor1*0.5 + tensor2*1.0 + tensor3*1.5
725+
expected = tensor1 * 0.5 + tensor2 * 1.0 + tensor3 * 1.5
726+
torch.testing.assert_close(result, expected)
727+
626728
def test_list_comprehension_with_function(self):
627729
"""Test list comprehension with expressions."""
628730
size = (14,)

0 commit comments

Comments
 (0)