@@ -84,9 +84,45 @@ def _(state: CodegenState) -> ast.AST:
8484 )
8585
8686
87+ @_decorators .ref (store )
88+ def _ (
89+ tensor : torch .Tensor ,
90+ index : list [object ],
91+ value : torch .Tensor | torch .SymInt | float ,
92+ extra_mask : torch .Tensor | None = None ,
93+ ) -> None :
94+ # Convert index list to tuple for tensor indexing
95+ index_tuple = tuple (index )
96+
97+ # Apply extra mask if provided
98+ if extra_mask is not None :
99+ # Only store where the mask is True
100+ if isinstance (value , torch .Tensor ):
101+ tensor [index_tuple ] = torch .where (extra_mask , value , tensor [index_tuple ]) # pyright: ignore[reportArgumentType]
102+ else :
103+ # For scalar values, we need to create a tensor of the right shape
104+ current = tensor [index_tuple ] # pyright: ignore[reportArgumentType]
105+ # Cast value to a proper numeric type for full_like
106+ if isinstance (value , torch .SymInt ):
107+ numeric_value = int (value )
108+ else :
109+ numeric_value = value
110+ tensor [index_tuple ] = torch .where ( # pyright: ignore[reportArgumentType]
111+ extra_mask , torch .full_like (current , numeric_value ), current
112+ )
113+ else :
114+ # Handle SymInt case for assignment
115+ if isinstance (value , torch .SymInt ):
116+ tensor [index_tuple ] = int (value ) # pyright: ignore[reportArgumentType]
117+ else :
118+ tensor [index_tuple ] = value # pyright: ignore[reportArgumentType]
119+
120+
87121@_decorators .api (tiles_as_sizes = True , allow_host_tensor = True )
88122def load (
89- tensor : torch .Tensor , index : list [object ], extra_mask : torch .Tensor | None = None
123+ tensor : torch .Tensor ,
124+ index : list [object ],
125+ extra_mask : torch .Tensor | None = None ,
90126) -> torch .Tensor :
91127 """Load a value from a tensor using a list of indices.
92128
@@ -129,6 +165,83 @@ def _(node: torch.fx.Node) -> int:
129165 return 0 # loads are always masked to 0
130166
131167
168+ @_decorators .ref (load )
169+ def _ (
170+ tensor : torch .Tensor ,
171+ index : list [object ],
172+ extra_mask : torch .Tensor | None = None ,
173+ ) -> torch .Tensor :
174+ from .ref_tile import RefTile
175+
176+ if extra_mask is None :
177+ return tensor [tuple (index )] # pyright: ignore[reportArgumentType]
178+
179+ # Create zero result matching mask shape
180+ result = torch .zeros (extra_mask .shape , dtype = tensor .dtype , device = tensor .device )
181+
182+ # Process indices: convert RefTiles and clamp tensor indices
183+ orig_indices , safe_indices , is_tensor_mask = [], [], []
184+ for i , idx in enumerate (index ):
185+ if isinstance (idx , RefTile ):
186+ idx = idx .index # Convert RefTile to tensor
187+
188+ if isinstance (idx , torch .Tensor ):
189+ dim_size = tensor .shape [i ] if i < len (tensor .shape ) else tensor .numel ()
190+ orig_indices .append (idx )
191+ safe_indices .append (torch .clamp (idx , 0 , dim_size - 1 ))
192+ is_tensor_mask .append (True )
193+ else :
194+ orig_indices .append (idx )
195+ safe_indices .append (idx )
196+ is_tensor_mask .append (False )
197+
198+ # Apply broadcasting if we have multiple tensor indices
199+ tensor_positions = [i for i , is_tensor in enumerate (is_tensor_mask ) if is_tensor ]
200+
201+ if len (tensor_positions ) > 1 :
202+ # Add unsqueeze operations for broadcasting
203+ broadcast_indices = []
204+ for i , (idx , is_tensor ) in enumerate (
205+ zip (safe_indices , is_tensor_mask , strict = False )
206+ ):
207+ if is_tensor :
208+ new_idx = idx
209+ # Add dimension for each other tensor index
210+ for j , other_pos in enumerate (tensor_positions ):
211+ if other_pos != i :
212+ new_idx = new_idx .unsqueeze (j if other_pos < i else - 1 )
213+ broadcast_indices .append (new_idx )
214+ else :
215+ broadcast_indices .append (idx )
216+ values = tensor [tuple (broadcast_indices )]
217+ else :
218+ values = tensor [tuple (safe_indices )]
219+
220+ # Build validity mask
221+ valid_mask = extra_mask .clone ()
222+ for i , (orig_idx , is_tensor ) in enumerate (
223+ zip (orig_indices , is_tensor_mask , strict = False )
224+ ):
225+ if is_tensor :
226+ dim_size = tensor .shape [i ] if i < len (tensor .shape ) else tensor .numel ()
227+ in_bounds = (orig_idx >= 0 ) & (orig_idx < dim_size )
228+ # Broadcast to match mask shape by adding dimensions
229+ # Count how many tensor indices come before and after this one
230+ n_before = sum (1 for j in range (i ) if is_tensor_mask [j ])
231+ n_after = sum (
232+ 1 for j in range (i + 1 , len (is_tensor_mask )) if is_tensor_mask [j ]
233+ )
234+
235+ # Add dimensions: n_after dimensions at the end, n_before at the beginning
236+ for _ in range (n_after ):
237+ in_bounds = in_bounds .unsqueeze (- 1 )
238+ for _ in range (n_before ):
239+ in_bounds = in_bounds .unsqueeze (0 )
240+ valid_mask = valid_mask & in_bounds
241+
242+ return torch .where (valid_mask , values , result )
243+
244+
132245@has_side_effect
133246@_decorators .api (allow_host_tensor = True )
134247def atomic_add (
@@ -210,6 +323,59 @@ def _(
210323 return None
211324
212325
326+ @_decorators .ref (atomic_add )
327+ def _ (
328+ target : torch .Tensor ,
329+ index : list [object ],
330+ value : torch .Tensor | float ,
331+ sem : str = "relaxed" ,
332+ ) -> None :
333+ """Reference implementation of atomic_add for interpret mode."""
334+ from .. import exc
335+ from .ref_tile import RefTile
336+
337+ # Validate sem parameter
338+ if sem not in ["relaxed" , "acquire" , "release" , "acq_rel" ]:
339+ raise exc .InternalError (
340+ ValueError (
341+ f"Invalid memory semantic '{ sem } '. Valid options are: relaxed, acquire, release, acq_rel"
342+ )
343+ )
344+
345+ # Convert indices to proper format
346+ processed_index = []
347+ for idx in index :
348+ if isinstance (idx , RefTile ):
349+ processed_index .append (idx ._slice )
350+ elif isinstance (idx , torch .Tensor ) and idx .numel () == 1 :
351+ processed_index .append (int (idx .item ()))
352+ else :
353+ processed_index .append (idx )
354+
355+ # Find tensor indices that need element-wise processing
356+ tensor_indices = [
357+ (i , idx )
358+ for i , idx in enumerate (processed_index )
359+ if isinstance (idx , torch .Tensor ) and idx .numel () > 1
360+ ]
361+
362+ if tensor_indices :
363+ # Element-wise processing for tensor indices
364+ i , tensor_idx = tensor_indices [0 ] # Handle first tensor index
365+ for j , elem in enumerate (tensor_idx ):
366+ new_index = processed_index .copy ()
367+ new_index [i ] = int (elem .item ())
368+ val = (
369+ value [j ]
370+ if isinstance (value , torch .Tensor ) and value .numel () > 1
371+ else value
372+ )
373+ target [tuple (new_index )] += val
374+ else :
375+ # Direct atomic add
376+ target [tuple (processed_index )] += value
377+
378+
213379@_decorators .codegen (atomic_add )
214380def _ (state : CodegenState ) -> ast .AST :
215381 target = state .proxy_arg (0 )
0 commit comments