Created
June 1, 2020 11:18
-
-
Save emilemathieu/327c6d6180b24272089fd341f79f638a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from utils import query, load_experiments, lower_ci, upper_ci, inter_ci, mean, convert_to_latex, compute_ci_and_format | |
def process_data(data): | |
data = data.unstack(level=-1) | |
data.columns = data.columns.droplevel(level=0) | |
bold_rows_name = dict((key, "\\bf " + value) for key, value in rows_name.items()) | |
data = data.reindex(columns=rows_name.keys()) | |
data.index.names = ["Loss", "${\\kappa}$"] | |
data = data.reindex(index=target_param_scale, level=1) | |
data = data.reindex(index=list(cols_name.keys()), level=0) | |
data = data.rename(index=cols_name, columns=bold_rows_name) | |
return data | |
def write_data_to_latex(data): | |
latex_path = "doc/tables/vmf.tex" | |
filename = path = os.path.join(os.getcwd(), latex_path) | |
data.to_latex(buf=filename, escape=False, multirow=True, column_format="cccccc") | |
def make_table(path): | |
loaded_data = load_experiments(path) | |
criteria = [ | |
"target_param_scale in {}".format(target_param_scale), | |
"target == 'VMFm0'", | |
"div == 'bf'", | |
"solver == 'rk4'", | |
"l2int == 0.", | |
] | |
gb = ["obj", "target_param_scale", "model"] | |
contents = {"loss_best": ["mean", inter_ci, lower_ci, upper_ci, "count"]} | |
data = query(loaded_data, " & ".join(criteria)).groupby(gb).agg(contents) | |
data = compute_ci_and_format(data, level=(0, 1)) | |
data = process_data(data) | |
write_data_to_latex(data) | |
def make_densities_plot(path): | |
# python3 geoflow/run_exp.py -m load=True model=stereo,riem target_manifold=Hypersphere target=VMFm0 obj=like,elbo target_param_scale=100,50,10 lr=1e-3 div=bf epochs=3000 test_freq=50 test_batch_size=2000 viz_freq=-2 name=final/vmf seed=-1 viz_plot=vmf viz_final=True run=0,1,2,3 gpu=0 | |
pass | |
if __name__ == "__main__": | |
# experiment = "experiments/final/vmf" | |
experiment = "experiments/vmf" | |
path = os.path.join(os.getcwd(), experiment) | |
rows_name = {"stereo": "Stereographic", "riem": "Riemannian"} | |
cols_name = {"like": "$\\mathcal{L}^{\\text{Like}}$", "elbo": "$\\mathcal{L}^{\\text{KL}}$"} | |
target_param_scale = [100, 50, 10] | |
make_table(path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment