Skip to content

Instantly share code, notes, and snippets.

@jvlmdr
Last active September 1, 2015 21:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jvlmdr/5c441b15394a84b07597 to your computer and use it in GitHub Desktop.
Save jvlmdr/5c441b15394a84b07597 to your computer and use it in GitHub Desktop.
function [x, mul_dx_dA, mul_dx_db] = solve_square(A, b)
% solve_square returns x = A^-1 b and operators that compute products with
% derivatives with respect to A and b. It uses an LU decomposition of A.
%
% Parameters:
% A has size [n, n] and rank(A) = n.
% b has size [n, 1].
%
% Returns:
% x has size [n, 1].
% v = mul_dx_dA(U)
% u has size [n, n].
% v has size [n, 1].
% v = mul_dx_db(u)
% u has size [n, 1].
% v has size [n, 1].
[m, n] = size(A);
if m ~= n
error('not square');
end
% % Old version using explicit inverse.
% C = inv(A);
% x = C * b;
% mul_dx_db = @(v) C * v;
% mul_dx_dA = @(V) -C * V * x;
[L, U] = lu(A);
% A x = b
% L U x = b
% x = U \ (L \ b)
x = U \ (L \ b);
mul_dx_db = @(v) U \ (L \ v);
mul_dx_dA = @(V) -U \ (L \ (V * x));
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment