Skip to content

Commit af6a557

Browse files
committed
mypy fixes
1 parent 8c9c18f commit af6a557

File tree

4 files changed

+29
-20
lines changed

4 files changed

+29
-20
lines changed

nbcc/cutile_backend/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def entry(
3232
occupancy=None,
3333
loc=None,
3434
ip=None,
35-
) -> Tile:
35+
) -> Any:
3636
"""
3737
from https://github.com/NVIDIA/cuda-tile/blob/8a775693b18303d6c696be6ffd06dadad1b32a8e/python/cuda_tile/dialects/cuda_tile_ops.py#L2C37-L2C44
3838
"""

nbcc/frontend/extra_spy_builtins.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Annotated, cast
1+
from typing import TYPE_CHECKING, Annotated, cast, Any
22

33
from spy.fqn import FQN
44
from ..mlir_utils import (
@@ -38,7 +38,7 @@ class W_MLIR_Value(W_Object):
3838
__spy_storage_category__ = "reference"
3939

4040

41-
_type_caches = {}
41+
_type_caches: dict[str, "W_MLIR_Type"] = {}
4242

4343

4444
@MLIR.builtin_type("MLIR_Type")
@@ -102,10 +102,10 @@ def w_opimpl(vm: "SPyVM", *args_w: W_Object) -> W_Object:
102102

103103
@MLIR.builtin_func("MLIR_unpack")
104104
def w_MLIR_unpack(
105-
vm: "SPyVM", w_fn: W_Object, w_idx: W_Object
105+
vm: "SPyVM", w_fn: W_Func, w_idx: W_Object
106106
) -> W_BuiltinFunc:
107107

108-
restype = w_fn.w_functype.w_restype
108+
restype = cast(W_MLIR_Type, w_fn.w_functype.w_restype)
109109

110110
assert restype.original_name.startswith("multivalues$")
111111
members = parse_composite_type(restype.original_name)
@@ -115,6 +115,8 @@ def decode_type(str_fqn: str) -> str:
115115
[enc] = fqn.parts[-1].qualifiers
116116
return decode_type_name(str(enc))
117117

118+
assert members is not None
119+
118120
types = list(
119121
map(lambda x: W_MLIR_Type.w_new(vm, vm.wrap(decode_type(x))), members)
120122
)
@@ -145,6 +147,7 @@ def w_MLIR_asm(
145147
vm: "SPyVM", w_asm: W_Str, w_restype: W_Object, w_argtypes: W_Tuple
146148
) -> W_BuiltinFunc:
147149

150+
RESTYPE: Any
148151
if isinstance(w_restype, W_Tuple):
149152
innernames = []
150153
for it_type in w_restype.items_w:

nbcc/mlir_lowering.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
from collections import defaultdict
55
from contextlib import contextmanager
66
from dataclasses import dataclass
7-
from typing import Any, Callable, Sequence, cast
7+
from typing import Any, Callable, Sequence, cast, Coroutine, TYPE_CHECKING
8+
9+
# Type aliases for MLIR types - backends will use their specific ir modules
10+
IRContext = Any
11+
IRType = Any
12+
IRLocation = Any
13+
IRInsertionPoint = Any
814

915

1016
from sealir import ase
@@ -70,8 +76,8 @@ class BackendInterface(ABC):
7076
"""
7177

7278
# Class attributes for MLIR context management
73-
Location: type # Set to ir.Location by implementations
74-
InsertionPoint: type # Set to ir.InsertionPoint by implementations
79+
Location: Any # Set to ir.Location by implementations
80+
InsertionPoint: Any # Set to ir.InsertionPoint by implementations
7581

7682
@classmethod
7783
@abstractmethod
@@ -87,47 +93,47 @@ def run_passes(self, module: Any, transforms: Any) -> Any: ...
8793
# Type Constants - Properties for clean access pattern
8894
@property
8995
@abstractmethod
90-
def context(self) -> ir.Context:
96+
def context(self) -> "IRContext":
9197
"""MLIR context for creating types and operations."""
9298

9399
@property
94100
@abstractmethod
95-
def i32(self) -> ir.Type:
101+
def i32(self) -> "IRType":
96102
"""32-bit integer type."""
97103

98104
@property
99105
@abstractmethod
100-
def i64(self) -> ir.Type:
106+
def i64(self) -> "IRType":
101107
"""64-bit integer type."""
102108

103109
@property
104110
@abstractmethod
105-
def f64(self) -> ir.Type:
111+
def f64(self) -> "IRType":
106112
"""64-bit float type."""
107113

108114
@property
109115
@abstractmethod
110-
def boolean(self) -> ir.Type:
116+
def boolean(self) -> "IRType":
111117
"""Boolean (1-bit integer) type."""
112118

113119
@property
114120
@abstractmethod
115-
def none_type(self) -> ir.Type:
121+
def none_type(self) -> "IRType":
116122
"""None/void type representation."""
117123

118124
@property
119125
@abstractmethod
120-
def io_type(self) -> ir.Type:
126+
def io_type(self) -> "IRType":
121127
"""IO token type for sequencing."""
122128

123129
@property
124130
@abstractmethod
125-
def llvm_ptr(self) -> ir.Type:
131+
def llvm_ptr(self) -> "IRType":
126132
"""LLVM pointer type for memory operations."""
127133

128134
# Core Methods
129135
@abstractmethod
130-
def lower_type(self, ty) -> tuple[ir.Type, ...]:
136+
def lower_type(self, ty) -> tuple["IRType", ...]:
131137
"""Convert SealIR types to backend IR types.
132138
133139
Returns a tuple of MLIR types. For single types, returns (type,).
@@ -136,7 +142,7 @@ def lower_type(self, ty) -> tuple[ir.Type, ...]:
136142
"""
137143

138144
@abstractmethod
139-
def get_ll_type(self, expr, mdmap) -> ir.Type | None:
145+
def get_ll_type(self, expr, mdmap) -> "IRType | None":
140146
"""Get backend type for expression with metadata."""
141147

142148
@abstractmethod
@@ -280,7 +286,7 @@ def lower(self, root: rg.Func) -> Any:
280286
"""
281287
context = self.be.context
282288
self.loc = loc = self.be.Location.name(
283-
f"{self}.lower()", context=context
289+
f"{self.__class__.__name__}.lower()", context=context
284290
)
285291
module = self.module
286292

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ mypy_path = ["deps/sealir", "deps/spy"]
3232
python_version = "3.12"
3333

3434
[[tool.mypy.overrides]]
35-
module = ["mlir.*", "sealir.*", "spy.*"]
35+
module = ["mlir.*", "sealir.*", "spy.*", "cuda_tile.*"]
3636
ignore_errors = true
3737
ignore_missing_imports = true
3838

0 commit comments

Comments
 (0)