Skip to content

Instantly share code, notes, and snippets.

@bdsaglam
Created July 14, 2020 09:50
Show Gist options
  • Save bdsaglam/b638b5fe4ddae38495b6b032d5726e33 to your computer and use it in GitHub Desktop.
Save bdsaglam/b638b5fe4ddae38495b6b032d5726e33 to your computer and use it in GitHub Desktop.
n-order Taylor approximation of matrix exponentiation for PyTorch
def expm(x, order=10):
'''
nth-order Taylor approximation of matrix exponential
'''
if order < 1:
raise ValueError("order cannot be smaller than 1")
I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
result = I
nom = I
denom = 1.0
for i in range(1, order):
nom = x @ nom
denom *= i
result += nom / denom
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment