Skip to content

Instantly share code, notes, and snippets.

@yoshiso
Last active April 18, 2022 22:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yoshiso/ec99ad8bebd898f8a3e1bb4a18375109 to your computer and use it in GitHub Desktop.
Save yoshiso/ec99ad8bebd898f8a3e1bb4a18375109 to your computer and use it in GitHub Desktop.
numerai leaderboard streamlit app
"""
Visualization app data from https://www.jofaichow.co.uk/numerati/data.html
# Usage:
Download data from the website and give it to app
$ streamlit app.py $CSV_FILE_PATH $BENCHMARK_MODEL_NAMES
# Example:
$ streamlit app.py \
data/numerati.csv \
'["integration_test"]'
"""
from typing import List
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
@st.cache
def load_data(path: str) -> pd.DataFrame:
df = pd.read_csv(path, dtype={"round": int})
# Who tries to escape with dropna. Name of nan translated to string.
df["model"].replace(np.nan, "nan", inplace=True)
return df
@st.cache
def rounds(path: str) -> List[str]:
return list(reversed(sorted(load_data(path)["round"].unique())))
@st.cache
def models(path: str) -> List[str]:
return list(sorted(load_data(path)["model"].unique()))
def head(path):
st.table(load_data(path).head(20))
def series_view(path: str, benchmarks: List[str]):
round_options = rounds(path)
start = st.sidebar.selectbox(label="Start", options=round_options, index=40)
end = st.sidebar.selectbox(label="End", options=round_options, index=0)
field = st.sidebar.selectbox(label="Field", options=["corrmmc", "corr", "mmc"])
# index = round, column = model
table = load_data(path).set_index(["round", "model"])[field].unstack()
table = table[(table.index >= start) & (table.index <= end)]
mean = table.mean(axis=0).rename(field)
q99 = table.quantile(0.99, axis=1).rename("99%tile")
q90 = table.quantile(0.90, axis=1).rename("90%tile")
q75 = table.quantile(0.75, axis=1).rename("75%tile")
q50 = table.quantile(0.50, axis=1).rename("median")
q25 = table.quantile(0.25, axis=1).rename("25%tile")
q10 = table.quantile(0.10, axis=1).rename("10%tile")
q1 = table.quantile(0.01, axis=1).rename("1%tile")
fig, ax = plt.subplots(1, 1, figsize=(10, 4))
plt.axhline(0, color="gray", alpha=0.8)
plt.fill_between(q90.index, q10, q90, alpha=0.2, color="gray")
plt.fill_between(q90.index, q25, q75, alpha=0.2, color="gray")
q1.plot(ax=ax, color="gray", linewidth=0.3)
q10.plot(ax=ax, color="gray", linewidth=0.3)
q25.plot(ax=ax, color="gray", linewidth=0.3)
q50.plot(ax=ax, color="black", linewidth=0.3)
q75.plot(ax=ax, color="gray", linewidth=0.3)
q90.plot(ax=ax, color="gray", linewidth=0.3)
q99.plot(ax=ax, color="gray", linewidth=0.3)
colors = plt.cm.get_cmap("tab20").colors
for i, name in enumerate(benchmarks):
if name not in table.columns:
continue
rank = _rank(mean).loc[name] * 100
table[name].plot(ax=ax, color=colors[i], label=f"{name}({rank:.1f}%tile)")
plt.xlim(start, end)
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
st.pyplot(fig)
st.table(mean[benchmarks].sort_values(ascending=False).to_frame())
def round_view(path: str, benchmarks: List[str]):
round = st.sidebar.selectbox(label="Round", options=rounds(path), index=0)
table = load_data(path)
table = table[table["round"] == round].set_index("model")
for field in ["corrmmc", "corr", "mmc"]:
fig, ax = plt.subplots(1, 1, figsize=(10, 4))
table[field].hist(bins=50, ax=ax, alpha=0.8, color="gray")
plt.title(field)
plt.axvline(table[field].median(), label="median", color="black", alpha=0.8)
colors = plt.cm.get_cmap("tab20").colors
for i, name in enumerate(benchmarks):
if name not in table.index:
continue
_plot_line(table[field], name, color=colors[i])
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
st.pyplot(fig)
def _plot_line(values: pd.Series, name: str, color):
value = values.loc[name]
rank = _rank(values).loc[name]
plt.axvline(value, label=f"{name} ({rank * 100:.1f}%tile)", color=color)
def main(path: str, benchmarks: List[str] = []):
page = st.sidebar.selectbox(
label="Page", options=["head", "round_view", "series_view"], index=0
)
selected = st.sidebar.multiselect(
label="Benchmarks", options=models(path), default=benchmarks,
)
if page == "head":
head(path)
if page == "round_view":
round_view(path, selected)
if page == "series_view":
series_view(path, selected)
def _rank(series):
rank = series.rank()
return rank / rank.max()
if __name__ == "__main__":
import fire
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment