66from typing import TYPE_CHECKING
77from typing import cast
88
9+ import torch
910from torch ._inductor .runtime .runtime_utils import next_power_of_2
1011
1112from .._compat import supports_amd_cdna_tunables
5253 "num_warps" ,
5354 "num_stages" ,
5455 "pid_type" ,
56+ "num_sm_multiplier" ,
57+ "maxnreg" ,
5558 "indexing" ,
5659 "load_eviction_policies" ,
5760 * AMD_CDNA_TUNABLES ,
5861 ]
5962)
6063VALID_PID_TYPES = ("flat" , "xyz" , "persistent_blocked" , "persistent_interleaved" )
64+ MIN_NUM_SM_MULTIPLIER = 1
65+ MAX_NUM_SM_MULTIPLIER = 128
66+ DEFAULT_NUM_SM_MULTIPLIER = 1
67+ # maxnreg values: None means no limit, otherwise limit to this many registers per thread
68+ # Lower values allow higher occupancy but may hurt performance for register-heavy kernels
69+ VALID_MAXNREG = (None , 32 , 64 , 128 , 256 )
70+ DEFAULT_MAXNREG = None
6171VALID_EVICTION_POLICIES = ("" , "first" , "last" )
6272VALID_WAVES_PER_EU = (1 , 2 , 3 , 4 )
6373VALID_MATRIX_INSTR_NONKDIM = (0 , 16 , 32 )
@@ -158,10 +168,18 @@ def disallow_pid_type(self, pid_type: PidTypeLiteral) -> None:
158168 )
159169 assert self .allowed_pid_types
160170
161- def normalize (self , config : helion .Config | dict [str , object ]) -> None :
162- """Normalize the config to match the block_sizes and validate the config."""
171+ def normalize (
172+ self , config : helion .Config | dict [str , object ], * , _fix_invalid : bool = False
173+ ) -> None :
174+ """Normalize the config to match the block_sizes and validate the config.
175+
176+ Args:
177+ config: The config to normalize (modified in place).
178+ _fix_invalid: If True, silently fix invalid combinations instead of raising
179+ errors. Used internally during autotuning config generation.
180+ """
163181 if isinstance (config , helion .Config ):
164- self .normalize (config .config )
182+ self .normalize (config .config , _fix_invalid = _fix_invalid )
165183 return
166184
167185 for name in (
@@ -250,19 +268,84 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
250268 elif key in config :
251269 raise InvalidConfig (f"{ key } is not supported on this target hardware" )
252270
253- # TODO(jansel): include num_ctas and max_nreg
271+ if "pid_type" in config :
272+ if config ["pid_type" ] not in VALID_PID_TYPES :
273+ raise InvalidConfig (
274+ f"Invalid value for 'pid_type': { config ['pid_type' ]!r} must be one of { list (VALID_PID_TYPES )!r} "
275+ )
276+ else :
277+ config ["pid_type" ] = VALID_PID_TYPES [0 ]
278+
279+ # Validate num_sm_multiplier is a power of two in range
280+ if "num_sm_multiplier" in config :
281+ val = config ["num_sm_multiplier" ]
282+ if (
283+ not isinstance (val , int )
284+ or val < MIN_NUM_SM_MULTIPLIER
285+ or val > MAX_NUM_SM_MULTIPLIER
286+ or (val & (val - 1 )) != 0 # not a power of two
287+ ):
288+ raise InvalidConfig (
289+ f"Invalid value for 'num_sm_multiplier': { val !r} must be a power of two between { MIN_NUM_SM_MULTIPLIER } and { MAX_NUM_SM_MULTIPLIER } "
290+ )
291+ else :
292+ config ["num_sm_multiplier" ] = DEFAULT_NUM_SM_MULTIPLIER
254293
255- for name , values in (("pid_type" , VALID_PID_TYPES ),):
256- if name in config :
257- if config [name ] not in values :
294+ # Only validate maxnreg on non-AMD devices (not supported on AMD)
295+ if torch .version .hip is None :
296+ if "maxnreg" in config :
297+ if config ["maxnreg" ] not in VALID_MAXNREG :
258298 raise InvalidConfig (
259- f"Invalid value for { name !r } : { config [name ]!r} must be one of { [ * values ] !r} "
299+ f"Invalid value for 'maxnreg' : { config ['maxnreg' ]!r} must be one of { list ( VALID_MAXNREG ) !r} "
260300 )
261301 else :
262- config [name ] = values [0 ]
302+ config ["maxnreg" ] = VALID_MAXNREG [0 ]
303+ else :
304+ # Remove maxnreg on AMD if present
305+ config .pop ("maxnreg" , None )
263306
264- # Set default values for grid indices when pid_type is not persistent
307+ # Handle num_sm_multiplier and maxnreg for non-persistent pid_types
308+ # These options only make sense for persistent kernels
265309 pid_type = config ["pid_type" ]
310+ if pid_type in ("flat" , "xyz" ):
311+ # Handle num_sm_multiplier
312+ num_sm_multiplier = config .get (
313+ "num_sm_multiplier" , DEFAULT_NUM_SM_MULTIPLIER
314+ )
315+ if num_sm_multiplier != DEFAULT_NUM_SM_MULTIPLIER :
316+ if _fix_invalid :
317+ # Silently fix during autotuning config generation
318+ config .pop ("num_sm_multiplier" , None )
319+ else :
320+ # Raise error for user-specified invalid combinations
321+ raise InvalidConfig (
322+ f"num_sm_multiplier={ num_sm_multiplier } can only be used with persistent "
323+ f"pid_type ('persistent_blocked' or 'persistent_interleaved'), "
324+ f"got pid_type={ pid_type !r} "
325+ )
326+ else :
327+ # Remove default value from config
328+ config .pop ("num_sm_multiplier" , None )
329+
330+ # Handle maxnreg - only makes sense for persistent kernels (and only on non-AMD)
331+ if torch .version .hip is None :
332+ maxnreg = config .get ("maxnreg" , DEFAULT_MAXNREG )
333+ if maxnreg != DEFAULT_MAXNREG :
334+ if _fix_invalid :
335+ # Silently fix during autotuning config generation
336+ config .pop ("maxnreg" , None )
337+ else :
338+ # Raise error for user-specified invalid combinations
339+ raise InvalidConfig (
340+ f"maxnreg={ maxnreg } can only be used with persistent "
341+ f"pid_type ('persistent_blocked' or 'persistent_interleaved'), "
342+ f"got pid_type={ pid_type !r} "
343+ )
344+ else :
345+ # Remove default value from config
346+ config .pop ("maxnreg" , None )
347+
348+ # Set default values for grid indices when pid_type is not persistent
266349 if pid_type in ("flat" , "xyz" ) and self .grid_block_ids :
267350 for name , mapping in (
268351 ("range_unroll_factors" , self .range_unroll_factors ),
@@ -322,8 +405,18 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
322405 "num_stages" : fn (IntegerFragment (1 , 8 , DEFAULT_NUM_STAGES )),
323406 "indexing" : fn (self .indexing ),
324407 "pid_type" : fn (EnumFragment (self .allowed_pid_types )),
408+ "num_sm_multiplier" : fn (
409+ PowerOfTwoFragment (
410+ MIN_NUM_SM_MULTIPLIER ,
411+ MAX_NUM_SM_MULTIPLIER ,
412+ DEFAULT_NUM_SM_MULTIPLIER ,
413+ )
414+ ),
325415 "load_eviction_policies" : fn (self .load_eviction_policies ),
326416 }
417+ # Only include maxnreg on non-AMD devices (not supported on AMD)
418+ if torch .version .hip is None :
419+ config ["maxnreg" ] = fn (EnumFragment (VALID_MAXNREG ))
327420 # Add tunable parameters
328421 config .update (
329422 {key : fn (fragment ) for key , fragment in self .user_defined_tunables .items ()}
@@ -345,7 +438,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
345438 ):
346439 if not config .get (name ):
347440 config .pop (name , None )
348- self .normalize (config )
441+ self .normalize (config , _fix_invalid = True )
349442 # pyrefly: ignore [bad-argument-type]
350443 return helion .Config (** config )
351444
0 commit comments