Skip to content

Instantly share code, notes, and snippets.

@anderzzz
Created November 3, 2020 14:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anderzzz/a109b1d9fb676bb04e62c1bf6697ce54 to your computer and use it in GitHub Desktop.
Save anderzzz/a109b1d9fb676bb04e62c1bf6697ce54 to your computer and use it in GitHub Desktop.
def _prob_density(self, codes, indices):
'''Compute the unnormalized probability density for the codes being in the sets defined by the indices
Returns:
prob_dens (Tensor): The unnormalized probability density of the vectors with given codes being part
of the subset of codes specified by the indices. There is one dimension, the batch dimension
'''
ragged = len(set([np.count_nonzero(ind) for ind in indices])) != 1
# In case the subsets of memory vectors are all of the same size, broadcasting can be used and the
# batch dimension is handled concisely. This will always be true for the k-nearest neighbour density
if not ragged:
vals = torch.tensor([np.compress(ind, self.memory_bank.vectors, axis=0) for ind in indices],
requires_grad=False)
v_dots = torch.matmul(vals, codes.unsqueeze(-1))
exp_values = torch.exp(torch.div(v_dots, self.temperature))
pdensity = torch.sum(exp_values, dim=1).squeeze(-1)
# Broadcasting not possible if the subsets of memory vectors are of different size, so then manually loop
# over the batch dimension and stack results
else:
xx_container = []
for k_item in range(codes.size(0)):
vals = torch.tensor(np.compress(indices[k_item], self.memory_bank.vectors, axis=0),
requires_grad=False)
v_dots_prime = torch.mv(vals, codes[k_item])
exp_values_prime = torch.exp(torch.div(v_dots_prime, self.temperature))
xx_prime = torch.sum(exp_values_prime, dim=0)
xx_container.append(xx_prime)
pdensity = torch.stack(xx_container, dim=0)
return pdensity
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment