Skip to content

Instantly share code, notes, and snippets.

@tuelwer
Last active March 10, 2021 14:31
Show Gist options
  • Save tuelwer/d04320d7d1acbddd1a76639ac49a1ee7 to your computer and use it in GitHub Desktop.
Save tuelwer/d04320d7d1acbddd1a76639ac49a1ee7 to your computer and use it in GitHub Desktop.
Gaussian mixture model in PyTorch
class GMM(torch.nn.Module):
def __init__(self, n, d=2, k=2):
super(GMM, self).__init__()
self.d = d
self.k = k
self.n = n
self.covs = torch.eye(self.d).view(-1, self.d, self.d).repeat(self.k,1,1)
self.mus = torch.zeros(n, k)
self.member = torch.zeros(n, k)
self.prior = torch.ones(k)/k
def fit(self, X, tol=0.01):
self.mus = X[torch.randint(self.n, size=(self.k,))]
for i in count(start=1):
self.update_member(X)
print(self.prior)
old_mus = self.mus.clone()
self.update_mus(X)
self.update_covs(X)
if torch.norm(self.mus-old_mus)< tol or torch.any(torch.isnan(self.mus)):
print('converged after', i, 'iterations')
break
def update_covs(self, X):
for i in range(self.k):
cX = (X-self.mus[i])
self.covs[i] = ((self.member[:,i]*cX.T)@cX)/torch.sum(self.member[:,i])
def update_mus(self, X):
self.mus = torch.sum(X[:,:,None]*self.member[:,None,:], axis=0).T
self.mus /= torch.sum(self.member[:,None,:], axis=0).T
def update_member(self, X):
for i in range(self.k):
log_probs = MVN(self.mus[i], self.covs[i]).log_prob(X).exp()*self.prior[i]
self.member[:,i] = log_probs
self.member /= self.member.sum(axis=1, keepdim=True)
self.prior = self.member.mean(axis=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment