Created
November 3, 2020 10:49
-
-
Save anderzzz/716cdb0ea1b21b6eb3989bdc8c9954b2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, codes, indices): | |
'''Forward pass for the local aggregation loss module''' | |
assert codes.shape[0] == len(indices) | |
codes = codes.type(torch.DoubleTensor) | |
code_data = normalize(codes.detach().numpy(), axis=1) | |
# Compute and collect arrays of indices that define the constants in the loss function. Note that | |
# no gradients are computed for these data values in backward pass | |
self.memory_bank.update_memory(code_data, indices) | |
background_neighbours = self._nearest_neighbours(code_data, indices) | |
close_neighbours = self._close_grouper(indices) | |
neighbour_intersect = self._intersecter(background_neighbours, close_neighbours) | |
# Compute the probability density for the codes given the constants of the memory bank | |
v = F.normalize(codes, p=2, dim=1) | |
d1 = self._prob_density(v, background_neighbours) | |
d2 = self._prob_density(v, neighbour_intersect) | |
return torch.sum(torch.log(d1) - torch.log(d2)) / codes.shape[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment