@@ -140,3 +140,93 @@ def fn(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
140140 _BLOCK_SIZE_1 = 32
141141 _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, y, out, out.size(0), out.size(1), x.size(0), x.size(1), y.size(0), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
142142 return out
143+
144+ --- assertExpectedJournal(TestViews.test_stack_dim0)
145+ from __future__ import annotations
146+
147+ import torch
148+ import triton
149+ import triton.language as tl
150+ from helion.runtime import default_launcher as _default_launcher
151+
152+ @triton.jit
153+ def _helion_test_stack_dim0_kernel(a, b, c, result, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
154+ pid_0 = tl.program_id(0)
155+ offset_0 = pid_0 * _BLOCK_SIZE_0
156+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
157+ mask_0 = indices_0 < 65
158+ indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
159+ mask_2 = indices_3 < 3
160+ for offset_2 in tl.range(0, 129, _BLOCK_SIZE_1):
161+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
162+ mask_1 = indices_2 < 129
163+ a_tile = tl.load(a + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
164+ b_tile = tl.load(b + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
165+ c_tile = tl.load(c + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
166+ stack_idx = tl.arange(0, 4)
167+ broadcast_idx = stack_idx[:, None, None]
168+ expanded_0 = tl.expand_dims(a_tile, 0)
169+ expanded_1 = tl.expand_dims(b_tile, 0)
170+ expanded_2 = tl.expand_dims(c_tile, 0)
171+ stacked_result = tl.zeros_like(expanded_0)
172+ mask_3 = broadcast_idx == 0
173+ stacked_result = tl.where(mask_3, expanded_0, stacked_result)
174+ mask_4 = broadcast_idx == 1
175+ stacked_result = tl.where(mask_4, expanded_1, stacked_result)
176+ mask_5 = broadcast_idx == 2
177+ stacked_result = tl.where(mask_5, expanded_2, stacked_result)
178+ tl.store(result + (indices_3[:, None, None] * 8385 + indices_0[None, :, None] * 129 + indices_2[None, None, :] * 1), stacked_result, mask_2[:, None, None] & mask_0[None, :, None] & mask_1[None, None, :])
179+
180+ def test_stack_dim0_kernel(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, *, _launcher=_default_launcher):
181+ M, N = a.shape
182+ result = torch.zeros(3, M, N, dtype=a.dtype, device=a.device)
183+ _BLOCK_SIZE_0 = 32
184+ _RDIM_SIZE_2 = 4
185+ _BLOCK_SIZE_1 = 32
186+ _launcher(_helion_test_stack_dim0_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
187+ return result
188+
189+ --- assertExpectedJournal(TestViews.test_stack_non_power_of_2)
190+ from __future__ import annotations
191+
192+ import torch
193+ import triton
194+ import triton.language as tl
195+ from helion.runtime import default_launcher as _default_launcher
196+
197+ @triton.jit
198+ def _helion_test_stack_non_power_of_2_kernel(a, b, c, result, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
199+ pid_0 = tl.program_id(0)
200+ offset_0 = pid_0 * _BLOCK_SIZE_0
201+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
202+ mask_0 = indices_0 < 65
203+ indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
204+ mask_2 = indices_3 < 3
205+ for offset_2 in tl.range(0, 129, _BLOCK_SIZE_1):
206+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
207+ mask_1 = indices_2 < 129
208+ a_tile = tl.load(a + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
209+ b_tile = tl.load(b + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
210+ c_tile = tl.load(c + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
211+ stack_idx = tl.arange(0, 4)
212+ broadcast_idx = stack_idx[None, :, None]
213+ expanded_0 = tl.expand_dims(a_tile, 1)
214+ expanded_1 = tl.expand_dims(b_tile, 1)
215+ expanded_2 = tl.expand_dims(c_tile, 1)
216+ stacked_result = tl.zeros_like(expanded_0)
217+ mask_3 = broadcast_idx == 0
218+ stacked_result = tl.where(mask_3, expanded_0, stacked_result)
219+ mask_4 = broadcast_idx == 1
220+ stacked_result = tl.where(mask_4, expanded_1, stacked_result)
221+ mask_5 = broadcast_idx == 2
222+ stacked_result = tl.where(mask_5, expanded_2, stacked_result)
223+ tl.store(result + (indices_0[:, None, None] * 387 + indices_3[None, :, None] * 129 + indices_2[None, None, :] * 1), stacked_result, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :])
224+
225+ def test_stack_non_power_of_2_kernel(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, *, _launcher=_default_launcher):
226+ M, N = a.shape
227+ result = torch.zeros(M, 3, N, dtype=a.dtype, device=a.device)
228+ _BLOCK_SIZE_0 = 32
229+ _RDIM_SIZE_2 = 4
230+ _BLOCK_SIZE_1 = 32
231+ _launcher(_helion_test_stack_non_power_of_2_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
232+ return result
0 commit comments