Created
November 3, 2020 13:14
-
-
Save anderzzz/3cd7b8d0a004faeb2a06292291d2a4bf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _nearest_neighbours(self, codes_data, indices): | |
'''Ascertain indices in memory bank of the k-nearest neighbours to given codes | |
Returns: | |
indices_nearest (numpy.ndarray): Boolean array of k-nearest neighbours for the batch of codes | |
''' | |
self.neighbour_finder.fit(self.memory_bank.vectors) | |
indices_nearest = self.neighbour_finder.kneighbors(codes_data, return_distance=False) | |
return self.memory_bank.mask(indices_nearest) | |
def _close_grouper(self, indices): | |
'''Ascertain indices in memory bank of vectors that are in the same cluster as vectors of given indices | |
Returns: | |
indices_close (numpy.ndarray): Boolean array of close neighbours for the batch of codes | |
''' | |
memberships = [[]] * len(indices) | |
for clusterer in self.clusterer: | |
clusterer.fit(self.memory_bank.vectors) | |
for k_index, cluster_index in enumerate(clusterer.labels_[indices]): | |
other_members = np.where(clusterer.labels_ == cluster_index)[0] | |
other_members_union = np.union1d(memberships[k_index], other_members) | |
memberships[k_index] = other_members_union.astype(int) | |
return self.memory_bank.mask(np.array(memberships, dtype=object)) | |
def _intersecter(self, n1, n2): | |
'''Compute set intersection of two boolean arrays''' | |
return np.array([[v1 and v2 for v1, v2 in zip(n1_x, n2_x)] for n1_x, n2_x in zip(n1, n2)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment