Skip to content

Instantly share code, notes, and snippets.

@DuaneNielsen
Created January 13, 2020 06:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DuaneNielsen/1e2a9be4e7474b7eea34d2f92e629a6e to your computer and use it in GitHub Desktop.
Save DuaneNielsen/1e2a9be4e7474b7eea34d2f92e629a6e to your computer and use it in GitHub Desktop.
k-means 2D in pytorch
import torch
import matplotlib.pyplot as plt
"""
K means 2D demo, in pytorch
"""
n = 30 # must be even number
k = 3
dims = 2
eps = torch.finfo(torch.float32).eps
def estimate_mu(x, mu):
dist = x.expand(k, -1, dims) - mu.view(k, 1, dims)
dist = torch.sum(dist ** 2, dim=2).sqrt()
i = torch.argmin(dist, dim=0)
hot = torch.zeros(n, k)
hot[torch.arange(n), i] = 1.0
sums = torch.matmul(x.T, hot)
elems = hot.sum(dim=0)
return (sums / (elems + eps)).T, i
def sample(mu, c):
z = torch.randn(2, n // 3)
return (mu.view(-1, 1) - c.matmul(z)).T
x1 = sample(torch.tensor([-1.0, -1.0]), torch.eye(dims) * 0.2)
x2 = sample(torch.tensor([1.0, 1.0]), torch.eye(dims) * 0.3)
x3 = sample(torch.tensor([1.0, -1.0]), torch.eye(dims) * 0.1)
x = torch.cat((x1, x2, x3), dim=0)
mu = torch.randn(k, dims)
plt.scatter(x[:, 0], x[:, 1])
plt.scatter(mu[:, 0], mu[:, 1])
plt.show()
mu, i_prev = estimate_mu(x, mu)
converged = False
while not converged:
mu, i = estimate_mu(x, mu)
converged = torch.allclose(i, i_prev)
i_prev = i.clone()
plt.scatter(x[:, 0], x[:, 1])
plt.scatter(mu[:, 0], mu[:, 1])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment