Skip to content

Instantly share code, notes, and snippets.

@michaelosthege
Last active August 12, 2020 19:21
Show Gist options
  • Save michaelosthege/6cd14970dd789247176c4d4a1dd28051 to your computer and use it in GitHub Desktop.
Save michaelosthege/6cd14970dd789247176c4d4a1dd28051 to your computer and use it in GitHub Desktop.
Analysis of Rt.live model estimates of R0 - a comparison between US regions (alternative link: https://nbviewer.jupyter.org/gist/michaelosthege/6cd14970dd789247176c4d4a1dd28051/_R0_regions.ipynb)
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import arviz
import fastprogress
import logging
import matplotlib
from matplotlib import pyplot
import numpy
import os
import pandas
import pathlib
import pymc3
import typing
_log = logging.getLogger(__file__)
US_REGION_CODES = {
"Alabama": "AL",
"Alaska": "AK",
"Arizona": "AZ",
"Arkansas": "AR",
"California": "CA",
"Colorado": "CO",
"Connecticut": "CT",
"District of Columbia": "DC",
"Delaware": "DE",
"Florida": "FL",
"Georgia": "GA",
"Hawaii": "HI",
"Idaho": "ID",
"Illinois": "IL",
"Indiana": "IN",
"Iowa": "IA",
"Kansas": "KS",
"Kentucky": "KY",
"Louisiana": "LA",
"Maine": "ME",
"Maryland": "MD",
"Massachusetts": "MA",
"Michigan": "MI",
"Minnesota": "MN",
"Mississippi": "MS",
"Missouri": "MO",
"Montana": "MT",
"Nebraska": "NE",
"Nevada": "NV",
"New Hampshire": "NH",
"New Jersey": "NJ",
"New Mexico": "NM",
"New York": "NY",
"North Dakota": "ND",
"North Carolina": "NC",
"Ohio": "OH",
"Oklahoma": "OK",
"Oregon" : "OR",
"Pennsylvania" : "PA",
"Rhode Island": "RI",
"South Carolina": "SC",
"South Dakota": "SD",
"Tennessee": "TN",
"Texas": "TX",
"Utah": "UT",
"Vermont": "VT",
"Virginia": "VA",
"Washington": "WA",
"West Virginia": "WV",
"Wisconsin": "WI",
"Wyoming": "WY",
}
US_REGIONS = list(US_REGION_CODES.values())
DATA_DATE = '2020-06-25'
def plot_r_t(region: str, country: str="us"):
# read the data
idata = arviz.from_netcdf(pathlib.Path(country, region, DATA_DATE, "trace.nc"))
fig, ax = pyplot.subplots(
dpi=140,
figsize=(10, 4),
)
pymc3.gp.util.plot_gp_dist(
ax=ax,
x=idata.posterior.date.values,
samples=idata.posterior.r_t.stack(sample=("chain", "draw")).T.values,
)
ax.axhline(1, linestyle=":")
ax.set_ylabel("$R_e(t)$ [-]", fontsize=15)
ax.legend(
handles=[
ax.fill_between([], [], color="red", label=f"us/{region}")
],
loc="upper left",
frameon=False,
)
ax.xaxis.set_major_locator(
matplotlib.dates.WeekdayLocator(interval=1, byweekday=matplotlib.dates.MO)
)
ax.xaxis.set_minor_locator(matplotlib.dates.DayLocator())
ax.xaxis.set_tick_params(rotation=90)
ax.set_ylim(0, 8)
fig.tight_layout()
return pyplot.show()
def plot_r_0(regions: typing.Sequence[str], country: str="us"):
region_samples = {}
for region in fastprogress.progress_bar(regions, leave=False):
idata = arviz.from_netcdf(pathlib.Path(country, region, DATA_DATE, "trace.nc"))
region_samples[region] = idata.posterior.r_t.stack(sample=('chain', 'draw')).values[0, :]
region_medians = {
region : numpy.median(samples)
for region, samples in region_samples.items()
}
fig, ax = pyplot.subplots(
dpi=140,
figsize=(10, 6),
)
for r, (region, median) in enumerate(sorted(region_medians.items(), key=lambda kv: kv[1])):
arviz.plot_kde(
ax=ax,
values=region_samples[region],
plot_kwargs=dict(linewidth=0.5)
)
# plot arrow to indicate the median
ax.annotate(
s=region,
xy=(median, 0),
xytext=(median, 0.15 + r % 9 * 0.15),
horizontalalignment="center", fontweight="bold",
arrowprops=dict(arrowstyle="-|>", facecolor="black", shrinkA=0, shrinkB=0)
)
ax.set_yticks(ticks=[], minor=[])
ax.axvline(1, linestyle=":")
ax.set_xlabel("$R_0$ [-]", fontsize=15)
ax.set_xlim(left=0)
ax.set_ylabel("$p(R_0 \mid data)$", fontsize=15)
ax.set_ylim(0)
fig.tight_layout()
pyplot.show()
return region_medians, region_samples
def _get_US_population_densities() -> pandas.DataFrame:
dfs = pandas.read_html('https://en.wikipedia.org/wiki/List_of_states_and_territories_of_the_United_States_by_population_density')
df = dfs[0].rename(columns={
"State etc.": "name",
"perkm2": "population_density",
})[[("name", "name"), ("Population density", "population_density")]]
df.columns = df.columns.droplevel(level=0)
df["code"] = [
US_REGION_CODES[name]
if name in US_REGION_CODES else
None
for name in df.name
]
df.replace(to_replace="<1", value=0.49, inplace=True)
df = df.dropna().set_index("code").sort_index()
return df
def plot_scatter_r_0(regions: typing.Sequence[str], on_x="population_density", country: str="us"):
df_densities = _get_US_population_densities()
region_samples = {}
for region in fastprogress.progress_bar(regions):
idata = arviz.from_netcdf(pathlib.Path(country, region, DATA_DATE, "trace.nc"))
region_samples[region] = idata.posterior.r_t.stack(sample=('chain', 'draw')).values[0, :]
fig, ax = pyplot.subplots(dpi=140, figsize=(7, 7))
ax.violinplot(
dataset=[
samples.flatten()
for samples in region_samples.values()
],
positions=[
numpy.log10(float(df_densities.loc[region, "population_density"]))
for region in region_samples.keys()
],
showextrema=False,
widths=0.3,
)
for region in regions:
ax.text(
s=region,
x=numpy.log10(float(df_densities.loc[region, "population_density"])),
y=numpy.median(region_samples[region]),
horizontalalignment="center", fontweight="bold", fontsize=6,
)
ax.xaxis.set_major_formatter(matplotlib.ticker.StrMethodFormatter("$10^{{{x:.0f}}}$"))
ax.xaxis.set_ticks([
numpy.log10(x)
for p in range(-1, 4)
for x in numpy.linspace(10**p, 10**(p+1), 10)
], minor=True)
ax.xaxis.set_ticks([
numpy.log10(x)
for p in range(-1, 4)
for x in numpy.linspace(10**p, 10**(p+1), 2)
], minor=False)
ax.set_ylim(0)
ax.set_xlim(-1, 4)
ax.set_ylabel("$p(R_0 \mid data)$ [-]")
ax.set_xlabel("population density [1/km²]")
return pyplot.show()
def read_nowcast_backcast_comparison(
region: str,
offset_now: int,
offset_back: int,
country: str="us",
hdi_prob: float=0.94,
) -> pandas.DataFrame:
df_results = pandas.DataFrame(columns=[
"date",
"nowcast_run_date", "nowcast_median", "nowcast_hdi_down", "nowcast_hdi_up",
"backcast_run_date", "backcast_median", "backcast_hdi_down", "backcast_hdi_up",
]).set_index("date")
dp_region = pathlib.Path(country, region)
for date_str in os.listdir(dp_region):
fp_trace = pathlib.Path(dp_region, date_str, 'trace.nc')
if fp_trace.exists():
run_date = pandas.Timestamp(date_str)
nowcast_date = run_date + pandas.DateOffset(offset_now)
backcast_date = nowcast_date + pandas.DateOffset(offset_back)
idata = arviz.from_netcdf(fp_trace)
r_t = idata.posterior.r_t.stack(sample=('chain', 'draw'))
if backcast_date not in list(r_t.date):
continue
nowcast = r_t.sel(date=nowcast_date).values
df_results.loc[nowcast_date, "nowcast_run_date"] = run_date
df_results.loc[nowcast_date, "nowcast_median"] = float(numpy.median(nowcast))
df_results.loc[nowcast_date, ["nowcast_hdi_down", "nowcast_hdi_up"]] = tuple(arviz.hdi(nowcast, hdi_prob=hdi_prob))
backcast = r_t.sel(date=backcast_date).values
df_results.loc[backcast_date, "backcast_run_date"] = run_date
df_results.loc[backcast_date, "backcast_median"] = float(numpy.median(backcast))
df_results.loc[backcast_date, ["backcast_hdi_down", "backcast_hdi_up"]] = tuple(arviz.hdi(backcast, hdi_prob=hdi_prob))
df_results.dropna(inplace=True)
for col in ['nowcast_median', 'nowcast_hdi_down', 'nowcast_hdi_up', 'backcast_median', 'backcast_hdi_down', 'backcast_hdi_up']:
df_results[col] = df_results[col].astype(float)
return df_results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment