Created
March 24, 2022 11:43
-
-
Save monatis/5d9c62be6ef489fe9bae976ead34c8c1 to your computer and use it in GitHub Desktop.
triplet_loss_-_advanced_intro1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def euclidean_distance_matrix(x): | |
"""Efficient computation of Euclidean distance matrix | |
Args: | |
x: Input tensor of shape (batch_size, embedding_dim) | |
Returns: | |
Distance matrix of shape (batch_size, batch_size) | |
""" | |
# step 1 - compute the dot product | |
# shape: (batch_size, batch_size) | |
dot_product = torch.mm(x, x.t()) | |
# step 2 - extract the squared Euclidean norm from the diagonal | |
# shape: (batch_size,) | |
squared_norm = torch.diag(dot_product) | |
# step 3 - compute squared Euclidean distances | |
# shape: (batch_size, batch_size) | |
distance_matrix = squared_norm.unsqueeze(0) - 2 * dot_product + squared_norm.unsqueeze(1) | |
# get rid of negative distances due to numerical instabilities | |
distance_matrix = F.relu(distance_matrix) | |
# step 4 - compute the non-squared distances | |
# handle numerical stability | |
# derivative of the square root operation applied to 0 is infinite | |
# we need to handle by setting any 0 to eps | |
mask = (distance_matrix == 0.0).float() | |
# use this mask to set indices with a value of 0 to eps | |
distance_matrix += mask * eps | |
# now it is safe to get the square root | |
distance_matrix = torch.sqrt(distance_matrix) | |
# undo the trick for numerical stability | |
distance_matrix *= (1.0 - mask) | |
return distance_matrix |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment