Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active October 25, 2020 00:05
Show Gist options
  • Save xmodar/7c4aeb3d75bf1e0ab99b24cf2b3b37a3 to your computer and use it in GitHub Desktop.
Save xmodar/7c4aeb3d75bf1e0ab99b24cf2b3b37a3 to your computer and use it in GitHub Desktop.
Compute the square root of a positive definite matrix with differentiable operations in pytorch (supports batching).
"""Matrix square root: https://github.com/pytorch/pytorch/issues/25481"""
import torch
def psd_matrix_sqrt(matrix, num_iterations=20):
"""Compute the square root of a PSD matrix using Newton's method.
This implementation was adopted from code by @JonathanVacher.
https://gist.github.com/ModarTensai/7c4aeb3d75bf1e0ab99b24cf2b3b37a3
"""
norm = matrix.norm(dim=[-2, -1], keepdim=True)
matrix = matrix / norm
def mul_diag_add(inputs, scale=-0.5, diag=1.5):
# multiply by a scalar then add a scalar to the diagonal
inputs.mul_(scale).diagonal(0, -1, -2).add_(diag)
return inputs
other = mul_diag_add(matrix.clone()) # avoid inplace
matrix = matrix @ other
for i in range(1, num_iterations):
temp = mul_diag_add(other @ matrix)
matrix = matrix @ temp
if i + 1 < num_iterations: # avoid last step
other = temp @ other
return matrix * norm.sqrt()
def decomposition_based(matrix):
"""Compute the square root of a positive definite matrix."""
# s, v = matrix.symeig(eigenvectors=True)
_, s, v = matrix.svd()
good = s > s.max(-1, True).values * s.size(-1) * torch.finfo(s.dtype).eps
components = good.sum(-1)
common = components.max()
unbalanced = common != components.min()
if common < s.size(-1):
s = s[..., :common]
v = v[..., :common]
if unbalanced:
good = good[..., :common]
if unbalanced:
s = s.where(good, torch.zeros((), device=s.device, dtype=s.dtype))
return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1)
def special_sylvester(a, b):
"""Solves the eqation `A @ X + X @ A = B` for a positive definite `A`.
Imitate `scipy.linalg.solve_sylvester(a, a, b)`.
Useful to explicitly define the backward pass of `sqrtm()`.
The gradient of `sqrtm(a)` w.r.t. `a` is:
`special_sylvester(sqrtm(a), grad_output)`
"""
# https://math.stackexchange.com/a/820313
# "A computational framework of gradient flows
# for general linear matrix equations"
# by Liqi Wang ·Moody T. Chu ·Yu Bo (2014)
# _, s, v = matrix.svd()
s, v = a.symeig(eigenvectors=True)
d = s.unsqueeze(-1)
d = d + d.transpose(-2, -1)
vt = v.transpose(-2, -1)
c = vt @ b @ v
return v @ (c / d) @ vt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment