Skip to content

Instantly share code, notes, and snippets.

@anderzzz
Created November 3, 2020 10:49
Show Gist options
  • Save anderzzz/716cdb0ea1b21b6eb3989bdc8c9954b2 to your computer and use it in GitHub Desktop.
Save anderzzz/716cdb0ea1b21b6eb3989bdc8c9954b2 to your computer and use it in GitHub Desktop.
def forward(self, codes, indices):
'''Forward pass for the local aggregation loss module'''
assert codes.shape[0] == len(indices)
codes = codes.type(torch.DoubleTensor)
code_data = normalize(codes.detach().numpy(), axis=1)
# Compute and collect arrays of indices that define the constants in the loss function. Note that
# no gradients are computed for these data values in backward pass
self.memory_bank.update_memory(code_data, indices)
background_neighbours = self._nearest_neighbours(code_data, indices)
close_neighbours = self._close_grouper(indices)
neighbour_intersect = self._intersecter(background_neighbours, close_neighbours)
# Compute the probability density for the codes given the constants of the memory bank
v = F.normalize(codes, p=2, dim=1)
d1 = self._prob_density(v, background_neighbours)
d2 = self._prob_density(v, neighbour_intersect)
return torch.sum(torch.log(d1) - torch.log(d2)) / codes.shape[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment