Skip to content

Instantly share code, notes, and snippets.

@Kwentar
Last active July 25, 2018 20:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Kwentar/24b0cb4f8eb1eb6e44548c2cf8bcc5c7 to your computer and use it in GitHub Desktop.
Save Kwentar/24b0cb4f8eb1eb6e44548c2cf8bcc5c7 to your computer and use it in GitHub Desktop.
BayesGenerator
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal as mvn
from keras.datasets import mnist
class BayesClassifier:
def fit(self, X, Y):
self.K = len(set(Y))
self.gaussians = []
for k in range(self.K):
Xk = X[Y == k]
mean = Xk.mean(axis=0)
cov = np.cov(Xk.T)
g = {'m': mean, 'c': cov}
self.gaussians.append(g)
def sample_given_y(self, y):
g = self.gaussians[y]
return mvn.rvs(mean=g['m'], cov=g['c'])
def sample(self):
y = np.random.randint(self.K)
return self.sample_given_y(y)
if __name__ == '__main__':
(x_train, y_train), (x_test, y_test) = mnist.load_data()
clf = BayesClassifier()
clf.fit(x_train.reshape(-1, 784), y_train)
for index in range(10):
sample = clf.sample()
plt.subplot(2, 5, index+1)
plt.imshow(sample.reshape(28, 28), cmap='gray')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment