Created
November 3, 2020 14:34
-
-
Save anderzzz/a109b1d9fb676bb04e62c1bf6697ce54 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _prob_density(self, codes, indices): | |
'''Compute the unnormalized probability density for the codes being in the sets defined by the indices | |
Returns: | |
prob_dens (Tensor): The unnormalized probability density of the vectors with given codes being part | |
of the subset of codes specified by the indices. There is one dimension, the batch dimension | |
''' | |
ragged = len(set([np.count_nonzero(ind) for ind in indices])) != 1 | |
# In case the subsets of memory vectors are all of the same size, broadcasting can be used and the | |
# batch dimension is handled concisely. This will always be true for the k-nearest neighbour density | |
if not ragged: | |
vals = torch.tensor([np.compress(ind, self.memory_bank.vectors, axis=0) for ind in indices], | |
requires_grad=False) | |
v_dots = torch.matmul(vals, codes.unsqueeze(-1)) | |
exp_values = torch.exp(torch.div(v_dots, self.temperature)) | |
pdensity = torch.sum(exp_values, dim=1).squeeze(-1) | |
# Broadcasting not possible if the subsets of memory vectors are of different size, so then manually loop | |
# over the batch dimension and stack results | |
else: | |
xx_container = [] | |
for k_item in range(codes.size(0)): | |
vals = torch.tensor(np.compress(indices[k_item], self.memory_bank.vectors, axis=0), | |
requires_grad=False) | |
v_dots_prime = torch.mv(vals, codes[k_item]) | |
exp_values_prime = torch.exp(torch.div(v_dots_prime, self.temperature)) | |
xx_prime = torch.sum(exp_values_prime, dim=0) | |
xx_container.append(xx_prime) | |
pdensity = torch.stack(xx_container, dim=0) | |
return pdensity |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment