Skip to content

Instantly share code, notes, and snippets.

@justanhduc
Last active November 24, 2020 13:07
Show Gist options
  • Save justanhduc/adbcc06dfd72e3a80026a30c9bd45f37 to your computer and use it in GitHub Desktop.
Save justanhduc/adbcc06dfd72e3a80026a30c9bd45f37 to your computer and use it in GitHub Desktop.
def batch_pairwise_sqdist(x: T.Tensor, y: T.Tensor):
"""
Calculates the pair-wise square distance between two sets of points.
To get the Euclidean distance, explicit square root needs to be applied
to the output.
:param x:
a tensor of shape ``(m, nx, d)`` or ``(nx, d)``.
If the tensor dimension is 2, the tensor batch dim is broadcasted.
:param y:
a tensor of shape ``(m, ny, d)`` or ``(ny, d)``.
If the tensor dimension is 2, the tensor batch dim is broadcasted.
:param c_code:
whether to use a C++ implementation.
Default: ``True`` when the CUDA extension is installed. ``False`` otherwise.
:return:
a tensor containing the exhaustive square distance between every pair of points
in `x` and `y` from the same batch.
"""
xx = T.sum(x ** 2, -1)
yy = T.sum(y ** 2, -1)
zz = T.matmul(x, y.transpose(-1, -2).contiguous())
rx = xx.unsqueeze(-2).expand_as(zz.transpose(-2, -1))
ry = yy.unsqueeze(-2).expand_as(zz)
P = (rx.transpose(-2, -1) + ry - 2. * zz)
return P
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment