Skip to content

Instantly share code, notes, and snippets.

@jrhone
Created November 1, 2019 13:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jrhone/c5b827b37295186d1d3ac11fb427870e to your computer and use it in GitHub Desktop.
Save jrhone/c5b827b37295186d1d3ac11fb427870e to your computer and use it in GitHub Desktop.
from typing import Tuple
import pandas as pd
import plotly.express as px
import streamlit as st
import pathlib
DATA_LOCAL = pathlib.Path(__file__).parent / "country_indicators.csv"
DATA_URL = (
"https://gist.githubusercontent.com/chriddyp/cb5392c35661370d95f300086accea51/"
"raw/8e0768211f6b747c0db42a9ce9a0937dafcbd8b2/indicators.csv"
)
def create_plot(df: pd.DataFrame, x_indicator: str, y_indicator: str, year_range: Tuple[int, int]) -> "plotly.Fig":
_df = df[df["Year"].between(*year_range)]
xs = _df[_df["Indicator Name"] == x_indicator]
ys = _df[_df["Indicator Name"] == y_indicator]
dataframe = pd.merge(xs, ys, how="inner", on=["Country Name", "Year"])
title = f"Country Indicators"
fig = px.scatter(dataframe, x="Value_x", y="Value_y", title=title, height=400)
fig.update_layout(dict(xaxis=dict(title=dict(text=x_indicator))))
fig.update_layout(dict(yaxis=dict(title=dict(text=y_indicator))))
return fig
def prepare_plot(df):
available_indicators = df["Indicator Name"].unique()
min_value = min(df["Year"])
max_value = max(df["Year"])
x_indicator = st.selectbox("Select indicator x", available_indicators, 0)
y_indicator = st.selectbox("Select indicator y", available_indicators, 1)
plotly_chart = st.empty()
# Hack to seperate plot and slider
st.markdown("<br><br>", unsafe_allow_html=True)
year_range = st.slider(
"Select min and max Year",
min_value=min_value,
max_value=max_value,
value=[min_value, max_value],
)
fig = create_plot(df, x_indicator, y_indicator, year_range)
plotly_chart.plotly_chart(fig, width=0, height=300)
# st.plotly_chart(fig, width=0, height=300)
@st.cache(show_spinner=False)
def get_dataframe(url) -> pd.DataFrame:
return pd.read_csv(url)
def get_data_from_url(url: str, local: pathlib.Path) -> pd.DataFrame:
if local.exists():
df = get_dataframe(local.as_posix())
else:
df = get_dataframe(url)
df.to_csv(local, index=False)
return df
st.markdown("""## Country Indicators - Streamlit version""")
data = get_data_from_url(DATA_URL, DATA_LOCAL)
prepare_plot(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment