Skip to content

Instantly share code, notes, and snippets.

@monatis
Created March 24, 2022 11:43
Show Gist options
  • Save monatis/5d9c62be6ef489fe9bae976ead34c8c1 to your computer and use it in GitHub Desktop.
Save monatis/5d9c62be6ef489fe9bae976ead34c8c1 to your computer and use it in GitHub Desktop.
triplet_loss_-_advanced_intro1
def euclidean_distance_matrix(x):
"""Efficient computation of Euclidean distance matrix
Args:
x: Input tensor of shape (batch_size, embedding_dim)
Returns:
Distance matrix of shape (batch_size, batch_size)
"""
# step 1 - compute the dot product
# shape: (batch_size, batch_size)
dot_product = torch.mm(x, x.t())
# step 2 - extract the squared Euclidean norm from the diagonal
# shape: (batch_size,)
squared_norm = torch.diag(dot_product)
# step 3 - compute squared Euclidean distances
# shape: (batch_size, batch_size)
distance_matrix = squared_norm.unsqueeze(0) - 2 * dot_product + squared_norm.unsqueeze(1)
# get rid of negative distances due to numerical instabilities
distance_matrix = F.relu(distance_matrix)
# step 4 - compute the non-squared distances
# handle numerical stability
# derivative of the square root operation applied to 0 is infinite
# we need to handle by setting any 0 to eps
mask = (distance_matrix == 0.0).float()
# use this mask to set indices with a value of 0 to eps
distance_matrix += mask * eps
# now it is safe to get the square root
distance_matrix = torch.sqrt(distance_matrix)
# undo the trick for numerical stability
distance_matrix *= (1.0 - mask)
return distance_matrix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment