Last active
April 18, 2022 22:34
-
-
Save yoshiso/ec99ad8bebd898f8a3e1bb4a18375109 to your computer and use it in GitHub Desktop.
numerai leaderboard streamlit app
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Visualization app data from https://www.jofaichow.co.uk/numerati/data.html | |
# Usage: | |
Download data from the website and give it to app | |
$ streamlit app.py $CSV_FILE_PATH $BENCHMARK_MODEL_NAMES | |
# Example: | |
$ streamlit app.py \ | |
data/numerati.csv \ | |
'["integration_test"]' | |
""" | |
from typing import List | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
@st.cache | |
def load_data(path: str) -> pd.DataFrame: | |
df = pd.read_csv(path, dtype={"round": int}) | |
# Who tries to escape with dropna. Name of nan translated to string. | |
df["model"].replace(np.nan, "nan", inplace=True) | |
return df | |
@st.cache | |
def rounds(path: str) -> List[str]: | |
return list(reversed(sorted(load_data(path)["round"].unique()))) | |
@st.cache | |
def models(path: str) -> List[str]: | |
return list(sorted(load_data(path)["model"].unique())) | |
def head(path): | |
st.table(load_data(path).head(20)) | |
def series_view(path: str, benchmarks: List[str]): | |
round_options = rounds(path) | |
start = st.sidebar.selectbox(label="Start", options=round_options, index=40) | |
end = st.sidebar.selectbox(label="End", options=round_options, index=0) | |
field = st.sidebar.selectbox(label="Field", options=["corrmmc", "corr", "mmc"]) | |
# index = round, column = model | |
table = load_data(path).set_index(["round", "model"])[field].unstack() | |
table = table[(table.index >= start) & (table.index <= end)] | |
mean = table.mean(axis=0).rename(field) | |
q99 = table.quantile(0.99, axis=1).rename("99%tile") | |
q90 = table.quantile(0.90, axis=1).rename("90%tile") | |
q75 = table.quantile(0.75, axis=1).rename("75%tile") | |
q50 = table.quantile(0.50, axis=1).rename("median") | |
q25 = table.quantile(0.25, axis=1).rename("25%tile") | |
q10 = table.quantile(0.10, axis=1).rename("10%tile") | |
q1 = table.quantile(0.01, axis=1).rename("1%tile") | |
fig, ax = plt.subplots(1, 1, figsize=(10, 4)) | |
plt.axhline(0, color="gray", alpha=0.8) | |
plt.fill_between(q90.index, q10, q90, alpha=0.2, color="gray") | |
plt.fill_between(q90.index, q25, q75, alpha=0.2, color="gray") | |
q1.plot(ax=ax, color="gray", linewidth=0.3) | |
q10.plot(ax=ax, color="gray", linewidth=0.3) | |
q25.plot(ax=ax, color="gray", linewidth=0.3) | |
q50.plot(ax=ax, color="black", linewidth=0.3) | |
q75.plot(ax=ax, color="gray", linewidth=0.3) | |
q90.plot(ax=ax, color="gray", linewidth=0.3) | |
q99.plot(ax=ax, color="gray", linewidth=0.3) | |
colors = plt.cm.get_cmap("tab20").colors | |
for i, name in enumerate(benchmarks): | |
if name not in table.columns: | |
continue | |
rank = _rank(mean).loc[name] * 100 | |
table[name].plot(ax=ax, color=colors[i], label=f"{name}({rank:.1f}%tile)") | |
plt.xlim(start, end) | |
plt.legend(loc="upper left", bbox_to_anchor=(1, 1)) | |
st.pyplot(fig) | |
st.table(mean[benchmarks].sort_values(ascending=False).to_frame()) | |
def round_view(path: str, benchmarks: List[str]): | |
round = st.sidebar.selectbox(label="Round", options=rounds(path), index=0) | |
table = load_data(path) | |
table = table[table["round"] == round].set_index("model") | |
for field in ["corrmmc", "corr", "mmc"]: | |
fig, ax = plt.subplots(1, 1, figsize=(10, 4)) | |
table[field].hist(bins=50, ax=ax, alpha=0.8, color="gray") | |
plt.title(field) | |
plt.axvline(table[field].median(), label="median", color="black", alpha=0.8) | |
colors = plt.cm.get_cmap("tab20").colors | |
for i, name in enumerate(benchmarks): | |
if name not in table.index: | |
continue | |
_plot_line(table[field], name, color=colors[i]) | |
plt.legend(loc="upper left", bbox_to_anchor=(1, 1)) | |
st.pyplot(fig) | |
def _plot_line(values: pd.Series, name: str, color): | |
value = values.loc[name] | |
rank = _rank(values).loc[name] | |
plt.axvline(value, label=f"{name} ({rank * 100:.1f}%tile)", color=color) | |
def main(path: str, benchmarks: List[str] = []): | |
page = st.sidebar.selectbox( | |
label="Page", options=["head", "round_view", "series_view"], index=0 | |
) | |
selected = st.sidebar.multiselect( | |
label="Benchmarks", options=models(path), default=benchmarks, | |
) | |
if page == "head": | |
head(path) | |
if page == "round_view": | |
round_view(path, selected) | |
if page == "series_view": | |
series_view(path, selected) | |
def _rank(series): | |
rank = series.rank() | |
return rank / rank.max() | |
if __name__ == "__main__": | |
import fire | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment