Skip to content

Commit 90a5130

Browse files
committed
Properly transpose matrix for implicit gradient solve
1 parent 2c098b6 commit 90a5130

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

phiml/math/_optimize.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..backend import get_precision, NUMPY, Backend
1111
from ..backend._backend import SolveResult, ML_LOGGER, default_backend, convert, Preconditioner, choose_backend
1212
from ..backend._linalg import IncompleteLU, incomplete_lu_dense, incomplete_lu_coo, coarse_explicit_preconditioner_coo
13-
from ._shape import EMPTY_SHAPE, Shape, merge_shapes, batch, non_batch, shape, dual, channel, non_dual, instance, spatial
13+
from ._shape import EMPTY_SHAPE, Shape, merge_shapes, batch, non_batch, shape, dual, channel, non_dual, instance, spatial, primal
1414
from ._tensors import Tensor, wrap, Dense, reshaped_tensor, preferred_backend_for
1515
from ._tree import layout, disassemble_tree, assemble_tree, NATIVE_TENSOR, variable_attributes
1616
from ._magic_ops import stack, copy_with, rename_dims, unpack_dim, unstack, expand, value_attributes
@@ -786,24 +786,25 @@ def attach_gradient_solve(forward_solve: Callable, auxiliary_args: str, matrix_a
786786
def implicit_gradient_solve(fwd_args: dict, x, dx):
787787
solve = fwd_args['solve']
788788
matrix = (fwd_args['matrix'],) if 'matrix' in fwd_args else ()
789+
matrixT = (transpose_matrix(matrix[0], fwd_args['solve'].x0.shape, fwd_args['y'].shape),) if matrix else ()
789790
if matrix_adjoint:
790-
assert matrix, "No matrix given but matrix_gradient=True"
791+
assert matrix, "grad_for_f=True requires and explicit matrix but was given a function instead. Use @jit_compile_linear to build a matrix on the fly"
791792
grad_solve = solve.gradient_solve
792-
x0 = grad_solve.x0 if grad_solve.x0 is not None else zeros_like(solve.x0)
793+
x0 = grad_solve.x0 if grad_solve.x0 is not None else zeros_like(fwd_args['y'])
793794
grad_solve_ = copy_with(solve.gradient_solve, x0=x0)
794795
if 'is_backprop' in fwd_args:
795796
del fwd_args['is_backprop']
796-
dy = solve_with_grad(dx, grad_solve_, *matrix, is_backprop=True, **fwd_args) # this should hopefully result in implicit gradients for higher orders as well
797+
dy = solve_with_grad(dx, grad_solve_, *matrixT, is_backprop=True, **fwd_args) # this should hopefully result in implicit gradients for higher orders as well
797798
if matrix_adjoint: # matrix adjoint = dy * x^T sampled at indices
798799
matrix = matrix[0]
799800
if isinstance(matrix, CompressedSparseMatrix):
800801
matrix = matrix.decompress()
801802
if isinstance(matrix, SparseCoordinateTensor):
802803
col = matrix.dual_indices(to_primal=True)
803804
row = matrix.primal_indices()
804-
_, dy_tensors = disassemble_tree(dy, cache=False, attr_type=value_attributes)
805-
_, x_tensors = disassemble_tree(x, cache=False, attr_type=variable_attributes)
806-
dm_values = dy_tensors[0][col] * x_tensors[0][row]
805+
_, (dy_tensor,) = disassemble_tree(dy, cache=False, attr_type=value_attributes)
806+
_, (x_tensor,) = disassemble_tree(x, cache=False, attr_type=variable_attributes)
807+
dm_values = dy_tensor[row] * x_tensor[col] # dense matrix gradient
807808
dm_values = math.sum_(dm_values, dm_values.shape.non_instance - matrix.shape)
808809
dm = matrix._with_values(dm_values)
809810
dm = -dm
@@ -821,6 +822,18 @@ def implicit_gradient_solve(fwd_args: dict, x, dx):
821822
return solve_with_grad
822823

823824

825+
def transpose_matrix(matrix: Tensor, x_dims: Shape, y_dims: Shape):
826+
# matrix: (~x,y)
827+
# x_dims: (x:p)
828+
# y_dims: (y:p)
829+
assert not x_dims.dual
830+
assert not y_dims.dual
831+
x_dims = x_dims.non_batch
832+
y_dims = y_dims.non_batch
833+
transposed = rename_dims(matrix, x_dims.as_dual() + y_dims, x_dims + y_dims.as_dual()) # y -> ~y, ~x -> x
834+
return transposed
835+
836+
824837
def compute_preconditioner(method: str, matrix: Tensor, rank_deficiency: Union[int, Tensor] = 0, target_backend: Backend = None, solver: str = None) -> Optional[Preconditioner]:
825838
rank_deficiency: Tensor = wrap(rank_deficiency)
826839
if method == 'auto':

0 commit comments

Comments
 (0)