-
-
Save ksadov/3c1e62fc09833c235991b52f227d697c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from curses import window | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from datetime import datetime | |
from statsmodels.tsa.seasonal import seasonal_decompose | |
from statsmodels.tsa.statespace.sarimax import SARIMAX | |
from scipy.stats import norm | |
from pmdarima import auto_arima | |
""" | |
My best attempt at forecasting https://realtime.org/data/nyc-rat-index using SARIMA, | |
for https://manifold.markets/strutheo/will-the-nyc-rat-index-be-65-or-hig-pfs413f8ws. | |
Running with daily aggregation cooks my laptop so I made my bet based on weekly aggregation. | |
download the CSV here: https://data.cityofnewyork.us/Social-Services/Rat-Sightings/3q43-55fe/about_data | |
and save it as Rat_Sightings.csv in the same directory as this script. | |
""" | |
def remove_last_outliers(df, z_threshold): | |
# SARIMA is sensitive to outliers on the last value of the series | |
# I think it should be ok to just... smooth it out | |
# If you disagree than modify my code and go bet against me on Manifold | |
last_30_mean = df[-30:].mean() | |
last_30_std = df[-30:].std() | |
last_value = df[-1] | |
# Check if last value is an outlier (more than 2 std from 30-day mean) | |
is_outlier = abs(last_value - last_30_mean) > z_threshold * last_30_std | |
# If last value is an outlier, use the mean of the last 7 days excluding the outlier | |
if is_outlier: | |
print(f"Note: Last value ({last_value:.1f}) appears to be an outlier") | |
print(f"30-day mean: {last_30_mean:.1f}, 30-day std: {last_30_std:.1f}") | |
adjusted_series = df.copy() | |
adjusted_series[-1] = df[-7:-1].mean() | |
print(f"Adjusted last value: {adjusted_series[-1]:.1f}") | |
else: | |
adjusted_series = df | |
print("No outliers detected in last 30 days") | |
df = adjusted_series | |
return df | |
def plot_seasonally_adjusted_rat_sightings( | |
csv_file, threshold=65, aggregation="weekly" | |
): | |
df = pd.read_csv(csv_file) | |
df["Created Date"] = pd.to_datetime(df["Created Date"]) | |
df = df[df["Created Date"].dt.year >= 2020] | |
daily_sightings = ( | |
df.groupby(df["Created Date"].dt.date) | |
.size() | |
.reset_index(name="Number of Sightings") | |
) | |
daily_sightings["Date"] = pd.to_datetime(daily_sightings["Created Date"]) | |
daily_sightings.drop(columns=["Created Date"], inplace=True) | |
daily_sightings.set_index("Date", inplace=True) | |
if aggregation == "daily": | |
# https://realtime.org/data/nyc-rat-index uses a 28-day rolling average | |
agg_data = ( | |
daily_sightings["Number of Sightings"] | |
.rolling(window=28, min_periods=1) | |
.mean() | |
) | |
seasonal_period = 365 | |
elif aggregation == "weekly": | |
agg_data = daily_sightings["Number of Sightings"].resample("W") | |
# approximate the 28-day rolling average | |
agg_data = agg_data.mean().rolling(window=4, min_periods=1).mean() | |
seasonal_period = 52 | |
elif aggregation == "monthly": | |
agg_data = daily_sightings["Number of Sightings"].resample("M").mean() | |
seasonal_period = 12 | |
else: | |
raise ValueError("Aggregation must be one of: daily, weekly, monthly.") | |
decomposition = seasonal_decompose( | |
agg_data.dropna(), # dropna in case early periods are incomplete | |
period=seasonal_period, | |
extrapolate_trend="freq", | |
) | |
seasonally_adjusted = agg_data - decomposition.seasonal | |
seasonally_adjusted = remove_last_outliers(seasonally_adjusted, 2.0) | |
seasonally_adjusted = seasonally_adjusted.dropna() | |
print("\nFitting auto_arima to find best model parameters...") | |
auto_arima_model = auto_arima( | |
seasonally_adjusted, | |
start_p=0, | |
start_q=0, | |
max_p=5, | |
max_q=5, | |
start_P=0, | |
max_P=5, | |
max_Q=5, | |
m=seasonal_period, | |
seasonal=True, | |
d=None, | |
D=None, | |
trace=True, | |
error_action="ignore", | |
suppress_warnings=True, | |
stepwise=True, | |
) | |
order = auto_arima_model.order | |
seasonal_order = auto_arima_model.seasonal_order | |
print(f"Best model: ARIMA{order}x{seasonal_order}") | |
sarima_model = SARIMAX( | |
seasonally_adjusted, | |
order=order, | |
seasonal_order=seasonal_order, | |
enforce_stationarity=False, | |
enforce_invertibility=False, | |
) | |
sarima_results = sarima_model.fit(disp=False) | |
last_date_in_data = seasonally_adjusted.index[-1] | |
end_of_2025 = pd.to_datetime("2025-12-31") | |
if aggregation == "daily": | |
steps_ahead = (end_of_2025 - last_date_in_data).days | |
elif aggregation == "weekly": | |
steps_ahead = (end_of_2025 - last_date_in_data).days // 7 + 1 | |
else: | |
steps_ahead = (end_of_2025.year - last_date_in_data.year) * 12 + ( | |
end_of_2025.month - last_date_in_data.month | |
) | |
forecast_results = sarima_results.get_forecast(steps=steps_ahead) | |
forecast_mean = forecast_results.predicted_mean | |
forecast_ci = forecast_results.conf_int() | |
final_forecast_value = forecast_mean.iloc[-1] | |
final_forecast_se = forecast_results.se_mean.iloc[-1] | |
z = (threshold - final_forecast_value) / final_forecast_se | |
p_exceed = 1 - norm.cdf(z) | |
print(f"\n--- Aggregation: {aggregation} ---") | |
print( | |
f"Forecast for {forecast_mean.index[-1].date()} is {final_forecast_value:.2f}" | |
) | |
print(f"Std error of forecast: {final_forecast_se:.2f}") | |
print(f"Probability Rat Index > {threshold} at end of 2025: {100*p_exceed:.2f}%\n") | |
plt.figure(figsize=(12, 6)) | |
plt.plot( | |
seasonally_adjusted.index, | |
seasonally_adjusted, | |
label="Historical (SA)", | |
color="blue", | |
) | |
plt.plot(forecast_mean.index, forecast_mean, label="Forecast", color="red") | |
lower_col = forecast_ci.columns[0] | |
upper_col = forecast_ci.columns[1] | |
plt.fill_between( | |
forecast_mean.index, | |
forecast_ci[lower_col], | |
forecast_ci[upper_col], | |
color="pink", | |
alpha=0.3, | |
label="95% CI", | |
) | |
plt.title(f"NYC Rat Sightings (SA, {aggregation.capitalize()}) + SARIMA Forecast") | |
plt.xlabel("Date") | |
plt.ylabel("Seasonally Adjusted Sightings") | |
plt.legend() | |
plt.grid(True, linestyle="--", alpha=0.7) | |
plt.tight_layout() | |
plt.show() | |
return seasonally_adjusted, forecast_mean, p_exceed | |
def main(): | |
parser = argparse.ArgumentParser( | |
description="Plot and forecast seasonally-adjusted NYC Rat Sightings." | |
) | |
parser.add_argument( | |
"--csv", | |
type=str, | |
default="Rat_Sightings.csv", | |
help="Path to the Rat Sightings CSV file.", | |
) | |
parser.add_argument( | |
"--threshold", | |
type=float, | |
default=65, | |
help="Threshold for computing the probability of exceeding this rat index.", | |
) | |
parser.add_argument( | |
"--aggregation", | |
choices=["daily", "weekly", "monthly"], | |
default="weekly", | |
help="Choose the frequency for time-series aggregation (daily, weekly, monthly).", | |
) | |
args = parser.parse_args() | |
plot_seasonally_adjusted_rat_sightings( | |
csv_file=args.csv, threshold=args.threshold, aggregation=args.aggregation | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment