Skip to content

Instantly share code, notes, and snippets.

@emilemathieu
Created June 1, 2020 11:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save emilemathieu/327c6d6180b24272089fd341f79f638a to your computer and use it in GitHub Desktop.
Save emilemathieu/327c6d6180b24272089fd341f79f638a to your computer and use it in GitHub Desktop.
import os
from utils import query, load_experiments, lower_ci, upper_ci, inter_ci, mean, convert_to_latex, compute_ci_and_format
def process_data(data):
data = data.unstack(level=-1)
data.columns = data.columns.droplevel(level=0)
bold_rows_name = dict((key, "\\bf " + value) for key, value in rows_name.items())
data = data.reindex(columns=rows_name.keys())
data.index.names = ["Loss", "${\\kappa}$"]
data = data.reindex(index=target_param_scale, level=1)
data = data.reindex(index=list(cols_name.keys()), level=0)
data = data.rename(index=cols_name, columns=bold_rows_name)
return data
def write_data_to_latex(data):
latex_path = "doc/tables/vmf.tex"
filename = path = os.path.join(os.getcwd(), latex_path)
data.to_latex(buf=filename, escape=False, multirow=True, column_format="cccccc")
def make_table(path):
loaded_data = load_experiments(path)
criteria = [
"target_param_scale in {}".format(target_param_scale),
"target == 'VMFm0'",
"div == 'bf'",
"solver == 'rk4'",
"l2int == 0.",
]
gb = ["obj", "target_param_scale", "model"]
contents = {"loss_best": ["mean", inter_ci, lower_ci, upper_ci, "count"]}
data = query(loaded_data, " & ".join(criteria)).groupby(gb).agg(contents)
data = compute_ci_and_format(data, level=(0, 1))
data = process_data(data)
write_data_to_latex(data)
def make_densities_plot(path):
# python3 geoflow/run_exp.py -m load=True model=stereo,riem target_manifold=Hypersphere target=VMFm0 obj=like,elbo target_param_scale=100,50,10 lr=1e-3 div=bf epochs=3000 test_freq=50 test_batch_size=2000 viz_freq=-2 name=final/vmf seed=-1 viz_plot=vmf viz_final=True run=0,1,2,3 gpu=0
pass
if __name__ == "__main__":
# experiment = "experiments/final/vmf"
experiment = "experiments/vmf"
path = os.path.join(os.getcwd(), experiment)
rows_name = {"stereo": "Stereographic", "riem": "Riemannian"}
cols_name = {"like": "$\\mathcal{L}^{\\text{Like}}$", "elbo": "$\\mathcal{L}^{\\text{KL}}$"}
target_param_scale = [100, 50, 10]
make_table(path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment