Skip to content

Instantly share code, notes, and snippets.

@soonraah
Created October 5, 2014 18:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soonraah/be02066bc45634df036d to your computer and use it in GitHub Desktop.
Save soonraah/be02066bc45634df036d to your computer and use it in GitHub Desktop.
To compare EM algorithm and MCMC on GMM training.
import numpy as np
from sklearn import cross_validation, mixture
import pickle
import os
import pystan
import time
import matplotlib.pyplot as plt
def dump_stan_model(stan_model, compiled_file_name):
"""
Dump compiled Stan model.
:param stan_model: compiled stan model instance
:param compiled_file_name: pickled file name (output)
"""
f = open(compiled_file_name, 'wb')
with f:
pickle.dump(stan_model, f)
def load_stan_model(compiled_file_name):
"""
Load compiled Stan model.
:param compiled_file_name: pickled stan model file
:return: loaded stan model instance
"""
f = open(compiled_file_name, 'rb')
with f:
stan_model = pickle.load(f)
return stan_model
def convert_model(stan_gmm_model):
"""
Convert a GMM model from Stan format to scikit-learn format.
:param stan_gmm_model: Stan's optimized result
:return: mixture.GMM instance from scikit-learn
"""
num_mixture_components = stan_gmm_model.get('weights').size
gmm = mixture.GMM(n_components=num_mixture_components, covariance_type='diag')
gmm.weights_ = stan_gmm_model.get('weights')
gmm.means_ = stan_gmm_model.get('mu')
gmm.covars_ = np.square(stan_gmm_model.get('sigma'))
return gmm
def draw_result_graph(likelihoods_em, likelihoods_mcmc):
"""
Draw a graph that shows likelihood of EM vs. MCMC
:param likelihoods_em:
:param likelihoods_mcmc:
:return:
"""
fix, ax = plt.subplots()
ax.scatter(likelihoods_em, likelihoods_mcmc, marker='o')
ax.plot([-7.5, -6.0], [-7.5, -6.0], color='gray', alpha=0.5)
plt.xlabel("Average Log Likelihood (EM)")
plt.ylabel("Average Log Likelihood (MCMC)")
plt.show()
def main():
# prepare data
data_file_name = 'winequality-white.csv'
raw_data_set = np.loadtxt(data_file_name, delimiter=";", skiprows=1)
data_set = raw_data_set[:, :11] # remove "quority" column
# load stan model
stan_code_file_name = 'multi_dimensional_gmm_diagonal.stan'
stan_compiled_file_name = 'multi_dimensional_gmm_diagonal.pkl'
if os.path.isfile(stan_compiled_file_name):
stan_model = load_stan_model(stan_compiled_file_name)
else:
stan_model = pystan.StanModel(file=stan_code_file_name)
dump_stan_model(stan_model, stan_compiled_file_name)
# cross validation
num_validations = 500
num_mixture_components = 4
cnt = 0
time_sec_em = 0.0
time_sec_mcmc = 0.0
likelihoods_em = []
likelihoods_mcmc = []
ss = cross_validation.ShuffleSplit(n=len(data_set.data), n_iter=num_validations, test_size=0.5)
for training_indexes, evaluation_indexes in ss:
cnt += 1
print("--------------------------------")
print("ITERATION {0}".format(cnt))
print("--------------------------------")
# separate data by ShuffleSplit results
tr_data_set = data_set[training_indexes]
ev_data_set = data_set[evaluation_indexes]
# run EM algorithm by scikit-learn
gmm_em = mixture.GMM(n_components=num_mixture_components, covariance_type='diag')
t = time.time()
gmm_em.fit(tr_data_set)
time_sec_em += time.time() - t
likelihoods_em.append(gmm_em.score(ev_data_set).mean())
# run MCMC by PyStan
data_dic = dict(D=tr_data_set.shape[1], N=tr_data_set.shape[0], M=num_mixture_components, X=tr_data_set)
t = time.time()
optimizing_result = stan_model.optimizing(data=data_dic, iter=20000)
time_sec_mcmc += time.time() - t
gmm_mcmc = convert_model(optimizing_result)
likelihoods_mcmc.append(gmm_mcmc.score(ev_data_set).mean())
print("--------------------------------")
print("COMPLETED")
print("--------------------------------")
print("likelihoods_em:", likelihoods_em)
print("likelihoods_mcmc:", likelihoods_mcmc)
print("avg time em: {0:.3f} sec".format(time_sec_em / num_validations))
print("avg time mcmc: {0:.3f} sec".format(time_sec_mcmc / num_validations))
draw_result_graph(likelihoods_em, likelihoods_mcmc)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment