Skip to content

Instantly share code, notes, and snippets.

Last active December 31, 2015 08:39
Show Gist options
  • Save dlovell/7961878 to your computer and use it in GitHub Desktop.
Save dlovell/7961878 to your computer and use it in GitHub Desktop.
UPDATE: Uses commit 7a454761519ac721f39ab52a2602b2bcfd4a0e65 of SHOW_DIAGNOSTICS branch (take a dict of summary functions) UPDATE: Uses commit cd0aee975a5178eecfe0299b817269a7aa17fd22 of SHOW_DIAGNOSTICS branch (add reprocess_summary_func argument to engine.analyze) Demonstrate passing a summary function to engines. Uses commit 1a17f7891fe2dfa42…
# -*- coding: utf-8 -*-
# <nbformat>3.0</nbformat>
# <codecell>
import numpy
import pylab
import crosscat.LocalEngine as LE
import crosscat.MultiprocessingEngine as ME
import crosscat.IPClusterEngine as IPE
import crosscat.utils.data_utils as du
import crosscat.utils.convergence_test_utils as ctu
import crosscat.utils.timing_test_utils as ttu
import crosscat.utils.summary_utils as su
def plot_with_mean(data_arr, hline=None):
data_mean = data_arr.mean(axis=1)
pylab.plot(data_arr, color='k')
pylab.plot(data_mean, linewidth=3, color='r')
if hline is not None:
# <codecell>
# settings
gen_seed = 0
inf_seed = 0
num_clusters = 4
num_cols = 32
num_views = 4
n_steps = 64
diagnostics_every_N= 2
n_test = 40
data_max_mean = 1
data_max_std = 1.
#num_rows = 1600
#n_chains = 32
#config_filename = '/home/dlovell/.config/ipython/profile_ssh/security/ipcontroller-client.json'
num_rows = 100
n_chains = 2
config_filename = None
# generate some data
T, M_r, M_c, data_inverse_permutation_indices = du.gen_factorial_data_objects(
gen_seed, num_clusters, num_cols, num_rows, num_views,
max_mean=data_max_mean, max_std=data_max_std,
view_assignment_truth, X_D_truth = ctu.truth_from_permute_indices(
data_inverse_permutation_indices, num_rows, num_cols, num_views, num_clusters)
X_L_gen, X_D_gen = ttu.get_generative_clustering(M_c, M_r, T,
data_inverse_permutation_indices, num_clusters, num_views)
T_test = ctu.create_test_set(M_c, T, X_L_gen, X_D_gen, n_test, seed_seed=0)
generative_mean_test_log_likelihood = ctu.calc_mean_test_log_likelihood(M_c, T,
X_L_gen, X_D_gen, T_test)
# <codecell>
# create the engine
# engine = ME.MultiprocessingEngine(seed=inf_seed)
engine = IPE.IPClusterEngine(config_filename=config_filename, seed=inf_seed)
# each function must take only p_State as its argument
summary_func_dict = dict(LE.default_summary_func_dict)
def get_ari(p_State):
# requires environment: {view_assignment_truth}
# requires import: {crosscat.utils.convergence_test_utils}
X_L = p_State.get_X_L()
ctu = crosscat.utils.convergence_test_utils
return ctu.get_column_ARI(X_L, view_assignment_truth)
# push the function and any arguments needed from the surrounding environment
args_dict = dict(
engine.dview.push(args_dict, block=True)
summary_func_dict['ARI'] = get_ari
# # Why does this fail?
# # IPython.parallel.error.RemoteError: NameError(global name 'view_assignment_truth' is not defined)
# #
# # this fails WITH SAME ERROR even if version above is run and passes
# # suggests it has to do with environment that get_ari func gets from
# # crosscat.utils.summary_utils.get_ari
# #
# # possibly this stackoverflow post describes what is wrong
# #
# import crosscat.utils.summary_utils
# args_dict = dict(view_assignment_truth=view_assignment_truth)
# engine.dview.push(args_dict, block=True)
# summary_func_dict['ARI'] = crosscat.utils.summary_utils.get_ari
# <codecell>
# run inference
X_L_list, X_D_list = engine.initialize(M_c, M_r, T, n_chains=n_chains)
X_L_list, X_D_list = engine.analyze(M_c, T, X_L_list, X_D_list,
n_steps=n_steps, do_diagnostics=False,
X_L_list, X_D_list, summaries_dict = engine.analyze(M_c, T, X_L_list, X_D_list,
n_steps=n_steps, do_diagnostics=True,
X_L_list, X_D_list, summaries_dict = engine.analyze(M_c, T, X_L_list, X_D_list,
n_steps=n_steps, do_diagnostics=summary_func_dict,
# <codecell>
# plot results
# plot_summaries_names = ['ARI', 'mean_test_ll', 'num_views']
plot_summaries_names = ['logscore', 'num_views', 'column_crp_alpha', 'ARI']
hline_lookup = dict(
for summaries_name in plot_summaries_names:
data_arr = summaries_dict[summaries_name]
hline = hline_lookup.get(summaries_name)
plot_with_mean(data_arr, hline=hline)
# import crosscat.utils.plot_utils as pu
# pu.plot_views(numpy.array(T), X_D_gen, X_L_gen, M_c)
# <codecell>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment