Skip to content

Instantly share code, notes, and snippets.

@ksadov
Created January 5, 2025 04:12
Show Gist options
  • Save ksadov/3c1e62fc09833c235991b52f227d697c to your computer and use it in GitHub Desktop.
Save ksadov/3c1e62fc09833c235991b52f227d697c to your computer and use it in GitHub Desktop.
import argparse
from curses import window
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tsa.statespace.sarimax import SARIMAX
from scipy.stats import norm
from pmdarima import auto_arima
"""
My best attempt at forecasting https://realtime.org/data/nyc-rat-index using SARIMA,
for https://manifold.markets/strutheo/will-the-nyc-rat-index-be-65-or-hig-pfs413f8ws.
Running with daily aggregation cooks my laptop so I made my bet based on weekly aggregation.
download the CSV here: https://data.cityofnewyork.us/Social-Services/Rat-Sightings/3q43-55fe/about_data
and save it as Rat_Sightings.csv in the same directory as this script.
"""
def remove_last_outliers(df, z_threshold):
# SARIMA is sensitive to outliers on the last value of the series
# I think it should be ok to just... smooth it out
# If you disagree than modify my code and go bet against me on Manifold
last_30_mean = df[-30:].mean()
last_30_std = df[-30:].std()
last_value = df[-1]
# Check if last value is an outlier (more than 2 std from 30-day mean)
is_outlier = abs(last_value - last_30_mean) > z_threshold * last_30_std
# If last value is an outlier, use the mean of the last 7 days excluding the outlier
if is_outlier:
print(f"Note: Last value ({last_value:.1f}) appears to be an outlier")
print(f"30-day mean: {last_30_mean:.1f}, 30-day std: {last_30_std:.1f}")
adjusted_series = df.copy()
adjusted_series[-1] = df[-7:-1].mean()
print(f"Adjusted last value: {adjusted_series[-1]:.1f}")
else:
adjusted_series = df
print("No outliers detected in last 30 days")
df = adjusted_series
return df
def plot_seasonally_adjusted_rat_sightings(
csv_file, threshold=65, aggregation="weekly"
):
df = pd.read_csv(csv_file)
df["Created Date"] = pd.to_datetime(df["Created Date"])
df = df[df["Created Date"].dt.year >= 2020]
daily_sightings = (
df.groupby(df["Created Date"].dt.date)
.size()
.reset_index(name="Number of Sightings")
)
daily_sightings["Date"] = pd.to_datetime(daily_sightings["Created Date"])
daily_sightings.drop(columns=["Created Date"], inplace=True)
daily_sightings.set_index("Date", inplace=True)
if aggregation == "daily":
# https://realtime.org/data/nyc-rat-index uses a 28-day rolling average
agg_data = (
daily_sightings["Number of Sightings"]
.rolling(window=28, min_periods=1)
.mean()
)
seasonal_period = 365
elif aggregation == "weekly":
agg_data = daily_sightings["Number of Sightings"].resample("W")
# approximate the 28-day rolling average
agg_data = agg_data.mean().rolling(window=4, min_periods=1).mean()
seasonal_period = 52
elif aggregation == "monthly":
agg_data = daily_sightings["Number of Sightings"].resample("M").mean()
seasonal_period = 12
else:
raise ValueError("Aggregation must be one of: daily, weekly, monthly.")
decomposition = seasonal_decompose(
agg_data.dropna(), # dropna in case early periods are incomplete
period=seasonal_period,
extrapolate_trend="freq",
)
seasonally_adjusted = agg_data - decomposition.seasonal
seasonally_adjusted = remove_last_outliers(seasonally_adjusted, 2.0)
seasonally_adjusted = seasonally_adjusted.dropna()
print("\nFitting auto_arima to find best model parameters...")
auto_arima_model = auto_arima(
seasonally_adjusted,
start_p=0,
start_q=0,
max_p=5,
max_q=5,
start_P=0,
max_P=5,
max_Q=5,
m=seasonal_period,
seasonal=True,
d=None,
D=None,
trace=True,
error_action="ignore",
suppress_warnings=True,
stepwise=True,
)
order = auto_arima_model.order
seasonal_order = auto_arima_model.seasonal_order
print(f"Best model: ARIMA{order}x{seasonal_order}")
sarima_model = SARIMAX(
seasonally_adjusted,
order=order,
seasonal_order=seasonal_order,
enforce_stationarity=False,
enforce_invertibility=False,
)
sarima_results = sarima_model.fit(disp=False)
last_date_in_data = seasonally_adjusted.index[-1]
end_of_2025 = pd.to_datetime("2025-12-31")
if aggregation == "daily":
steps_ahead = (end_of_2025 - last_date_in_data).days
elif aggregation == "weekly":
steps_ahead = (end_of_2025 - last_date_in_data).days // 7 + 1
else:
steps_ahead = (end_of_2025.year - last_date_in_data.year) * 12 + (
end_of_2025.month - last_date_in_data.month
)
forecast_results = sarima_results.get_forecast(steps=steps_ahead)
forecast_mean = forecast_results.predicted_mean
forecast_ci = forecast_results.conf_int()
final_forecast_value = forecast_mean.iloc[-1]
final_forecast_se = forecast_results.se_mean.iloc[-1]
z = (threshold - final_forecast_value) / final_forecast_se
p_exceed = 1 - norm.cdf(z)
print(f"\n--- Aggregation: {aggregation} ---")
print(
f"Forecast for {forecast_mean.index[-1].date()} is {final_forecast_value:.2f}"
)
print(f"Std error of forecast: {final_forecast_se:.2f}")
print(f"Probability Rat Index > {threshold} at end of 2025: {100*p_exceed:.2f}%\n")
plt.figure(figsize=(12, 6))
plt.plot(
seasonally_adjusted.index,
seasonally_adjusted,
label="Historical (SA)",
color="blue",
)
plt.plot(forecast_mean.index, forecast_mean, label="Forecast", color="red")
lower_col = forecast_ci.columns[0]
upper_col = forecast_ci.columns[1]
plt.fill_between(
forecast_mean.index,
forecast_ci[lower_col],
forecast_ci[upper_col],
color="pink",
alpha=0.3,
label="95% CI",
)
plt.title(f"NYC Rat Sightings (SA, {aggregation.capitalize()}) + SARIMA Forecast")
plt.xlabel("Date")
plt.ylabel("Seasonally Adjusted Sightings")
plt.legend()
plt.grid(True, linestyle="--", alpha=0.7)
plt.tight_layout()
plt.show()
return seasonally_adjusted, forecast_mean, p_exceed
def main():
parser = argparse.ArgumentParser(
description="Plot and forecast seasonally-adjusted NYC Rat Sightings."
)
parser.add_argument(
"--csv",
type=str,
default="Rat_Sightings.csv",
help="Path to the Rat Sightings CSV file.",
)
parser.add_argument(
"--threshold",
type=float,
default=65,
help="Threshold for computing the probability of exceeding this rat index.",
)
parser.add_argument(
"--aggregation",
choices=["daily", "weekly", "monthly"],
default="weekly",
help="Choose the frequency for time-series aggregation (daily, weekly, monthly).",
)
args = parser.parse_args()
plot_seasonally_adjusted_rat_sightings(
csv_file=args.csv, threshold=args.threshold, aggregation=args.aggregation
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment