Skip to content

Commit d2bf84d

Browse files
Sibylauyf225
andauthored
Fixes #447: throw an error when printing output code in eager mode (#528)
Co-authored-by: Will Feng <yfeng.us@gmail.com>
1 parent 0f3e2d5 commit d2bf84d

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed

helion/exc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,10 @@ class GraphModuleUnsupportedOps(BaseError):
382382
message = "GraphModule contains unsupported operations: {0}. Only pure computation graphs are supported (no load_attr or call_module ops)."
383383

384384

385+
class RefEagerModeCodePrintError(BaseError):
386+
message = "No generated code to print out if ref eager mode is enabled."
387+
388+
385389
class NoDeviceLoopsInKernel(BaseError):
386390
message = (
387391
"Kernel contains no device loops. Add an hl.tile(...) or hl.grid(...) loop "

helion/runtime/settings.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,14 @@ def __init__(self, **settings: object) -> None:
128128
Args:
129129
settings: Keyword arguments representing various settings.
130130
"""
131+
131132
if defaults := getattr(_tls, "default_settings", None):
132133
settings = {**defaults.to_dict(), **settings}
133134

134135
super().__init__(**settings) # pyright: ignore[reportArgumentType]
135136

137+
self._check_ref_eager_mode_before_print_output_code()
138+
136139
def to_dict(self) -> dict[str, object]:
137140
"""
138141
Convert the Settings object to a dictionary.
@@ -162,6 +165,13 @@ def check_autotuning_disabled(self) -> None:
162165
if msg:
163166
raise exc.AutotuningDisallowedInEnvironment(msg)
164167

168+
def _check_ref_eager_mode_before_print_output_code(self) -> None:
169+
"""
170+
Check if ref eager mode is enabled before printing output code. If ref eager mode is enabled, raise an error.
171+
"""
172+
if self.ref_mode == RefMode.EAGER and self.print_output_code:
173+
raise exc.RefEagerModeCodePrintError
174+
165175
@staticmethod
166176
def default() -> Settings:
167177
"""

test/test_print_ref_eager_mode.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import annotations
2+
3+
import contextlib
4+
import io
5+
import unittest
6+
7+
import pytest
8+
import torch
9+
10+
import helion
11+
from helion import exc
12+
from helion._testing import TestCase
13+
import helion.language as hl
14+
15+
16+
class TestPrintOutputCode(TestCase):
17+
def test_ref_eager_mode_code_print_error(self):
18+
"""Test that RefEagerModeCodePrintError is raised when using @helion.kernel with both settings"""
19+
20+
with pytest.raises(exc.RefEagerModeCodePrintError):
21+
22+
@helion.kernel(
23+
use_default_config=True,
24+
print_output_code=True,
25+
ref_mode=helion.RefMode.EAGER,
26+
)
27+
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
28+
x, y = torch.broadcast_tensors(x, y)
29+
out = torch.empty(
30+
x.shape,
31+
dtype=torch.promote_types(x.dtype, y.dtype),
32+
device=x.device,
33+
)
34+
for tile in hl.tile(out.size()):
35+
out[tile] = x[tile] + y[tile]
36+
return out
37+
38+
x = torch.randn([512, 512], device="cuda", dtype=torch.float16)
39+
y = torch.randn([512, 512], device="cuda", dtype=torch.float16)
40+
torch.testing.assert_close(add(x, y), torch.add(x, y))
41+
42+
def test_normal_mode_code_print(self):
43+
"""Test that output code is in stderr when using @helion.kernel with normal mode"""
44+
45+
f = io.StringIO()
46+
with contextlib.redirect_stderr(f):
47+
48+
@helion.kernel(
49+
use_default_config=True,
50+
print_output_code=True,
51+
ref_mode=helion.RefMode.OFF,
52+
)
53+
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
54+
x, y = torch.broadcast_tensors(x, y)
55+
out = torch.empty(
56+
x.shape,
57+
dtype=torch.promote_types(x.dtype, y.dtype),
58+
device=x.device,
59+
)
60+
for tile in hl.tile(out.size()):
61+
out[tile] = x[tile] + y[tile]
62+
return out
63+
64+
x = torch.randn([512, 512], device="cuda", dtype=torch.float16)
65+
y = torch.randn([512, 512], device="cuda", dtype=torch.float16)
66+
torch.testing.assert_close(add(x, y), torch.add(x, y))
67+
68+
self.assertNotEqual(
69+
f.getvalue(),
70+
"",
71+
"Output code in stderr should not be empty at normal mode.",
72+
)
73+
74+
75+
if __name__ == "__main__":
76+
unittest.main()

0 commit comments

Comments
 (0)