1010from ..backend import get_precision , NUMPY , Backend
1111from ..backend ._backend import SolveResult , ML_LOGGER , default_backend , convert , Preconditioner , choose_backend
1212from ..backend ._linalg import IncompleteLU , incomplete_lu_dense , incomplete_lu_coo , coarse_explicit_preconditioner_coo
13- from ._shape import EMPTY_SHAPE , Shape , merge_shapes , batch , non_batch , shape , dual , channel , non_dual , instance , spatial
13+ from ._shape import EMPTY_SHAPE , Shape , merge_shapes , batch , non_batch , shape , dual , channel , non_dual , instance , spatial , primal
1414from ._tensors import Tensor , wrap , Dense , reshaped_tensor , preferred_backend_for
1515from ._tree import layout , disassemble_tree , assemble_tree , NATIVE_TENSOR , variable_attributes
1616from ._magic_ops import stack , copy_with , rename_dims , unpack_dim , unstack , expand , value_attributes
@@ -786,24 +786,25 @@ def attach_gradient_solve(forward_solve: Callable, auxiliary_args: str, matrix_a
786786 def implicit_gradient_solve (fwd_args : dict , x , dx ):
787787 solve = fwd_args ['solve' ]
788788 matrix = (fwd_args ['matrix' ],) if 'matrix' in fwd_args else ()
789+ matrixT = (transpose_matrix (matrix [0 ], fwd_args ['solve' ].x0 .shape , fwd_args ['y' ].shape ),) if matrix else ()
789790 if matrix_adjoint :
790- assert matrix , "No matrix given but matrix_gradient=True "
791+ assert matrix , "grad_for_f=True requires and explicit matrix but was given a function instead. Use @jit_compile_linear to build a matrix on the fly "
791792 grad_solve = solve .gradient_solve
792- x0 = grad_solve .x0 if grad_solve .x0 is not None else zeros_like (solve . x0 )
793+ x0 = grad_solve .x0 if grad_solve .x0 is not None else zeros_like (fwd_args [ 'y' ] )
793794 grad_solve_ = copy_with (solve .gradient_solve , x0 = x0 )
794795 if 'is_backprop' in fwd_args :
795796 del fwd_args ['is_backprop' ]
796- dy = solve_with_grad (dx , grad_solve_ , * matrix , is_backprop = True , ** fwd_args ) # this should hopefully result in implicit gradients for higher orders as well
797+ dy = solve_with_grad (dx , grad_solve_ , * matrixT , is_backprop = True , ** fwd_args ) # this should hopefully result in implicit gradients for higher orders as well
797798 if matrix_adjoint : # matrix adjoint = dy * x^T sampled at indices
798799 matrix = matrix [0 ]
799800 if isinstance (matrix , CompressedSparseMatrix ):
800801 matrix = matrix .decompress ()
801802 if isinstance (matrix , SparseCoordinateTensor ):
802803 col = matrix .dual_indices (to_primal = True )
803804 row = matrix .primal_indices ()
804- _ , dy_tensors = disassemble_tree (dy , cache = False , attr_type = value_attributes )
805- _ , x_tensors = disassemble_tree (x , cache = False , attr_type = variable_attributes )
806- dm_values = dy_tensors [ 0 ][ col ] * x_tensors [ 0 ][ row ]
805+ _ , ( dy_tensor ,) = disassemble_tree (dy , cache = False , attr_type = value_attributes )
806+ _ , ( x_tensor ,) = disassemble_tree (x , cache = False , attr_type = variable_attributes )
807+ dm_values = dy_tensor [ row ] * x_tensor [ col ] # dense matrix gradient
807808 dm_values = math .sum_ (dm_values , dm_values .shape .non_instance - matrix .shape )
808809 dm = matrix ._with_values (dm_values )
809810 dm = - dm
@@ -821,6 +822,18 @@ def implicit_gradient_solve(fwd_args: dict, x, dx):
821822 return solve_with_grad
822823
823824
825+ def transpose_matrix (matrix : Tensor , x_dims : Shape , y_dims : Shape ):
826+ # matrix: (~x,y)
827+ # x_dims: (x:p)
828+ # y_dims: (y:p)
829+ assert not x_dims .dual
830+ assert not y_dims .dual
831+ x_dims = x_dims .non_batch
832+ y_dims = y_dims .non_batch
833+ transposed = rename_dims (matrix , x_dims .as_dual () + y_dims , x_dims + y_dims .as_dual ()) # y -> ~y, ~x -> x
834+ return transposed
835+
836+
824837def compute_preconditioner (method : str , matrix : Tensor , rank_deficiency : Union [int , Tensor ] = 0 , target_backend : Backend = None , solver : str = None ) -> Optional [Preconditioner ]:
825838 rank_deficiency : Tensor = wrap (rank_deficiency )
826839 if method == 'auto' :
0 commit comments