Last active
October 25, 2020 00:05
-
-
Save xmodar/7c4aeb3d75bf1e0ab99b24cf2b3b37a3 to your computer and use it in GitHub Desktop.
Compute the square root of a positive definite matrix with differentiable operations in pytorch (supports batching).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Matrix square root: https://github.com/pytorch/pytorch/issues/25481""" | |
import torch | |
def psd_matrix_sqrt(matrix, num_iterations=20): | |
"""Compute the square root of a PSD matrix using Newton's method. | |
This implementation was adopted from code by @JonathanVacher. | |
https://gist.github.com/ModarTensai/7c4aeb3d75bf1e0ab99b24cf2b3b37a3 | |
""" | |
norm = matrix.norm(dim=[-2, -1], keepdim=True) | |
matrix = matrix / norm | |
def mul_diag_add(inputs, scale=-0.5, diag=1.5): | |
# multiply by a scalar then add a scalar to the diagonal | |
inputs.mul_(scale).diagonal(0, -1, -2).add_(diag) | |
return inputs | |
other = mul_diag_add(matrix.clone()) # avoid inplace | |
matrix = matrix @ other | |
for i in range(1, num_iterations): | |
temp = mul_diag_add(other @ matrix) | |
matrix = matrix @ temp | |
if i + 1 < num_iterations: # avoid last step | |
other = temp @ other | |
return matrix * norm.sqrt() | |
def decomposition_based(matrix): | |
"""Compute the square root of a positive definite matrix.""" | |
# s, v = matrix.symeig(eigenvectors=True) | |
_, s, v = matrix.svd() | |
good = s > s.max(-1, True).values * s.size(-1) * torch.finfo(s.dtype).eps | |
components = good.sum(-1) | |
common = components.max() | |
unbalanced = common != components.min() | |
if common < s.size(-1): | |
s = s[..., :common] | |
v = v[..., :common] | |
if unbalanced: | |
good = good[..., :common] | |
if unbalanced: | |
s = s.where(good, torch.zeros((), device=s.device, dtype=s.dtype)) | |
return (v * s.sqrt().unsqueeze(-2)) @ v.transpose(-2, -1) | |
def special_sylvester(a, b): | |
"""Solves the eqation `A @ X + X @ A = B` for a positive definite `A`. | |
Imitate `scipy.linalg.solve_sylvester(a, a, b)`. | |
Useful to explicitly define the backward pass of `sqrtm()`. | |
The gradient of `sqrtm(a)` w.r.t. `a` is: | |
`special_sylvester(sqrtm(a), grad_output)` | |
""" | |
# https://math.stackexchange.com/a/820313 | |
# "A computational framework of gradient flows | |
# for general linear matrix equations" | |
# by Liqi Wang ·Moody T. Chu ·Yu Bo (2014) | |
# _, s, v = matrix.svd() | |
s, v = a.symeig(eigenvectors=True) | |
d = s.unsqueeze(-1) | |
d = d + d.transpose(-2, -1) | |
vt = v.transpose(-2, -1) | |
c = vt @ b @ v | |
return v @ (c / d) @ vt | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment