Skip to content

Instantly share code, notes, and snippets.

@aluo-x
Created January 11, 2022 04:51
Show Gist options
  • Save aluo-x/c85b3e306af2824e39366b041453917c to your computer and use it in GitHub Desktop.
Save aluo-x/c85b3e306af2824e39366b041453917c to your computer and use it in GitHub Desktop.
Fast pytorch all pairs distance
import torch
@torch.jit.script
def distance1(input_pos_1:torch.Tensor, input_pos_2:torch.Tensor) -> torch.Tensor:
# naive approach
# n, 3
# m, 3
# return n, m
# n, 1, 3 - 1, m 3
return torch.sum(torch.square((input_pos_1.unsqueeze(1) - input_pos_2[None])), dim=2)
@torch.jit.script
def distance2(x1, x2):
# by jacobrgardner
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
return res
@torch.jit.script
def distance3(X, Y):
# by J. Emmanuel Johnson, and Divakar
XX = torch.einsum('ij,ij->i', X, X)[:, None]
YY = torch.einsum('ij,ij->i', Y, Y)
XY = 2 * torch.matmul(X, Y.T)
return XX + YY - XY
# distance1 is 4x memory as distance2, 8.67 ms
# distance2 is most memory efficient 4.28 ms
# distance3 is 3x memory as distance2 (2x if jit is used) 2.31 ms
# comparing distance2, and distance3, the torch.addmm is responsible for the difference in memory and speed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment