Skip to content

Commit 2c6e526

Browse files
committed
Add cuda_tile tests
1 parent af6a557 commit 2c6e526

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

.github/workflows/nbcc_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ jobs:
5252
- name: Test
5353
run: |
5454
conda activate ./envs/dev
55+
conda install sklam::cuda_tile_mlir -y
5556
pytest ./nbcc
5657
57-
5858
- name: MyPy
5959
run: |
6060
conda activate ./envs/dev

nbcc/tests/test_cuda_tile.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import os
2+
import os.path
3+
import tempfile
4+
from contextlib import contextmanager
5+
from ctypes import CDLL, byref, c_double
6+
from pathlib import Path
7+
from typing import Generator, Any
8+
9+
import pytest
10+
from mlir.runtime import (
11+
get_ranked_memref_descriptor,
12+
make_nd_memref_descriptor,
13+
ranked_memref_to_numpy,
14+
)
15+
16+
import nbcc
17+
from nbcc.compiler import compile_to_mlir
18+
try:
19+
import cuda_tile
20+
except ImportError:
21+
HAS_CUDA_TILE = False
22+
else:
23+
HAS_CUDA_TILE = True
24+
25+
if HAS_CUDA_TILE:
26+
from nbcc.cutile_backend.backend import CuTileBackend
27+
28+
29+
example_dir = Path(os.path.dirname(nbcc.__file__)) / ".." / "examples" / "cuda_tile"
30+
31+
32+
@contextmanager
33+
def make_temp_directory() -> Generator[Path, None, None]:
34+
with tempfile.TemporaryDirectory(delete=False) as dirpath:
35+
yield Path(dirpath)
36+
37+
38+
@contextmanager
39+
def compile_mlir(filename: str) -> Generator[Any, None, None]:
40+
path = example_dir / filename
41+
assert path.exists()
42+
with make_temp_directory() as dir:
43+
mlir_mod = compile_to_mlir(str(path), be_type=CuTileBackend)
44+
yield mlir_mod
45+
46+
47+
def test_has_examples():
48+
assert example_dir.exists()
49+
50+
@pytest.mark.skipif(not HAS_CUDA_TILE, reason="no cuda_tile")
51+
def test_cuda_tile_to_mlir():
52+
with compile_mlir("tile_example.spy") as mlir_mod:
53+
mlir_text = mlir_mod.operation.get_asm()
54+
assert "entry @spy_tile_example$exported$export_foo" in mlir_text
55+
assert "ntry @spy_tile_example$exported$export_vecadd" in mlir_text

0 commit comments

Comments
 (0)