@@ -890,6 +890,152 @@ def kernel_static_range_with_start(x: torch.Tensor, *, _launcher=_default_launch
890890 # src[test_unroll_tuples.py:N]: return result
891891 return result
892892
893+ --- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension)
894+ from __future__ import annotations
895+
896+ import torch
897+ import triton
898+ import triton.language as tl
899+ from helion.runtime import default_launcher as _default_launcher
900+
901+ @triton.jit
902+ def _helion_kernel_tuple_comprehension(x, result, multipliers_item_0, multipliers_item_1, multipliers_item_2, _BLOCK_SIZE_0: tl.constexpr):
903+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
904+ pid_0 = tl.program_id(0)
905+ offset_0 = pid_0 * _BLOCK_SIZE_0
906+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
907+ # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
908+ acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
909+ # src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multiplier
910+ load = tl.load(x + indices_0 * 1, None)
911+ v_0 = tl.cast(multipliers_item_0, tl.float32)
912+ v_1 = load * v_0
913+ v_2 = acc + v_1
914+ load_1 = tl.load(x + indices_0 * 1, None)
915+ v_3 = tl.cast(multipliers_item_1, tl.float32)
916+ v_4 = load_1 * v_3
917+ v_5 = v_2 + v_4
918+ load_2 = tl.load(x + indices_0 * 1, None)
919+ v_6 = tl.cast(multipliers_item_2, tl.float32)
920+ v_7 = load_2 * v_6
921+ v_8 = v_5 + v_7
922+ # src[test_unroll_tuples.py:N]: result[tile_idx] = acc
923+ tl.store(result + indices_0 * 1, v_8, None)
924+
925+ def kernel_tuple_comprehension(x: torch.Tensor, *, _launcher=_default_launcher):
926+ """Test tuple comprehension with generator expression."""
927+ # src[test_unroll_tuples.py:N]: result = torch.zeros_like(x)
928+ result = torch.zeros_like(x)
929+ # src[test_unroll_tuples.py:N]: multipliers = tuple(m * 2 for m in (1, 2, 3))
930+ multipliers = tuple((m * 2 for m in (1, 2, 3)))
931+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
932+ _BLOCK_SIZE_0 = 16
933+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
934+ # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
935+ # src[test_unroll_tuples.py:N]: for multiplier in multipliers:
936+ # src[test_unroll_tuples.py:N-N]: ...
937+ _launcher(_helion_kernel_tuple_comprehension, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, multipliers[0], multipliers[1], multipliers[2], _BLOCK_SIZE_0, num_warps=4, num_stages=1)
938+ # src[test_unroll_tuples.py:N]: return result
939+ return result
940+
941+ --- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension_with_static_range)
942+ from __future__ import annotations
943+
944+ import torch
945+ import helion.language as hl
946+ import triton
947+ import triton.language as tl
948+ from helion.runtime import default_launcher as _default_launcher
949+
950+ @triton.jit
951+ def _helion_kernel_tuple_comprehension_with_static_range(x, result, multipliers_item_0, multipliers_item_1, multipliers_item_2, multipliers_item_3, _BLOCK_SIZE_0: tl.constexpr):
952+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
953+ pid_0 = tl.program_id(0)
954+ offset_0 = pid_0 * _BLOCK_SIZE_0
955+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
956+ # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
957+ acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
958+ # src[test_unroll_tuples.py:N]: acc += x[tile_idx] * multipliers[i]
959+ load = tl.load(x + indices_0 * 1, None)
960+ v_0 = tl.cast(multipliers_item_0, tl.float32)
961+ v_1 = load * v_0
962+ v_2 = acc + v_1
963+ load_1 = tl.load(x + indices_0 * 1, None)
964+ v_3 = tl.cast(multipliers_item_1, tl.float32)
965+ v_4 = load_1 * v_3
966+ v_5 = v_2 + v_4
967+ load_2 = tl.load(x + indices_0 * 1, None)
968+ v_6 = tl.cast(multipliers_item_2, tl.float32)
969+ v_7 = load_2 * v_6
970+ v_8 = v_5 + v_7
971+ load_3 = tl.load(x + indices_0 * 1, None)
972+ v_9 = tl.cast(multipliers_item_3, tl.float32)
973+ v_10 = load_3 * v_9
974+ v_11 = v_8 + v_10
975+ # src[test_unroll_tuples.py:N]: result[tile_idx] = acc
976+ tl.store(result + indices_0 * 1, v_11, None)
977+
978+ def kernel_tuple_comprehension_with_static_range(x: torch.Tensor, N: hl.constexpr, *, _launcher=_default_launcher):
979+ """Test tuple comprehension with static_range for indexing."""
980+ # src[test_unroll_tuples.py:N]: result = torch.zeros_like(x)
981+ result = torch.zeros_like(x)
982+ # src[test_unroll_tuples.py:N]: multipliers = tuple(i + 1 for i in range(N))
983+ multipliers = tuple((i + 1 for i in range(4)))
984+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
985+ _BLOCK_SIZE_0 = 16
986+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
987+ # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
988+ # src[test_unroll_tuples.py:N]: for i in hl.static_range(N):
989+ # src[test_unroll_tuples.py:N-N]: ...
990+ _launcher(_helion_kernel_tuple_comprehension_with_static_range, (triton.cdiv(16, _BLOCK_SIZE_0),), x, result, multipliers[0], multipliers[1], multipliers[2], multipliers[3], _BLOCK_SIZE_0, num_warps=4, num_stages=1)
991+ # src[test_unroll_tuples.py:N]: return result
992+ return result
993+
994+ --- assertExpectedJournal(TestUnrollTuples.test_tuple_comprehension_with_tensors)
995+ from __future__ import annotations
996+
997+ import torch
998+ import triton
999+ import triton.language as tl
1000+ from helion.runtime import default_launcher as _default_launcher
1001+
1002+ @triton.jit
1003+ def _helion_kernel_tuple_comprehension_with_tensors(scaled_item_0, scaled_item_1, scaled_item_2, result, _BLOCK_SIZE_0: tl.constexpr):
1004+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
1005+ pid_0 = tl.program_id(0)
1006+ offset_0 = pid_0 * _BLOCK_SIZE_0
1007+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1008+ mask_0 = indices_0 < 18
1009+ # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
1010+ acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
1011+ # src[test_unroll_tuples.py:N]: acc += tensor[tile_idx]
1012+ load = tl.load(scaled_item_0 + indices_0 * 1, mask_0, other=0)
1013+ v_0 = acc + load
1014+ load_1 = tl.load(scaled_item_1 + indices_0 * 1, mask_0, other=0)
1015+ v_1 = v_0 + load_1
1016+ load_2 = tl.load(scaled_item_2 + indices_0 * 1, mask_0, other=0)
1017+ v_2 = v_1 + load_2
1018+ # src[test_unroll_tuples.py:N]: result[tile_idx] = acc
1019+ tl.store(result + indices_0 * 1, v_2, mask_0)
1020+
1021+ def kernel_tuple_comprehension_with_tensors(tensors: tuple[torch.Tensor, torch.Tensor, torch.Tensor], *, _launcher=_default_launcher):
1022+ """Test tuple comprehension that transforms a tuple of tensors."""
1023+ # src[test_unroll_tuples.py:N]: result = torch.zeros_like(tensors[0])
1024+ result = torch.zeros_like(tensors[0])
1025+ # src[test_unroll_tuples.py:N]: scales = (0.5, 1.0, 1.5)
1026+ scales = (0.5, 1.0, 1.5)
1027+ # src[test_unroll_tuples.py:N]: scaled = tuple(t * s for t, s in zip(tensors, scales, strict=False))
1028+ scaled = tuple((t * s for t, s in zip(tensors, scales, strict=False)))
1029+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
1030+ _BLOCK_SIZE_0 = 32
1031+ # src[test_unroll_tuples.py:N]: for tile_idx in hl.tile(result.size(0)):
1032+ # src[test_unroll_tuples.py:N]: acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device)
1033+ # src[test_unroll_tuples.py:N]: for tensor in scaled:
1034+ # src[test_unroll_tuples.py:N-N]: ...
1035+ _launcher(_helion_kernel_tuple_comprehension_with_tensors, (triton.cdiv(18, _BLOCK_SIZE_0),), scaled[0], scaled[1], scaled[2], result, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
1036+ # src[test_unroll_tuples.py:N]: return result
1037+ return result
1038+
8931039--- assertExpectedJournal(TestUnrollTuples.test_tuple_with_scaling_factors)
8941040from __future__ import annotations
8951041
0 commit comments