Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Last active Mar 28, 2022
Embed
What would you like to do?
kedro annoy index custom dataset
class KedroAnnoyIndex(AbstractDataSet):
"""Wrap ANNOY so it can be included in Kedro data catalog
Args:
AbstractDataSet (AbstractDataset): Kedro abstract class
"""
def __init__(self, filepath, embedding_length, metric) -> None:
self._filepath = Path(filepath)
self.embedding_length = embedding_length
self.metric = metric
def _load(self) -> AnnoyIndex:
annoy_index = AnnoyIndex(self.embedding_length, self.metric)
annoy_index.load(self._filepath.as_posix())
return annoy_index
def _save(self, annoy_idx: AnnoyIndex) -> None:
annoy_idx.save(self._filepath.as_posix())
def _describe(self) -> Dict[str, Any]:
return dict(filepath=self._filepath, embedding_length=self.embedding_length, metric=self.metric)
def build_index(item_factors, params: Dict):
metric = params["metric"]
n_trees = params["n_trees"]
factors = item_factors.shape[1]
# dot product index
annoy_idx = AnnoyIndex(factors, metric)
for i in range(item_factors.shape[0]):
v = item_factors[i]
annoy_idx.add_item(i, v)
annoy_idx.build(n_trees)
# save
annoy_dataset = MlflowArtifactDataSet(data_set={
"type": KedroAnnoyIndex,
"filepath": INDEX_PATH,
"embedding_length": factors,
"metric": metric
})
annoy_dataset.save(data=annoy_idx)
return annoy_dataset
def validate_index(kedro_annoy_dataset: KedroAnnoyIndex, idx_to_names: Dict):
# 1558 = Dark Knight
# 1042 = Ratatouille
# 2196 = Spy who loved me
# 1246 = Rambo
# 818 = Rashomon
# 2481 = The Haunting
annoy_index = kedro_annoy_dataset.load()
item_ids_for_sampling = [1558, 1042, 2196, 1246, 818, 2481]
for item_id in item_ids_for_sampling:
nearest_movies_annoy(item_id, annoy_index, idx_to_names)
return kedro_annoy_dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment