Last active
June 1, 2020 11:19
-
-
Save emilemathieu/b4b6597ff740534c79b73caba18ce091 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import io | |
import pandas as pd | |
import numpy as np | |
import torch | |
import yaml | |
from scipy import stats | |
import math | |
from geoflow.utils import walklevel | |
def query(data_frame, query_string): | |
if query_string == "all": | |
return data_frame | |
return data_frame.query(query_string) | |
def map_loss(losses): | |
return {"loss_last": losses["test_loss"][-1], "loss_best": min(losses["test_loss"])} | |
def load_experiments(path, map_loss=map_loss): | |
""" load all experiments in path folder """ | |
opts = [] | |
runs = [x[0] for x in walklevel(path, level=2)][1:] | |
for run in runs: | |
if not os.path.isdir(os.path.join(run, ".hydra")): | |
continue | |
try: | |
with open(os.path.join(run, ".hydra/config.yaml"), "r") as stream: | |
args = yaml.safe_load(stream) | |
losses = torch.load(run + "/losses.rar") | |
if map_loss: | |
losses = map_loss(losses) | |
opts.append({**args, **losses}) | |
except: | |
pass | |
return pd.DataFrame(opts) | |
def parse_agg(series): | |
return float(series) if len(series) == 1 else series.tolist() | |
def inter_ci(series): | |
return confidence_interval(series)[1] | |
def mean(series): | |
return confidence_interval(series)[0] | |
def lower_ci(series): | |
return confidence_interval(series)[2] | |
def upper_ci(series): | |
return confidence_interval(series)[3] | |
def confidence_interval(series): | |
arr = pd.DataFrame(item for item in series) | |
count = arr.count(0) | |
mean = arr.mean(0) | |
std = arr.std(0, ddof=True) | |
alpha = 0.10 | |
df = len(arr) - 1 | |
t = stats.t.ppf(1 - alpha / 2, df) | |
sigma = t * std / np.sqrt(count) | |
return tuple(map(parse_agg, (mean, sigma, mean - sigma, mean + sigma))) | |
def compute_ci_and_format(data, level=0): | |
new_data = pd.DataFrame() | |
for _, new_df in data.groupby(level=level): | |
new_df.columns = new_df.columns.droplevel(level=0) | |
sorted_idx = np.argsort(new_df["mean"]) | |
is_gap = new_df.iloc[sorted_idx[0]]["upper_ci"] < new_df.iloc[sorted_idx[1]]["lower_ci"] | |
bolds = [""] * len(sorted_idx) | |
if is_gap: | |
bolds[sorted_idx[0]] = "\\bm" | |
new_col = [ | |
"${}{{{:.2f}}}_{{\pm {:.2f}}}$".format(bold, mean, sigma) | |
for bold, mean, sigma in zip(bolds, new_df["mean"], new_df["inter_ci"]) | |
] | |
series = pd.DataFrame(new_col, index=new_df.index, dtype=pd.StringDtype()) | |
new_data = pd.concat([new_data, series]) | |
return new_data | |
def remove_useless_columns(data): | |
for col in data: | |
if isinstance(data.loc[:, col].tolist()[0], list): | |
continue | |
if len(data.loc[:, col].unique()) == 1: | |
data.pop(col) | |
return data |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment