Created
April 27, 2024 11:45
-
-
Save minesh1291/94ec783cf18d1cc8411c563e9fb24b39 to your computer and use it in GitHub Desktop.
Distribution Forecasting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import scipy.signal | |
from gluonts.dataset.repository import get_dataset, dataset_names | |
from gluonts.dataset.util import to_pandas | |
from gluonts.dataset.common import ListDataset | |
from gluonts.torch import SimpleFeedForwardEstimator | |
from lightning.pytorch.callbacks.early_stopping import EarlyStopping | |
from gluonts.evaluation import make_evaluation_predictions | |
# Print available datasets | |
print(f"Available datasets: {dataset_names}") | |
# Resample and plot electrocardiogram data | |
target_arr = scipy.signal.resample(scipy.misc.electrocardiogram()[:360_000], 36_000) | |
plt.plot(target_arr) | |
plt.show() | |
# Define window parameters | |
window_size = 360 | |
shift_size = window_size // 4 | |
prediction_length = 50 | |
context_length = 250 | |
# Create time series windows | |
def create_time_series_window(arr, window_size, shift_size): | |
return np.array([arr[i:i+window_size] for i in range(0, len(arr)-window_size, shift_size)]) | |
target_series_windows = create_time_series_window( | |
arr=target_arr, | |
window_size=window_size, | |
shift_size=shift_size | |
) | |
# Define start period | |
start = pd.Period("01-01-2019", freq="0.02s") | |
# Create train and test datasets | |
train_ds = ListDataset( | |
[{"target": x, "start": start, "item_id": enum} for enum, x in enumerate(10*target_series_windows[:, :-prediction_length])], | |
freq="0.02s", | |
) | |
test_ds = ListDataset( | |
[{"target": x, "start": start, "item_id": enum} for enum, x in enumerate(10*target_series_windows)], | |
freq="0.02s" | |
) | |
# Define estimator | |
estimator = SimpleFeedForwardEstimator( | |
hidden_dimensions=[512, 128, 32, 8], | |
prediction_length=prediction_length, | |
context_length=context_length, | |
num_batches_per_epoch=5, | |
batch_size=64, | |
lr=3e-4, | |
trainer_kwargs=dict( | |
max_epochs=300, | |
default_root_dir="/tmp/gluonts", | |
callbacks=[ | |
EarlyStopping(monitor="val_loss", patience=20, verbose=True), | |
], | |
) | |
) | |
# Train the estimator | |
predictor = estimator.train(train_ds, test_ds, num_workers=0) | |
# Make evaluation predictions | |
forecast_it, ts_it = make_evaluation_predictions( | |
dataset=test_ds, | |
predictor=predictor, | |
num_samples=100, | |
) | |
forecasts = list(forecast_it) | |
tss = list(ts_it) | |
# Plot forecasts | |
for _ in range(5): | |
index = np.random.randint(0, len(tss)) | |
ts_entry = tss[index] | |
forecast_entry = forecasts[index] | |
forecasted_samples = forecast_entry.distribution.sample([100]) | |
plt.plot(ts_entry[-prediction_length*4:].to_timestamp()) | |
forecast_entry.plot(show_label=True) | |
plt.legend() | |
plt.show() | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment