Skip to content

Commit 2c098b6

Browse files
committed
Don't pass sparse indices to native gradient functions
1 parent eb01a77 commit 2c098b6

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

phiml/math/_functional.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ._tensors import Tensor, disassemble_tensors, assemble_tensors, wrap, specs_equal, equality_by_shape_and_value, backend_for, Dense, TensorStack
1515
from ._tree import disassemble_tree, assemble_tree, variable_attributes, NATIVE_TENSOR, object_dims, slice_, find_differences
1616
from ._magic_ops import stack, rename_dims, all_attributes
17-
from ._sparse import SparseCoordinateTensor
17+
from ._sparse import SparseCoordinateTensor, is_sparse
1818
from ._lin_trace import ShiftLinTracer, matrix_from_function, LinearTraceInProgress
1919
from .magic import PhiTreeNode, Shapable
2020
from ..backend import Backend
@@ -126,7 +126,17 @@ def key_from_args(args: tuple,
126126
cache=False,
127127
aux: Set[str] = (),
128128
attr_type=variable_attributes,
129-
for_jit=False) -> Tuple[SignatureKey, List[Tensor], tuple, Dict[str, Any], Dict[str, Any]]:
129+
use: str = None) -> Tuple[SignatureKey, List[Tensor], tuple, Dict[str, Any], Dict[str, Any]]:
130+
"""
131+
Args:
132+
args:
133+
kwargs:
134+
parameters:
135+
cache:
136+
aux:
137+
attr_type:
138+
use: 'jit' or 'gradient' or 'linear'
139+
"""
130140
kwargs = {**kwargs, **{parameters[i]: v for i, v in enumerate(args)}}
131141
attached_aux_kwargs = {}
132142
detached_aux_kwargs = {}
@@ -140,8 +150,8 @@ def key_from_args(args: tuple,
140150
_, aux_tensors = disassemble_tree(detached_aux_kwargs, cache=cache, attr_type=variable_attributes)
141151
tracing = not math.all_available(*tensors, *aux_tensors)
142152
backend = backend_for(*tensors, *aux_tensors)
143-
natives, shapes, specs = disassemble_tensors(tensors, expand=cache)
144-
if for_jit and backend.name == 'torch': # for PyTorch, add tracers from aux to natives, but keep aux. PyTorch does not support using tensors with grad inside jit otherwise.
153+
natives, shapes, specs = disassemble_tensors(tensors, expand=cache, include_constants=use != 'gradient')
154+
if use == 'jit' and backend.name == 'torch': # for PyTorch, add tracers from aux to natives, but keep aux. PyTorch does not support using tensors with grad inside jit otherwise.
145155
_, aux_tensors = disassemble_tree(attached_aux_kwargs, cache=cache, attr_type=variable_attributes)
146156
aux_natives, aux_shapes, aux_specs = disassemble_tensors(aux_tensors, expand=False)
147157
from torch import Tensor
@@ -273,7 +283,7 @@ def jit_f_native(*natives):
273283

274284
def __call__(self, *args, **kwargs):
275285
try:
276-
key, _, natives, _, aux_kwargs = key_from_args(args, kwargs, self.f_params, cache=True, aux=self.auxiliary_args, for_jit=True)
286+
key, _, natives, _, aux_kwargs = key_from_args(args, kwargs, self.f_params, cache=True, aux=self.auxiliary_args, use='jit')
277287
except LinearTraceInProgress:
278288
return self.f(*args, **kwargs)
279289
if isinstance(self.f, GradientFunction) and key.backend.supports(Backend.jit_compile_grad):
@@ -438,7 +448,7 @@ def _get_or_trace(self, key: SignatureKey, args: tuple, f_kwargs: dict):
438448

439449
def __call__(self, *args: X, **kwargs) -> Y:
440450
try:
441-
key, tensors, natives, x, aux_kwargs = key_from_args(args, kwargs, self.f_params, cache=False, aux=self.auxiliary_args)
451+
key, tensors, natives, x, aux_kwargs = key_from_args(args, kwargs, self.f_params, cache=False, aux=self.auxiliary_args, use='linear')
442452
except LinearTraceInProgress:
443453
return self.f(*args, **kwargs)
444454
assert tensors, "Linear function requires at least one argument"
@@ -472,7 +482,7 @@ def sparse_matrix(self, *args, **kwargs):
472482
Returns:
473483
Sparse matrix representation with `values` property and `native()` method.
474484
"""
475-
key, *_, aux_kwargs = key_from_args(args, kwargs, self.f_params, cache=False, aux=self.auxiliary_args)
485+
key, *_, aux_kwargs = key_from_args(args, kwargs, self.f_params, cache=False, aux=self.auxiliary_args, use='linear')
476486
matrix, bias, *_ = self._get_or_trace(key, args, aux_kwargs)
477487
assert math.close(bias, 0), "This is an affine function and cannot be represented by a single matrix. Use sparse_matrix_and_bias() instead."
478488
return matrix
@@ -490,7 +500,7 @@ def sparse_matrix_and_bias(self, *args, **kwargs):
490500
matrix: Sparse matrix representation with `values` property and `native()` method.
491501
bias: `Tensor`
492502
"""
493-
key, *_, aux_kwargs = key_from_args(args, kwargs, self.f_params, cache=False, aux=self.auxiliary_args)
503+
key, *_, aux_kwargs = key_from_args(args, kwargs, self.f_params, cache=False, aux=self.auxiliary_args, use='linear')
494504
return self._get_or_trace(key, args, aux_kwargs)[:2]
495505

496506
def __repr__(self):
@@ -624,7 +634,7 @@ def f_native(*natives):
624634
return in_key.backend.jacobian(f_native, wrt=wrt_natives, get_output=self.get_output, is_f_scalar=self.is_f_scalar)
625635

626636
def __call__(self, *args, **kwargs):
627-
key, tensors, natives, kwargs, _ = key_from_args(args, kwargs, self.f_params, cache=True, attr_type=variable_attributes)
637+
key, tensors, natives, kwargs, _ = key_from_args(args, kwargs, self.f_params, cache=True, attr_type=variable_attributes, use='gradient')
628638
if not key.backend.supports(Backend.jacobian):
629639
if math.default_backend().supports(Backend.jacobian):
630640
warnings.warn(f"Using {math.default_backend()} for gradient computation because {key.backend} does not support jacobian()", RuntimeWarning)
@@ -787,7 +797,7 @@ def __init__(self, f: Callable, f_params, wrt: tuple, get_output: bool, get_grad
787797
# return hessian_generator(f_native, wrt=wrt_natives, get_output=self.get_output, get_gradient=self.get_gradient)
788798
#
789799
# def __call__(self, *args, **kwargs):
790-
# key, tensors, natives, kwargs, batch_shape = key_from_args_pack_batch(args, kwargs, self.f_params, cache=True, attr_type=variable_attributes)
800+
# key, tensors, natives, kwargs, batch_shape = key_from_args_pack_batch(args, kwargs, self.f_params, cache=True, attr_type=variable_attributes, use='gradient')
791801
# if not key.backend.supports(Backend.jacobian):
792802
# if math.default_backend().supports(Backend.jacobian):
793803
# warnings.warn(f"Using {math.default_backend()} for gradient computation because {key.backend} does not support jacobian()", RuntimeWarning)
@@ -937,7 +947,7 @@ def backward_native(x_natives, y_natives, dy_natives):
937947
result = self.gradient(kwargs, y, dy)
938948
assert isinstance(result, dict) and all(key in kwargs for key in result.keys()), f"gradient function must return a dict containing only parameter names of the forward function. Forward '{f_name(self.f)}' has arguments {kwargs}."
939949
full_result = tuple(result.get(name, None) for name in in_key.tree.keys())
940-
result_natives = self.incomplete_tree_to_natives(full_result, tuple(in_key.tree.values()), list(in_key.specs))
950+
result_natives = self.grad_natives(full_result, tuple(in_key.tree.values()), list(in_key.specs))
941951
ML_LOGGER.debug(f"Backward pass of custom op {backward_native.__name__} returned gradients for {tuple(result.keys())} out of {tuple(in_key.tree.keys())} containing {len(result_natives)} native tensors")
942952
return result_natives
943953

@@ -947,7 +957,7 @@ def backward_native(x_natives, y_natives, dy_natives):
947957
return in_key.backend.custom_gradient(forward_native, backward_native, get_external_cache=lambda: self.recorded_mappings[in_key], on_call_skipped=partial(self.recorded_mappings.__setitem__, in_key))
948958

949959
def __call__(self, *args, **kwargs):
950-
key, _, natives, _, _ = key_from_args(args, kwargs, self.f_params, cache=False, aux=self.auxiliary_args, attr_type=variable_attributes)
960+
key, _, natives, _, _ = key_from_args(args, kwargs, self.f_params, cache=False, aux=self.auxiliary_args, attr_type=variable_attributes, use='gradient')
951961
if not key.backend.supports(Backend.jacobian) and not key.backend.supports(Backend.jacobian):
952962
return self.f(*args, **kwargs) # no need to use custom gradient if gradients aren't supported anyway
953963
elif not key.backend.supports(Backend.custom_gradient):
@@ -973,14 +983,14 @@ def __name__(self):
973983
return f"custom_grad({f_name(self.f)})"
974984

975985
@staticmethod
976-
def incomplete_tree_to_natives(incomplete, tree, complete_specs: List[dict]) -> list:
986+
def grad_natives(incomplete, tree, complete_specs: List[dict]) -> list:
977987
""" Returns native tensors for required input gradients.
978988
979989
Args:
980-
incomplete: Computed gradient Tensors / composite types
990+
incomplete: Computed gradient Tensors / composite types. This can include `None` for gradients that are not required.
981991
tree: Corresponding input data `x`, including non-Tensor attributes. For dataclasses, this will be an instance of DataclassTreeNode.
982992
None means there is a tensor, not a composite type.
983-
complete_shapes: Shapes of `x`.
993+
complete_specs: Shapes of `x`.
984994
"""
985995
if tree is None:
986996
c_spec = complete_specs.pop(0)
@@ -993,30 +1003,32 @@ def incomplete_tree_to_natives(incomplete, tree, complete_specs: List[dict]) ->
9931003
elif isinstance(incomplete, TensorStack):
9941004
specs = c_spec['tensors']
9951005
return [t.native(s['names']) for t, s in zip(incomplete._tensors, specs)]
1006+
elif is_sparse(incomplete): # sparse indices are never passed as native arguments to gradient functions
1007+
return [incomplete._values.native(c_spec['values']['names'])]
9961008
warnings.warn(f"Unsupported tensor type for custom gradient: {type(incomplete)}. This might result in incorrectly transposed gradients.", RuntimeWarning)
9971009
return list(incomplete._natives())
9981010
elif isinstance(tree, str) and tree == NATIVE_TENSOR:
9991011
complete_specs.pop(0)
10001012
return [incomplete]
10011013
elif isinstance(tree, (tuple, list)):
10021014
if incomplete is None:
1003-
return sum([CustomGradientFunction.incomplete_tree_to_natives(None, item, complete_specs) for item in tree], [])
1015+
return sum([CustomGradientFunction.grad_natives(None, item, complete_specs) for item in tree], [])
10041016
else:
10051017
assert type(tree) == type(incomplete) and len(tree) == len(incomplete)
1006-
return sum([CustomGradientFunction.incomplete_tree_to_natives(i_item, c_item, complete_specs) for i_item, c_item in zip(incomplete, tree)], [])
1018+
return sum([CustomGradientFunction.grad_natives(i_item, c_item, complete_specs) for i_item, c_item in zip(incomplete, tree)], [])
10071019
elif isinstance(tree, dict):
10081020
if incomplete is None:
1009-
return sum([CustomGradientFunction.incomplete_tree_to_natives(None, item, complete_specs) for item in tree.values()], [])
1021+
return sum([CustomGradientFunction.grad_natives(None, item, complete_specs) for item in tree.values()], [])
10101022
else:
10111023
assert type(tree) == type(incomplete) and len(tree) == len(incomplete) and set(tree.keys()) == set(incomplete.keys())
1012-
return sum([CustomGradientFunction.incomplete_tree_to_natives(incomplete[key], c_item, complete_specs) for key, c_item in tree.items()], [])
1024+
return sum([CustomGradientFunction.grad_natives(incomplete[key], c_item, complete_specs) for key, c_item in tree.items()], [])
10131025
elif dataclasses.is_dataclass(tree):
10141026
from ._tree import DataclassTreeNode
10151027
if isinstance(tree, DataclassTreeNode):
10161028
natives = []
10171029
for attr, n_val in tree.extracted.items():
10181030
i_val = getattr(incomplete, attr) if incomplete is not None else None
1019-
natives_item = CustomGradientFunction.incomplete_tree_to_natives(i_val, n_val, complete_specs)
1031+
natives_item = CustomGradientFunction.grad_natives(i_val, n_val, complete_specs)
10201032
natives.extend(natives_item)
10211033
return natives
10221034
if isinstance(tree, PhiTreeNode):
@@ -1025,7 +1037,7 @@ def incomplete_tree_to_natives(incomplete, tree, complete_specs: List[dict]) ->
10251037
for attr in attributes:
10261038
n_val = getattr(tree, attr)
10271039
i_val = getattr(incomplete, attr) if incomplete is not None else None
1028-
natives_item = CustomGradientFunction.incomplete_tree_to_natives(i_val, n_val, complete_specs)
1040+
natives_item = CustomGradientFunction.grad_natives(i_val, n_val, complete_specs)
10291041
natives.extend(natives_item)
10301042
return natives
10311043
else:
@@ -1135,11 +1147,13 @@ def trace_check(traced_function, *args, **kwargs) -> Tuple[bool, str]:
11351147
f = traced_function
11361148
if isinstance(f, (JitFunction, GradientFunction, HessianFunction, CustomGradientFunction)):
11371149
keys = f.traces.keys()
1150+
use = 'jit' if isinstance(f, JitFunction) else 'gradient'
11381151
elif isinstance(f, LinearFunction):
11391152
keys = f.matrices_and_biases.keys()
1153+
use = 'linear'
11401154
else:
11411155
raise ValueError(f"{f_name(f)} is not a traceable function. Only supports jit_compile, jit_compile_linear, gradient, custom_gradient, jacobian, hessian")
1142-
key, *_ = key_from_args(args, kwargs, f.f_params, aux=f.auxiliary_args)
1156+
key, *_ = key_from_args(args, kwargs, f.f_params, aux=f.auxiliary_args, use=use)
11431157
if not keys:
11441158
return False, "Function has not yet been traced"
11451159
if key in keys:

0 commit comments

Comments
 (0)