Skip to content

Instantly share code, notes, and snippets.

@alper111
Last active October 21, 2019 13:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alper111/06bd93c463b5bf60269613448155eda5 to your computer and use it in GitHub Desktop.
Save alper111/06bd93c463b5bf60269613448155eda5 to your computer and use it in GitHub Desktop.
PyTorch and NumPy implementation of pairwise distance function (p2dist) for tensors with dimensions greater or equal to 2. Distances are calculated w.r.t. last dimension.
import torch
import numpy as np
def p2dist_pytorch(x, y):
y_dim = len(y.shape)
return torch.pow(x, 2).sum(dim=-1).view(x.shape[:-1]+(1,)) - \
2 * torch.matmul(x, y.permute(list(range(y_dim-2))+[y_dim-1, y_dim-2])) + \
torch.pow(y, 2).sum(dim=-1).view(y.shape[:-2]+(1,y.shape[-2]))
def p2dist_numpy(x, y):
y_dim = len(y.shape)
return np.power(x, 2).sum(axis=-1).reshape(x.shape[:-1]+(1,)) - \
2 * np.matmul(x, np.transpose(y, (list(range(y_dim-2))+[y_dim-1, y_dim-2]))) + \
np.power(y, 2).sum(axis=-1).reshape(y.shape[:-2]+(1,y.shape[-2]))
@alper111
Copy link
Author

credits to @Fzaero

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment