Last active
April 6, 2020 12:33
-
-
Save DFoly/4e374ba9940bbcefcdbc3d48a9dba4a1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GMM: | |
""" Gaussian Mixture Model | |
Parameters | |
----------- | |
k: int , number of gaussian distributions | |
seed: int, will be randomly set if None | |
max_iter: int, number of iterations to run algorithm, default: 200 | |
Attributes | |
----------- | |
centroids: array, k, number_features | |
cluster_labels: label for each data point | |
""" | |
def __init__(self, C, n_runs): | |
self.C = C # number of Guassians/clusters | |
self.n_runs = n_runs | |
def get_params(self): | |
return (self.mu, self.pi, self.sigma) | |
def calculate_mean_covariance(self, X, prediction): | |
"""Calculate means and covariance of different | |
clusters from k-means prediction | |
Parameters: | |
------------ | |
prediction: cluster labels from k-means | |
X: N*d numpy array data points | |
Returns: | |
------------- | |
intial_means: for E-step of EM algorithm | |
intial_cov: for E-step of EM algorithm | |
""" | |
d = X.shape[1] | |
labels = np.unique(prediction) | |
self.initial_means = np.zeros((self.C, d)) | |
self.initial_cov = np.zeros((self.C, d, d)) | |
self.initial_pi = np.zeros(self.C) | |
counter=0 | |
for label in labels: | |
ids = np.where(prediction == label) # returns indices | |
self.initial_pi[counter] = len(ids[0]) / X.shape[0] | |
self.initial_means[counter,:] = np.mean(X[ids], axis = 0) | |
de_meaned = X[ids] - self.initial_means[counter,:] | |
Nk = X[ids].shape[0] # number of data points in current gaussian | |
self.initial_cov[counter,:, :] = np.dot(self.initial_pi[counter] * de_meaned.T, de_meaned) / Nk | |
counter+=1 | |
assert np.sum(self.initial_pi) == 1 | |
return (self.initial_means, self.initial_cov, self.initial_pi) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment