Created
January 11, 2022 04:51
-
-
Save aluo-x/c85b3e306af2824e39366b041453917c to your computer and use it in GitHub Desktop.
Fast pytorch all pairs distance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
@torch.jit.script | |
def distance1(input_pos_1:torch.Tensor, input_pos_2:torch.Tensor) -> torch.Tensor: | |
# naive approach | |
# n, 3 | |
# m, 3 | |
# return n, m | |
# n, 1, 3 - 1, m 3 | |
return torch.sum(torch.square((input_pos_1.unsqueeze(1) - input_pos_2[None])), dim=2) | |
@torch.jit.script | |
def distance2(x1, x2): | |
# by jacobrgardner | |
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True) | |
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True) | |
res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm) | |
return res | |
@torch.jit.script | |
def distance3(X, Y): | |
# by J. Emmanuel Johnson, and Divakar | |
XX = torch.einsum('ij,ij->i', X, X)[:, None] | |
YY = torch.einsum('ij,ij->i', Y, Y) | |
XY = 2 * torch.matmul(X, Y.T) | |
return XX + YY - XY | |
# distance1 is 4x memory as distance2, 8.67 ms | |
# distance2 is most memory efficient 4.28 ms | |
# distance3 is 3x memory as distance2 (2x if jit is used) 2.31 ms | |
# comparing distance2, and distance3, the torch.addmm is responsible for the difference in memory and speed |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment