Skip to content

Commit 0f3e2d5

Browse files
authored
Add torch.stack support (#524)
1 parent 467cfee commit 0f3e2d5

File tree

4 files changed

+278
-2
lines changed

4 files changed

+278
-2
lines changed

helion/_compiler/device_ir.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ class _TLS(Protocol):
7373
tls: _TLS = cast("_TLS", threading.local())
7474

7575

76+
def _get_custom_decomp_table() -> dict[torch._ops.OpOverload, Callable[..., object]]:
77+
decomp_table = select_decomp_table().copy()
78+
# Normally, aten.stack is decomposed to aten.unsqueeze + aten.cat, but it's difficult to
79+
# figure out the right Triton implementation for aten.cat. As a workaround, we disable
80+
# the decomp for aten.stack and implement aten.stack in Triton (codegen_stack) instead.
81+
decomp_table.pop(torch.ops.aten.stack.default, None)
82+
return decomp_table
83+
84+
7685
def _make_fx(fn: Callable[..., object], *args: object) -> torch.fx.Graph:
7786
"""
7887
We monkey patch get_proxy_slot to support Tensor/SymInt/SymFloat/SymBool in the
@@ -628,7 +637,7 @@ def run_subgraph(*args: object) -> list[object]:
628637

629638
with self.disable_tracing() as tracer:
630639
graph = proxy_tensor.make_fx(
631-
run_subgraph, decomposition_table=select_decomp_table()
640+
run_subgraph, decomposition_table=_get_custom_decomp_table()
632641
)(*inputs.get_tensor_args()).graph
633642
graph_idx = self.device_ir.add_graph(
634643
graph,
@@ -711,7 +720,7 @@ def run_body(*args: object) -> list[object]:
711720

712721
with self.disable_tracing() as tracer:
713722
body_graph = proxy_tensor.make_fx(
714-
run_body, decomposition_table=select_decomp_table()
723+
run_body, decomposition_table=_get_custom_decomp_table()
715724
)(*inputs.get_tensor_args()).graph
716725
assert outputs is not None
717726
graph_idx = self.device_ir.add_graph(

helion/_compiler/inductor_lowering.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,61 @@ def codegen_permute(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
839839
)
840840

841841

842+
@register_lowering(
843+
torch.ops.aten.stack.default, # pyright: ignore[reportAttributeAccessIssue]
844+
masked_value_fn=passthrough_masked_value,
845+
)
846+
def codegen_stack(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
847+
tensors = node.args[0]
848+
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
849+
850+
assert isinstance(tensors, (list, tuple))
851+
tensor_asts = [ctx.env[t] for t in tensors] # pyright: ignore[reportArgumentType]
852+
n = len(tensor_asts)
853+
854+
if n == 0:
855+
raise ValueError("Cannot stack empty tensor list")
856+
857+
# Round up to power of 2 for efficient masking
858+
padded_size = 1 << (n - 1).bit_length()
859+
860+
# Create index array [0, 1, 2, 3, ...] for tensor selection
861+
idx = ctx.cg.device_function.new_var("stack_idx")
862+
ctx.cg.add_statement(statement_from_string(f"{idx} = tl.arange(0, {padded_size})"))
863+
864+
# Broadcast index to target dimension shape
865+
# e.g., dim=0: [:, None, None], dim=1: [None, :, None], dim=2: [None, None, :]
866+
bidx = ctx.cg.device_function.new_var("broadcast_idx")
867+
assert isinstance(dim, int)
868+
pattern = "[" + ", ".join(["None"] * dim + [":"] + ["None"] * max(0, 2 - dim)) + "]"
869+
ctx.cg.add_statement(statement_from_string(f"{bidx} = {idx}{pattern}"))
870+
871+
# Expand each input tensor along the stack dimension
872+
expanded = [ctx.cg.device_function.new_var(f"expanded_{i}") for i in range(n)]
873+
for var, tensor in zip(expanded, tensor_asts, strict=False):
874+
ctx.cg.add_statement(
875+
statement_from_string(f"{var} = tl.expand_dims({{t}}, {dim})", t=tensor)
876+
)
877+
878+
# Initialize result with zeros
879+
result = ctx.cg.device_function.new_var("stacked_result")
880+
ctx.cg.add_statement(
881+
statement_from_string(f"{result} = tl.zeros_like({expanded[0]})")
882+
)
883+
884+
# Select each tensor using masks
885+
for i in range(n):
886+
mask = ctx.cg.device_function.new_var(f"mask_{i}")
887+
ctx.cg.add_statement(statement_from_string(f"{mask} = {bidx} == {i}"))
888+
ctx.cg.add_statement(
889+
statement_from_string(
890+
f"{result} = tl.where({mask}, {expanded[i]}, {result})"
891+
)
892+
)
893+
894+
return expr_from_string(result)
895+
896+
842897
@register_lowering(
843898
torch.ops.aten.expand.default, # pyright: ignore[reportAttributeAccessIssue]
844899
masked_value_fn=passthrough_masked_value,

test/test_views.expected

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,93 @@ def fn(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
140140
_BLOCK_SIZE_1 = 32
141141
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, y, out, out.size(0), out.size(1), x.size(0), x.size(1), y.size(0), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
142142
return out
143+
144+
--- assertExpectedJournal(TestViews.test_stack_dim0)
145+
from __future__ import annotations
146+
147+
import torch
148+
import triton
149+
import triton.language as tl
150+
from helion.runtime import default_launcher as _default_launcher
151+
152+
@triton.jit
153+
def _helion_test_stack_dim0_kernel(a, b, c, result, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
154+
pid_0 = tl.program_id(0)
155+
offset_0 = pid_0 * _BLOCK_SIZE_0
156+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
157+
mask_0 = indices_0 < 65
158+
indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
159+
mask_2 = indices_3 < 3
160+
for offset_2 in tl.range(0, 129, _BLOCK_SIZE_1):
161+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
162+
mask_1 = indices_2 < 129
163+
a_tile = tl.load(a + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
164+
b_tile = tl.load(b + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
165+
c_tile = tl.load(c + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
166+
stack_idx = tl.arange(0, 4)
167+
broadcast_idx = stack_idx[:, None, None]
168+
expanded_0 = tl.expand_dims(a_tile, 0)
169+
expanded_1 = tl.expand_dims(b_tile, 0)
170+
expanded_2 = tl.expand_dims(c_tile, 0)
171+
stacked_result = tl.zeros_like(expanded_0)
172+
mask_3 = broadcast_idx == 0
173+
stacked_result = tl.where(mask_3, expanded_0, stacked_result)
174+
mask_4 = broadcast_idx == 1
175+
stacked_result = tl.where(mask_4, expanded_1, stacked_result)
176+
mask_5 = broadcast_idx == 2
177+
stacked_result = tl.where(mask_5, expanded_2, stacked_result)
178+
tl.store(result + (indices_3[:, None, None] * 8385 + indices_0[None, :, None] * 129 + indices_2[None, None, :] * 1), stacked_result, mask_2[:, None, None] & mask_0[None, :, None] & mask_1[None, None, :])
179+
180+
def test_stack_dim0_kernel(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, *, _launcher=_default_launcher):
181+
M, N = a.shape
182+
result = torch.zeros(3, M, N, dtype=a.dtype, device=a.device)
183+
_BLOCK_SIZE_0 = 32
184+
_RDIM_SIZE_2 = 4
185+
_BLOCK_SIZE_1 = 32
186+
_launcher(_helion_test_stack_dim0_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
187+
return result
188+
189+
--- assertExpectedJournal(TestViews.test_stack_non_power_of_2)
190+
from __future__ import annotations
191+
192+
import torch
193+
import triton
194+
import triton.language as tl
195+
from helion.runtime import default_launcher as _default_launcher
196+
197+
@triton.jit
198+
def _helion_test_stack_non_power_of_2_kernel(a, b, c, result, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
199+
pid_0 = tl.program_id(0)
200+
offset_0 = pid_0 * _BLOCK_SIZE_0
201+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
202+
mask_0 = indices_0 < 65
203+
indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
204+
mask_2 = indices_3 < 3
205+
for offset_2 in tl.range(0, 129, _BLOCK_SIZE_1):
206+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
207+
mask_1 = indices_2 < 129
208+
a_tile = tl.load(a + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
209+
b_tile = tl.load(b + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
210+
c_tile = tl.load(c + (indices_0[:, None] * 129 + indices_2[None, :] * 1), mask_0[:, None] & mask_1[None, :], other=0)
211+
stack_idx = tl.arange(0, 4)
212+
broadcast_idx = stack_idx[None, :, None]
213+
expanded_0 = tl.expand_dims(a_tile, 1)
214+
expanded_1 = tl.expand_dims(b_tile, 1)
215+
expanded_2 = tl.expand_dims(c_tile, 1)
216+
stacked_result = tl.zeros_like(expanded_0)
217+
mask_3 = broadcast_idx == 0
218+
stacked_result = tl.where(mask_3, expanded_0, stacked_result)
219+
mask_4 = broadcast_idx == 1
220+
stacked_result = tl.where(mask_4, expanded_1, stacked_result)
221+
mask_5 = broadcast_idx == 2
222+
stacked_result = tl.where(mask_5, expanded_2, stacked_result)
223+
tl.store(result + (indices_0[:, None, None] * 387 + indices_3[None, :, None] * 129 + indices_2[None, None, :] * 1), stacked_result, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :])
224+
225+
def test_stack_non_power_of_2_kernel(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, *, _launcher=_default_launcher):
226+
M, N = a.shape
227+
result = torch.zeros(M, 3, N, dtype=a.dtype, device=a.device)
228+
_BLOCK_SIZE_0 = 32
229+
_RDIM_SIZE_2 = 4
230+
_BLOCK_SIZE_1 = 32
231+
_launcher(_helion_test_stack_non_power_of_2_kernel, (triton.cdiv(65, _BLOCK_SIZE_0),), a, b, c, result, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
232+
return result

test/test_views.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,128 @@ def fn(x: torch.Tensor) -> torch.Tensor:
209209
torch.testing.assert_close(result, expected)
210210
self.assertExpectedJournal(code)
211211

212+
def test_stack_power_of_2(self):
213+
@helion.kernel(use_default_config=True, static_shapes=True)
214+
def test_stack_power_of_2_kernel(
215+
a: torch.Tensor, b: torch.Tensor
216+
) -> torch.Tensor:
217+
M, N = a.shape
218+
result = torch.zeros(M * 2, N, dtype=a.dtype, device=a.device)
219+
220+
for tile_m in hl.tile(M):
221+
for tile_n in hl.tile(N):
222+
a_tile = a[tile_m, tile_n]
223+
b_tile = b[tile_m, tile_n]
224+
225+
# Stack tensors along dim=1 (creates [BLOCK_M, 2, BLOCK_N])
226+
stacked = torch.stack([a_tile, b_tile], dim=1)
227+
228+
# Reshape to [BLOCK_M * 2, BLOCK_N]
229+
reshaped = stacked.reshape(tile_m.block_size * 2, tile_n.block_size)
230+
231+
result[
232+
(tile_m.begin * 2) : (tile_m.begin * 2 + tile_m.block_size * 2),
233+
tile_n,
234+
] = reshaped
235+
236+
return result
237+
238+
M, N = 64, 128
239+
device = DEVICE
240+
241+
a = torch.randn(M, N, dtype=torch.float32, device=device)
242+
b = torch.randn(M, N, dtype=torch.float32, device=device)
243+
244+
result = test_stack_power_of_2_kernel(a, b)
245+
expected = torch.zeros(M * 2, N, dtype=torch.float32, device=device)
246+
expected[0::2] = a # Every 2nd row starting from 0
247+
expected[1::2] = b # Every 2nd row starting from 1
248+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
249+
250+
def test_stack_non_power_of_2(self):
251+
@helion.kernel(use_default_config=True, static_shapes=True)
252+
def test_stack_non_power_of_2_kernel(
253+
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
254+
) -> torch.Tensor:
255+
M, N = a.shape
256+
result = torch.zeros(M, 3, N, dtype=a.dtype, device=a.device)
257+
258+
for tile_m in hl.tile(M):
259+
for tile_n in hl.tile(N):
260+
a_tile = a[tile_m, tile_n]
261+
b_tile = b[tile_m, tile_n]
262+
c_tile = c[tile_m, tile_n]
263+
264+
# Stack tensors along dim=1 (creates [BLOCK_M, 3, BLOCK_N])
265+
stacked = torch.stack([a_tile, b_tile, c_tile], dim=1)
266+
267+
result[tile_m, :, tile_n] = stacked
268+
269+
return result
270+
271+
M, N = 65, 129
272+
device = DEVICE
273+
274+
a = torch.randn(M, N, dtype=torch.float32, device=device)
275+
b = torch.randn(M, N, dtype=torch.float32, device=device)
276+
c = torch.randn(M, N, dtype=torch.float32, device=device)
277+
278+
code, result = code_and_output(test_stack_non_power_of_2_kernel, (a, b, c))
279+
expected = torch.stack([a, b, c], dim=1)
280+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
281+
self.assertExpectedJournal(code)
282+
283+
def test_stack_dim0(self):
284+
@helion.kernel(use_default_config=True, static_shapes=True)
285+
def test_stack_dim0_kernel(
286+
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
287+
) -> torch.Tensor:
288+
M, N = a.shape
289+
result = torch.zeros(3, M, N, dtype=a.dtype, device=a.device)
290+
291+
for tile_m in hl.tile(M):
292+
for tile_n in hl.tile(N):
293+
a_tile = a[tile_m, tile_n]
294+
b_tile = b[tile_m, tile_n]
295+
c_tile = c[tile_m, tile_n]
296+
297+
# Stack 3 tensors along dim=0
298+
# This creates [3, BLOCK_M, BLOCK_N]
299+
stacked = torch.stack([a_tile, b_tile, c_tile], dim=0)
300+
301+
result[:, tile_m, tile_n] = stacked
302+
303+
return result
304+
305+
M, N = 65, 129
306+
device = DEVICE
307+
308+
a = torch.randn(M, N, dtype=torch.float32, device=device)
309+
b = torch.randn(M, N, dtype=torch.float32, device=device)
310+
c = torch.randn(M, N, dtype=torch.float32, device=device)
311+
312+
code, result = code_and_output(test_stack_dim0_kernel, (a, b, c))
313+
expected = torch.stack([a, b, c], dim=0)
314+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
315+
self.assertExpectedJournal(code)
316+
317+
# Verify torch.compile still decomposes aten.stack to aten.cat
318+
from torch._inductor import config as inductor_config
319+
320+
def capture_graph(graph):
321+
self._graph = str(graph)
322+
return graph
323+
324+
with inductor_config.patch(post_grad_custom_pre_pass=capture_graph):
325+
torch.compile(
326+
lambda x, y, z: torch.stack([x, y, z], dim=0), backend="inductor"
327+
)(
328+
torch.randn(4, 4, device=device),
329+
torch.randn(4, 4, device=device),
330+
torch.randn(4, 4, device=device),
331+
)
332+
assert "aten.cat" in self._graph and "aten.stack" not in self._graph
333+
212334

213335
if __name__ == "__main__":
214336
unittest.main()

0 commit comments

Comments
 (0)