1414from ._tensors import Tensor , disassemble_tensors , assemble_tensors , wrap , specs_equal , equality_by_shape_and_value , backend_for , Dense , TensorStack
1515from ._tree import disassemble_tree , assemble_tree , variable_attributes , NATIVE_TENSOR , object_dims , slice_ , find_differences
1616from ._magic_ops import stack , rename_dims , all_attributes
17- from ._sparse import SparseCoordinateTensor
17+ from ._sparse import SparseCoordinateTensor , is_sparse
1818from ._lin_trace import ShiftLinTracer , matrix_from_function , LinearTraceInProgress
1919from .magic import PhiTreeNode , Shapable
2020from ..backend import Backend
@@ -126,7 +126,17 @@ def key_from_args(args: tuple,
126126 cache = False ,
127127 aux : Set [str ] = (),
128128 attr_type = variable_attributes ,
129- for_jit = False ) -> Tuple [SignatureKey , List [Tensor ], tuple , Dict [str , Any ], Dict [str , Any ]]:
129+ use : str = None ) -> Tuple [SignatureKey , List [Tensor ], tuple , Dict [str , Any ], Dict [str , Any ]]:
130+ """
131+ Args:
132+ args:
133+ kwargs:
134+ parameters:
135+ cache:
136+ aux:
137+ attr_type:
138+ use: 'jit' or 'gradient' or 'linear'
139+ """
130140 kwargs = {** kwargs , ** {parameters [i ]: v for i , v in enumerate (args )}}
131141 attached_aux_kwargs = {}
132142 detached_aux_kwargs = {}
@@ -140,8 +150,8 @@ def key_from_args(args: tuple,
140150 _ , aux_tensors = disassemble_tree (detached_aux_kwargs , cache = cache , attr_type = variable_attributes )
141151 tracing = not math .all_available (* tensors , * aux_tensors )
142152 backend = backend_for (* tensors , * aux_tensors )
143- natives , shapes , specs = disassemble_tensors (tensors , expand = cache )
144- if for_jit and backend .name == 'torch' : # for PyTorch, add tracers from aux to natives, but keep aux. PyTorch does not support using tensors with grad inside jit otherwise.
153+ natives , shapes , specs = disassemble_tensors (tensors , expand = cache , include_constants = use != 'gradient' )
154+ if use == 'jit' and backend .name == 'torch' : # for PyTorch, add tracers from aux to natives, but keep aux. PyTorch does not support using tensors with grad inside jit otherwise.
145155 _ , aux_tensors = disassemble_tree (attached_aux_kwargs , cache = cache , attr_type = variable_attributes )
146156 aux_natives , aux_shapes , aux_specs = disassemble_tensors (aux_tensors , expand = False )
147157 from torch import Tensor
@@ -273,7 +283,7 @@ def jit_f_native(*natives):
273283
274284 def __call__ (self , * args , ** kwargs ):
275285 try :
276- key , _ , natives , _ , aux_kwargs = key_from_args (args , kwargs , self .f_params , cache = True , aux = self .auxiliary_args , for_jit = True )
286+ key , _ , natives , _ , aux_kwargs = key_from_args (args , kwargs , self .f_params , cache = True , aux = self .auxiliary_args , use = 'jit' )
277287 except LinearTraceInProgress :
278288 return self .f (* args , ** kwargs )
279289 if isinstance (self .f , GradientFunction ) and key .backend .supports (Backend .jit_compile_grad ):
@@ -438,7 +448,7 @@ def _get_or_trace(self, key: SignatureKey, args: tuple, f_kwargs: dict):
438448
439449 def __call__ (self , * args : X , ** kwargs ) -> Y :
440450 try :
441- key , tensors , natives , x , aux_kwargs = key_from_args (args , kwargs , self .f_params , cache = False , aux = self .auxiliary_args )
451+ key , tensors , natives , x , aux_kwargs = key_from_args (args , kwargs , self .f_params , cache = False , aux = self .auxiliary_args , use = 'linear' )
442452 except LinearTraceInProgress :
443453 return self .f (* args , ** kwargs )
444454 assert tensors , "Linear function requires at least one argument"
@@ -472,7 +482,7 @@ def sparse_matrix(self, *args, **kwargs):
472482 Returns:
473483 Sparse matrix representation with `values` property and `native()` method.
474484 """
475- key , * _ , aux_kwargs = key_from_args (args , kwargs , self .f_params , cache = False , aux = self .auxiliary_args )
485+ key , * _ , aux_kwargs = key_from_args (args , kwargs , self .f_params , cache = False , aux = self .auxiliary_args , use = 'linear' )
476486 matrix , bias , * _ = self ._get_or_trace (key , args , aux_kwargs )
477487 assert math .close (bias , 0 ), "This is an affine function and cannot be represented by a single matrix. Use sparse_matrix_and_bias() instead."
478488 return matrix
@@ -490,7 +500,7 @@ def sparse_matrix_and_bias(self, *args, **kwargs):
490500 matrix: Sparse matrix representation with `values` property and `native()` method.
491501 bias: `Tensor`
492502 """
493- key , * _ , aux_kwargs = key_from_args (args , kwargs , self .f_params , cache = False , aux = self .auxiliary_args )
503+ key , * _ , aux_kwargs = key_from_args (args , kwargs , self .f_params , cache = False , aux = self .auxiliary_args , use = 'linear' )
494504 return self ._get_or_trace (key , args , aux_kwargs )[:2 ]
495505
496506 def __repr__ (self ):
@@ -624,7 +634,7 @@ def f_native(*natives):
624634 return in_key .backend .jacobian (f_native , wrt = wrt_natives , get_output = self .get_output , is_f_scalar = self .is_f_scalar )
625635
626636 def __call__ (self , * args , ** kwargs ):
627- key , tensors , natives , kwargs , _ = key_from_args (args , kwargs , self .f_params , cache = True , attr_type = variable_attributes )
637+ key , tensors , natives , kwargs , _ = key_from_args (args , kwargs , self .f_params , cache = True , attr_type = variable_attributes , use = 'gradient' )
628638 if not key .backend .supports (Backend .jacobian ):
629639 if math .default_backend ().supports (Backend .jacobian ):
630640 warnings .warn (f"Using { math .default_backend ()} for gradient computation because { key .backend } does not support jacobian()" , RuntimeWarning )
@@ -787,7 +797,7 @@ def __init__(self, f: Callable, f_params, wrt: tuple, get_output: bool, get_grad
787797# return hessian_generator(f_native, wrt=wrt_natives, get_output=self.get_output, get_gradient=self.get_gradient)
788798#
789799# def __call__(self, *args, **kwargs):
790- # key, tensors, natives, kwargs, batch_shape = key_from_args_pack_batch(args, kwargs, self.f_params, cache=True, attr_type=variable_attributes)
800+ # key, tensors, natives, kwargs, batch_shape = key_from_args_pack_batch(args, kwargs, self.f_params, cache=True, attr_type=variable_attributes, use='gradient' )
791801# if not key.backend.supports(Backend.jacobian):
792802# if math.default_backend().supports(Backend.jacobian):
793803# warnings.warn(f"Using {math.default_backend()} for gradient computation because {key.backend} does not support jacobian()", RuntimeWarning)
@@ -937,7 +947,7 @@ def backward_native(x_natives, y_natives, dy_natives):
937947 result = self .gradient (kwargs , y , dy )
938948 assert isinstance (result , dict ) and all (key in kwargs for key in result .keys ()), f"gradient function must return a dict containing only parameter names of the forward function. Forward '{ f_name (self .f )} ' has arguments { kwargs } ."
939949 full_result = tuple (result .get (name , None ) for name in in_key .tree .keys ())
940- result_natives = self .incomplete_tree_to_natives (full_result , tuple (in_key .tree .values ()), list (in_key .specs ))
950+ result_natives = self .grad_natives (full_result , tuple (in_key .tree .values ()), list (in_key .specs ))
941951 ML_LOGGER .debug (f"Backward pass of custom op { backward_native .__name__ } returned gradients for { tuple (result .keys ())} out of { tuple (in_key .tree .keys ())} containing { len (result_natives )} native tensors" )
942952 return result_natives
943953
@@ -947,7 +957,7 @@ def backward_native(x_natives, y_natives, dy_natives):
947957 return in_key .backend .custom_gradient (forward_native , backward_native , get_external_cache = lambda : self .recorded_mappings [in_key ], on_call_skipped = partial (self .recorded_mappings .__setitem__ , in_key ))
948958
949959 def __call__ (self , * args , ** kwargs ):
950- key , _ , natives , _ , _ = key_from_args (args , kwargs , self .f_params , cache = False , aux = self .auxiliary_args , attr_type = variable_attributes )
960+ key , _ , natives , _ , _ = key_from_args (args , kwargs , self .f_params , cache = False , aux = self .auxiliary_args , attr_type = variable_attributes , use = 'gradient' )
951961 if not key .backend .supports (Backend .jacobian ) and not key .backend .supports (Backend .jacobian ):
952962 return self .f (* args , ** kwargs ) # no need to use custom gradient if gradients aren't supported anyway
953963 elif not key .backend .supports (Backend .custom_gradient ):
@@ -973,14 +983,14 @@ def __name__(self):
973983 return f"custom_grad({ f_name (self .f )} )"
974984
975985 @staticmethod
976- def incomplete_tree_to_natives (incomplete , tree , complete_specs : List [dict ]) -> list :
986+ def grad_natives (incomplete , tree , complete_specs : List [dict ]) -> list :
977987 """ Returns native tensors for required input gradients.
978988
979989 Args:
980- incomplete: Computed gradient Tensors / composite types
990+ incomplete: Computed gradient Tensors / composite types. This can include `None` for gradients that are not required.
981991 tree: Corresponding input data `x`, including non-Tensor attributes. For dataclasses, this will be an instance of DataclassTreeNode.
982992 None means there is a tensor, not a composite type.
983- complete_shapes : Shapes of `x`.
993+ complete_specs : Shapes of `x`.
984994 """
985995 if tree is None :
986996 c_spec = complete_specs .pop (0 )
@@ -993,30 +1003,32 @@ def incomplete_tree_to_natives(incomplete, tree, complete_specs: List[dict]) ->
9931003 elif isinstance (incomplete , TensorStack ):
9941004 specs = c_spec ['tensors' ]
9951005 return [t .native (s ['names' ]) for t , s in zip (incomplete ._tensors , specs )]
1006+ elif is_sparse (incomplete ): # sparse indices are never passed as native arguments to gradient functions
1007+ return [incomplete ._values .native (c_spec ['values' ]['names' ])]
9961008 warnings .warn (f"Unsupported tensor type for custom gradient: { type (incomplete )} . This might result in incorrectly transposed gradients." , RuntimeWarning )
9971009 return list (incomplete ._natives ())
9981010 elif isinstance (tree , str ) and tree == NATIVE_TENSOR :
9991011 complete_specs .pop (0 )
10001012 return [incomplete ]
10011013 elif isinstance (tree , (tuple , list )):
10021014 if incomplete is None :
1003- return sum ([CustomGradientFunction .incomplete_tree_to_natives (None , item , complete_specs ) for item in tree ], [])
1015+ return sum ([CustomGradientFunction .grad_natives (None , item , complete_specs ) for item in tree ], [])
10041016 else :
10051017 assert type (tree ) == type (incomplete ) and len (tree ) == len (incomplete )
1006- return sum ([CustomGradientFunction .incomplete_tree_to_natives (i_item , c_item , complete_specs ) for i_item , c_item in zip (incomplete , tree )], [])
1018+ return sum ([CustomGradientFunction .grad_natives (i_item , c_item , complete_specs ) for i_item , c_item in zip (incomplete , tree )], [])
10071019 elif isinstance (tree , dict ):
10081020 if incomplete is None :
1009- return sum ([CustomGradientFunction .incomplete_tree_to_natives (None , item , complete_specs ) for item in tree .values ()], [])
1021+ return sum ([CustomGradientFunction .grad_natives (None , item , complete_specs ) for item in tree .values ()], [])
10101022 else :
10111023 assert type (tree ) == type (incomplete ) and len (tree ) == len (incomplete ) and set (tree .keys ()) == set (incomplete .keys ())
1012- return sum ([CustomGradientFunction .incomplete_tree_to_natives (incomplete [key ], c_item , complete_specs ) for key , c_item in tree .items ()], [])
1024+ return sum ([CustomGradientFunction .grad_natives (incomplete [key ], c_item , complete_specs ) for key , c_item in tree .items ()], [])
10131025 elif dataclasses .is_dataclass (tree ):
10141026 from ._tree import DataclassTreeNode
10151027 if isinstance (tree , DataclassTreeNode ):
10161028 natives = []
10171029 for attr , n_val in tree .extracted .items ():
10181030 i_val = getattr (incomplete , attr ) if incomplete is not None else None
1019- natives_item = CustomGradientFunction .incomplete_tree_to_natives (i_val , n_val , complete_specs )
1031+ natives_item = CustomGradientFunction .grad_natives (i_val , n_val , complete_specs )
10201032 natives .extend (natives_item )
10211033 return natives
10221034 if isinstance (tree , PhiTreeNode ):
@@ -1025,7 +1037,7 @@ def incomplete_tree_to_natives(incomplete, tree, complete_specs: List[dict]) ->
10251037 for attr in attributes :
10261038 n_val = getattr (tree , attr )
10271039 i_val = getattr (incomplete , attr ) if incomplete is not None else None
1028- natives_item = CustomGradientFunction .incomplete_tree_to_natives (i_val , n_val , complete_specs )
1040+ natives_item = CustomGradientFunction .grad_natives (i_val , n_val , complete_specs )
10291041 natives .extend (natives_item )
10301042 return natives
10311043 else :
@@ -1135,11 +1147,13 @@ def trace_check(traced_function, *args, **kwargs) -> Tuple[bool, str]:
11351147 f = traced_function
11361148 if isinstance (f , (JitFunction , GradientFunction , HessianFunction , CustomGradientFunction )):
11371149 keys = f .traces .keys ()
1150+ use = 'jit' if isinstance (f , JitFunction ) else 'gradient'
11381151 elif isinstance (f , LinearFunction ):
11391152 keys = f .matrices_and_biases .keys ()
1153+ use = 'linear'
11401154 else :
11411155 raise ValueError (f"{ f_name (f )} is not a traceable function. Only supports jit_compile, jit_compile_linear, gradient, custom_gradient, jacobian, hessian" )
1142- key , * _ = key_from_args (args , kwargs , f .f_params , aux = f .auxiliary_args )
1156+ key , * _ = key_from_args (args , kwargs , f .f_params , aux = f .auxiliary_args , use = use )
11431157 if not keys :
11441158 return False , "Function has not yet been traced"
11451159 if key in keys :
0 commit comments