Skip to content

Instantly share code, notes, and snippets.

@ksadov
Created January 5, 2025 04:12
import argparse
from curses import window
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tsa.statespace.sarimax import SARIMAX
from scipy.stats import norm
from pmdarima import auto_arima
"""
My best attempt at forecasting https://realtime.org/data/nyc-rat-index using SARIMA,
for https://manifold.markets/strutheo/will-the-nyc-rat-index-be-65-or-hig-pfs413f8ws.
Running with daily aggregation cooks my laptop so I made my bet based on weekly aggregation.
download the CSV here: https://data.cityofnewyork.us/Social-Services/Rat-Sightings/3q43-55fe/about_data
and save it as Rat_Sightings.csv in the same directory as this script.
"""
def remove_last_outliers(df, z_threshold):
# SARIMA is sensitive to outliers on the last value of the series
# I think it should be ok to just... smooth it out
# If you disagree than modify my code and go bet against me on Manifold
last_30_mean = df[-30:].mean()
last_30_std = df[-30:].std()
last_value = df[-1]
# Check if last value is an outlier (more than 2 std from 30-day mean)
is_outlier = abs(last_value - last_30_mean) > z_threshold * last_30_std
# If last value is an outlier, use the mean of the last 7 days excluding the outlier
if is_outlier:
print(f"Note: Last value ({last_value:.1f}) appears to be an outlier")
print(f"30-day mean: {last_30_mean:.1f}, 30-day std: {last_30_std:.1f}")
adjusted_series = df.copy()
adjusted_series[-1] = df[-7:-1].mean()
print(f"Adjusted last value: {adjusted_series[-1]:.1f}")
else:
adjusted_series = df
print("No outliers detected in last 30 days")
df = adjusted_series
return df
def plot_seasonally_adjusted_rat_sightings(
csv_file, threshold=65, aggregation="weekly"
):
df = pd.read_csv(csv_file)
df["Created Date"] = pd.to_datetime(df["Created Date"])
df = df[df["Created Date"].dt.year >= 2020]
daily_sightings = (
df.groupby(df["Created Date"].dt.date)
.size()
.reset_index(name="Number of Sightings")
)
daily_sightings["Date"] = pd.to_datetime(daily_sightings["Created Date"])
daily_sightings.drop(columns=["Created Date"], inplace=True)
daily_sightings.set_index("Date", inplace=True)
if aggregation == "daily":
# https://realtime.org/data/nyc-rat-index uses a 28-day rolling average
agg_data = (
daily_sightings["Number of Sightings"]
.rolling(window=28, min_periods=1)
.mean()
)
seasonal_period = 365
elif aggregation == "weekly":
agg_data = daily_sightings["Number of Sightings"].resample("W")
# approximate the 28-day rolling average
agg_data = agg_data.mean().rolling(window=4, min_periods=1).mean()
seasonal_period = 52
elif aggregation == "monthly":
agg_data = daily_sightings["Number of Sightings"].resample("M").mean()
seasonal_period = 12
else:
raise ValueError("Aggregation must be one of: daily, weekly, monthly.")
decomposition = seasonal_decompose(
agg_data.dropna(), # dropna in case early periods are incomplete
period=seasonal_period,
extrapolate_trend="freq",
)
seasonally_adjusted = agg_data - decomposition.seasonal
seasonally_adjusted = remove_last_outliers(seasonally_adjusted, 2.0)
seasonally_adjusted = seasonally_adjusted.dropna()
print("\nFitting auto_arima to find best model parameters...")
auto_arima_model = auto_arima(
seasonally_adjusted,
start_p=0,
start_q=0,
max_p=5,
max_q=5,
start_P=0,
max_P=5,
max_Q=5,
m=seasonal_period,
seasonal=True,
d=None,
D=None,
trace=True,
error_action="ignore",
suppress_warnings=True,
stepwise=True,
)
order = auto_arima_model.order
seasonal_order = auto_arima_model.seasonal_order
print(f"Best model: ARIMA{order}x{seasonal_order}")
sarima_model = SARIMAX(
seasonally_adjusted,
order=order,
seasonal_order=seasonal_order,
enforce_stationarity=False,
enforce_invertibility=False,
)
sarima_results = sarima_model.fit(disp=False)
last_date_in_data = seasonally_adjusted.index[-1]
end_of_2025 = pd.to_datetime("2025-12-31")
if aggregation == "daily":
steps_ahead = (end_of_2025 - last_date_in_data).days
elif aggregation == "weekly":
steps_ahead = (end_of_2025 - last_date_in_data).days // 7 + 1
else:
steps_ahead = (end_of_2025.year - last_date_in_data.year) * 12 + (
end_of_2025.month - last_date_in_data.month
)
forecast_results = sarima_results.get_forecast(steps=steps_ahead)
forecast_mean = forecast_results.predicted_mean
forecast_ci = forecast_results.conf_int()
final_forecast_value = forecast_mean.iloc[-1]
final_forecast_se = forecast_results.se_mean.iloc[-1]
z = (threshold - final_forecast_value) / final_forecast_se
p_exceed = 1 - norm.cdf(z)
print(f"\n--- Aggregation: {aggregation} ---")
print(
f"Forecast for {forecast_mean.index[-1].date()} is {final_forecast_value:.2f}"
)
print(f"Std error of forecast: {final_forecast_se:.2f}")
print(f"Probability Rat Index > {threshold} at end of 2025: {100*p_exceed:.2f}%\n")
plt.figure(figsize=(12, 6))
plt.plot(
seasonally_adjusted.index,
seasonally_adjusted,
label="Historical (SA)",
color="blue",
)
plt.plot(forecast_mean.index, forecast_mean, label="Forecast", color="red")
lower_col = forecast_ci.columns[0]
upper_col = forecast_ci.columns[1]
plt.fill_between(
forecast_mean.index,
forecast_ci[lower_col],
forecast_ci[upper_col],
color="pink",
alpha=0.3,
label="95% CI",
)
plt.title(f"NYC Rat Sightings (SA, {aggregation.capitalize()}) + SARIMA Forecast")
plt.xlabel("Date")
plt.ylabel("Seasonally Adjusted Sightings")
plt.legend()
plt.grid(True, linestyle="--", alpha=0.7)
plt.tight_layout()
plt.show()
return seasonally_adjusted, forecast_mean, p_exceed
def main():
parser = argparse.ArgumentParser(
description="Plot and forecast seasonally-adjusted NYC Rat Sightings."
)
parser.add_argument(
"--csv",
type=str,
default="Rat_Sightings.csv",
help="Path to the Rat Sightings CSV file.",
)
parser.add_argument(
"--threshold",
type=float,
default=65,
help="Threshold for computing the probability of exceeding this rat index.",
)
parser.add_argument(
"--aggregation",
choices=["daily", "weekly", "monthly"],
default="weekly",
help="Choose the frequency for time-series aggregation (daily, weekly, monthly).",
)
args = parser.parse_args()
plot_seasonally_adjusted_rat_sightings(
csv_file=args.csv, threshold=args.threshold, aggregation=args.aggregation
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment