Skip to content

Instantly share code, notes, and snippets.

@omaraflak
Last active October 18, 2020 13:47
Show Gist options
  • Save omaraflak/e5fd267d4b2dda54c4cb279595279aa6 to your computer and use it in GitHub Desktop.
Save omaraflak/e5fd267d4b2dda54c4cb279595279aa6 to your computer and use it in GitHub Desktop.
Gaussian Mixture Model - From scratch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
I = 300 # samples per class
N = 3 # features
J = 2 # clusters
mean_true = np.random.randint(-20, 20, (J, N))
variance_true = np.array([np.identity(N) * np.random.randint(1, 20, N) for j in range(J)])
X = np.vstack([
np.reshape(np.random.multivariate_normal(m, s, I), (I, N, 1))
for m, s in zip(mean_true, variance_true)
])
np.random.shuffle(X)
def pdf(x, mean, variance):
num = np.exp(-0.5 * np.dot(np.dot((x - mean).T, np.linalg.inv(variance)), x - mean))
den = np.sqrt(np.power(2 * np.pi, N) * np.linalg.det(variance))
val = num / den
return val[0][0]
def predict_proba(x):
return [
pdf(x, mean[j], variance[j]) * np.sqrt(np.power(2 * np.pi, N) * np.linalg.det(variance[j]))
for j in range(J)
]
def predict_class(x):
probs = predict_proba(x)
idx = np.argmax(probs)
return idx, probs[idx]
# variables
epochs = 100
phi = np.ones(J) / J
mean = np.random.rand(J, N, 1)
variance = np.array([np.identity(N) * np.var(X, axis=0) for j in range(J)])
likelihood = np.zeros((I, J))
for e in range(epochs):
print('epoch %d/%d' %(e + 1, epochs))
# e-step
for i in range(I):
s = sum(phi[j] * pdf(X[i], mean[j], variance[j]) for j in range(J))
for j in range(J):
likelihood[i, j] = phi[j] * pdf(X[i], mean[j], variance[j]) / s
# m-step
for j in range(J):
s = sum(likelihood[i, j] for i in range(I))
mean[j] = sum(likelihood[i, j] * X[i] for i in range(I)) / s
variance[j] = sum(likelihood[i, j] * np.dot(X[i] - mean[j], (X[i] - mean[j]).T) for i in range(I)) / s
phi[j] = s / I
print('\nreal (mean, variance)')
for j in range(J):
print(j, '-', mean_true[j].reshape(N), variance_true[j].diagonal())
print('\nfound (mean, variance)')
for j in range(J):
print(j, '-', mean[j].reshape(N), variance[j].diagonal())
print('\npredictions on random samples from true distributions:')
for j in range(J):
samples = np.reshape(np.random.multivariate_normal(mean_true[j], variance_true[j], 5), (5, N, 1))
predictions = [predict_class(sample)[0] for sample in samples]
print(j, '-', predictions)
# plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], marker='.')
for j in range(J):
ax.scatter(*mean[j], color='red', marker='+')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment