Last active
March 8, 2019 14:28
-
-
Save DFoly/37678ad6215cf7a4a6d363ce373157e1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict(self, X): | |
"""Returns predicted labels using Bayes Rule to | |
Calculate the posterior distribution | |
Parameters: | |
------------- | |
X: N*d numpy array | |
Returns: | |
---------- | |
labels: predicted cluster based on | |
highest responsibility gamma. | |
""" | |
labels = np.zeros((X.shape[0], self.C)) | |
for c in range(self.C): | |
labels [:,c] = self.pi[c] * mvn.pdf(X, self.mu[c,:], self.sigma[c]) | |
labels = labels.argmax(1) | |
return labels | |
def predict_proba(self, X): | |
"""Returns predicted labels | |
Parameters: | |
------------- | |
X: N*d numpy array | |
Returns: | |
---------- | |
labels: predicted cluster based on | |
highest responsibility gamma. | |
""" | |
post_proba = np.zeros((X.shape[0], self.C)) | |
for c in range(self.C): | |
# Posterior Distribution using Bayes Rule | |
post_proba[:,c] = self.pi[c] * mvn.pdf(X, self.mu[c,:], self.sigma[c]) | |
return post_proba |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment