Skip to content

Instantly share code, notes, and snippets.

@DFoly
Last active March 8, 2019 14:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DFoly/37678ad6215cf7a4a6d363ce373157e1 to your computer and use it in GitHub Desktop.
Save DFoly/37678ad6215cf7a4a6d363ce373157e1 to your computer and use it in GitHub Desktop.
def predict(self, X):
"""Returns predicted labels using Bayes Rule to
Calculate the posterior distribution
Parameters:
-------------
X: N*d numpy array
Returns:
----------
labels: predicted cluster based on
highest responsibility gamma.
"""
labels = np.zeros((X.shape[0], self.C))
for c in range(self.C):
labels [:,c] = self.pi[c] * mvn.pdf(X, self.mu[c,:], self.sigma[c])
labels = labels.argmax(1)
return labels
def predict_proba(self, X):
"""Returns predicted labels
Parameters:
-------------
X: N*d numpy array
Returns:
----------
labels: predicted cluster based on
highest responsibility gamma.
"""
post_proba = np.zeros((X.shape[0], self.C))
for c in range(self.C):
# Posterior Distribution using Bayes Rule
post_proba[:,c] = self.pi[c] * mvn.pdf(X, self.mu[c,:], self.sigma[c])
return post_proba
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment