Skip to content

Instantly share code, notes, and snippets.

@minesh1291
Created April 27, 2024 11:45
Show Gist options
  • Save minesh1291/94ec783cf18d1cc8411c563e9fb24b39 to your computer and use it in GitHub Desktop.
Save minesh1291/94ec783cf18d1cc8411c563e9fb24b39 to your computer and use it in GitHub Desktop.
Distribution Forecasting 
# %%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.signal
from gluonts.dataset.repository import get_dataset, dataset_names
from gluonts.dataset.util import to_pandas
from gluonts.dataset.common import ListDataset
from gluonts.torch import SimpleFeedForwardEstimator
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from gluonts.evaluation import make_evaluation_predictions
# Print available datasets
print(f"Available datasets: {dataset_names}")
# Resample and plot electrocardiogram data
target_arr = scipy.signal.resample(scipy.misc.electrocardiogram()[:360_000], 36_000)
plt.plot(target_arr)
plt.show()
# Define window parameters
window_size = 360
shift_size = window_size // 4
prediction_length = 50
context_length = 250
# Create time series windows
def create_time_series_window(arr, window_size, shift_size):
return np.array([arr[i:i+window_size] for i in range(0, len(arr)-window_size, shift_size)])
target_series_windows = create_time_series_window(
arr=target_arr,
window_size=window_size,
shift_size=shift_size
)
# Define start period
start = pd.Period("01-01-2019", freq="0.02s")
# Create train and test datasets
train_ds = ListDataset(
[{"target": x, "start": start, "item_id": enum} for enum, x in enumerate(10*target_series_windows[:, :-prediction_length])],
freq="0.02s",
)
test_ds = ListDataset(
[{"target": x, "start": start, "item_id": enum} for enum, x in enumerate(10*target_series_windows)],
freq="0.02s"
)
# Define estimator
estimator = SimpleFeedForwardEstimator(
hidden_dimensions=[512, 128, 32, 8],
prediction_length=prediction_length,
context_length=context_length,
num_batches_per_epoch=5,
batch_size=64,
lr=3e-4,
trainer_kwargs=dict(
max_epochs=300,
default_root_dir="/tmp/gluonts",
callbacks=[
EarlyStopping(monitor="val_loss", patience=20, verbose=True),
],
)
)
# Train the estimator
predictor = estimator.train(train_ds, test_ds, num_workers=0)
# Make evaluation predictions
forecast_it, ts_it = make_evaluation_predictions(
dataset=test_ds,
predictor=predictor,
num_samples=100,
)
forecasts = list(forecast_it)
tss = list(ts_it)
# Plot forecasts
for _ in range(5):
index = np.random.randint(0, len(tss))
ts_entry = tss[index]
forecast_entry = forecasts[index]
forecasted_samples = forecast_entry.distribution.sample([100])
plt.plot(ts_entry[-prediction_length*4:].to_timestamp())
forecast_entry.plot(show_label=True)
plt.legend()
plt.show()
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment