Skip to content

Instantly share code, notes, and snippets.

@wzjoriv
Last active February 13, 2024 15:13
Show Gist options
  • Star 34 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
  • Save wzjoriv/7e89afc7f30761022d7747a501260fe3 to your computer and use it in GitHub Desktop.
Save wzjoriv/7e89afc7f30761022d7747a501260fe3 to your computer and use it in GitHub Desktop.
Nearest Neighbor, K Nearest Neighbor and K Means (NN, KNN, KMeans) implemented only using PyTorch
import torch as th
"""
Author: Josue N Rivera (github.com/wzjoriv)
Date: 7/3/2021
Description: Snippet of various clustering implementations only using PyTorch
Full project repository: https://github.com/wzjoriv/Lign (A graph deep learning framework that works alongside PyTorch)
"""
def random_sample(tensor, k):
return tensor[th.randperm(len(tensor))[:k]]
def distance_matrix(x, y=None, p = 2): #pairwise distance of vectors
y = x if type(y) == type(None) else y
n = x.size(0)
m = y.size(0)
d = x.size(1)
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
dist = th.linalg.vector_norm(x - y, p, 2) if th.__version__ >= '1.7.0' else th.pow(x - y, p).sum(2)**(1/p)
return dist
class NN():
def __init__(self, X = None, Y = None, p = 2):
self.p = p
self.train(X, Y)
def train(self, X, Y):
self.train_pts = X
self.train_label = Y
def __call__(self, x):
return self.predict(x)
def predict(self, x):
if type(self.train_pts) == type(None) or type(self.train_label) == type(None):
name = self.__class__.__name__
raise RuntimeError(f"{name} wasn't trained. Need to execute {name}.train() first")
dist = distance_matrix(x, self.train_pts, self.p)
labels = th.argmin(dist, dim=1)
return self.train_label[labels]
class KNN(NN):
def __init__(self, X = None, Y = None, k = 3, p = 2):
self.k = k
super().__init__(X, Y, p)
def train(self, X, Y):
super().train(X, Y)
if type(Y) != type(None):
self.unique_labels = self.train_label.unique()
def predict(self, x):
if type(self.train_pts) == type(None) or type(self.train_label) == type(None):
name = self.__class__.__name__
raise RuntimeError(f"{name} wasn't trained. Need to execute {name}.train() first")
dist = distance_matrix(x, self.train_pts, self.p)
knn = dist.topk(self.k, largest=False)
votes = self.train_label[knn.indices]
winner = th.zeros(votes.size(0), dtype=votes.dtype, device=votes.device)
count = th.zeros(votes.size(0), dtype=votes.dtype, device=votes.device) - 1
for lab in self.unique_labels:
vote_count = (votes == lab).sum(1)
who = vote_count >= count
winner[who] = lab
count[who] = vote_count[who]
return winner
class KMeans(NN):
def __init__(self, X = None, k=2, n_iters = 10, p = 2):
self.k = k
self.n_iters = n_iters
self.p = p
if type(X) != type(None):
self.train(X)
def train(self, X):
self.train_pts = random_sample(X, self.k)
self.train_label = th.LongTensor(range(self.k))
for _ in range(self.n_iters):
labels = self.predict(X)
for lab in range(self.k):
select = labels == lab
self.train_pts[lab] = th.mean(X[select], dim=0)
if __name__ == '__main__':
a = th.Tensor([
[1, 1],
[0.88, 0.90],
[-1, -1],
[-1, -0.88]
])
b = th.LongTensor([3, 3, 5, 5])
c = th.Tensor([
[-0.5, -0.5],
[0.88, 0.88]
])
knn = KNN(a, b)
print(knn(c))
@salehafzoon
Copy link

Such a useful code.
I hope you continue to produce useful programming content.

@wzjoriv
Copy link
Author

wzjoriv commented Aug 27, 2022

Thanks. Will do once I have more time haha.

@wzjoriv
Copy link
Author

wzjoriv commented Dec 25, 2022

I want note that these methods have a bottleneck. As the number of nodes increase, the distance matrices gets more expensive to compute sharply. I would suggest grouping the points into sections (say groups of k nodes), then, when new points need to be labeled, one can just compute the matrix for and compare against those within its section and surrounding ones.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment