Skip to content

Instantly share code, notes, and snippets.

@j-adamczyk
Created September 12, 2020 12:14
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save j-adamczyk/74ee808ffd53cd8545a49f185a908584 to your computer and use it in GitHub Desktop.
Save j-adamczyk/74ee808ffd53cd8545a49f185a908584 to your computer and use it in GitHub Desktop.
k nearest neighbors classifier with faiss library
import numpy as np
import faiss
class FaissKNeighbors:
def __init__(self, k=5):
self.index = None
self.y = None
self.k = k
def fit(self, X, y):
self.index = faiss.IndexFlatL2(X.shape[1])
self.index.add(X.astype(np.float32))
self.y = y
def predict(self, X):
distances, indices = self.index.search(X.astype(np.float32), k=self.k)
votes = self.y[indices]
predictions = np.array([np.argmax(np.bincount(x)) for x in votes])
return predictions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment