Skip to content

Instantly share code, notes, and snippets.

@jaheba
Created March 22, 2023 13:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaheba/4a4147036c39eaad13369790d5ae66cc to your computer and use it in GitHub Desktop.
Save jaheba/4a4147036c39eaad13369790d5ae66cc to your computer and use it in GitHub Desktop.
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from dataclasses import dataclass, field
from typing import Optional, List, Union
import numpy as np
from toolz import first, valmap
from gluonts import maybe
from gluonts.core.settings import Settings
from gluonts.itertools import (
Cyclic,
Map,
batcher,
IterableSlice,
select,
)
from gluonts import zebras as zb
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.model.simple_feedforward import (
SimpleFeedForwardLightningModule,
)
@dataclass
class ObservedValuesIndicator:
ref: str
output: str
def __call__(self, frame):
return frame.set_like(
self.ref,
self.output,
np.array(~np.isnan(frame[self.ref]), dtype=np.float32),
)
@dataclass
class InstanceSampler:
n: int
def __call__(self, data, future_length):
max_idx = len(data) - future_length
if max_idx > 0:
return np.random.randint(max_idx, size=self.n)
return []
@dataclass
class InstanceSplitter:
sampler: InstanceSampler
past_length: int
future_length: int
def __call__(self, data):
for element in data:
for split_index in self.sampler(element, self.future_length):
yield element.split(
split_index,
past_length=self.past_length,
future_length=self.future_length,
)
class Env(Settings):
training_sampler: InstanceSampler = InstanceSampler(1)
cache_data: bool = False
env = Env()
class DeepLearningEstimator:
def get_schema(self):
raise NotImplementedError
def create_training_instances(self, dataset):
raise NotImplementedError
def create_validation_instances(self, dataset):
raise NotImplementedError
def training_pipeline(self):
return []
def validation_pipeline(self):
return self.training_pipeline()
def predictor_pipeline(self):
return self.training_pipeline()
def into_batches(self, instances):
batches = batcher(instances, self.batch_size)
batches = Map(zb.batch(type=self.tensor_type), batches)
batches = Map(lambda batch: batch.as_dict(), batches)
return batches
@env._inject("cache_data")
def training_dataloader(self, training_data, cache_data: bool = False):
schema = self.get_schema()
training_data = Map(
lambda entry: schema.load_timeframe(
entry,
start=entry["start"],
freq="M",
),
training_data,
)
for step in self.training_pipeline():
training_data = Map(step, training_data)
if cache_data:
training_data = list(training_data)
training_data = Cyclic(training_data).stream()
instances = self.create_training_instances(training_data)
batches = self.into_batches(instances)
return IterableSlice(
batches,
getattr(self, "num_batches_per_epoch", None),
)
@env._inject("cache_data")
def validation_dataloader(self, validation_data, cache_data: bool = False):
schema = self.get_schema()
validation_data = Map(schema.load_timeframe, validation_data)
for step in self.validation_pipeline():
validation_data = Map(step, validation_data)
if cache_data:
validation_data = list(validation_data)
instances = self.create_validation_instances(validation_data)
batches = self.into_batches(instances)
return IterableSlice(batches, None)
class SplittingEstimator(DeepLearningEstimator):
@env._inject("training_sampler")
def create_training_instances(
self, dataset, training_sampler=InstanceSampler(1)
):
instance_splitter = InstanceSplitter(
training_sampler,
past_length=self.past_length,
future_length=self.prediction_length,
)
return instance_splitter(dataset)
def create_validation_instances(self, dataset):
return Map(
lambda frame: frame.split(
-self.prediction_length,
past_length=self.past_length,
future_length=self.prediction_length,
),
dataset,
)
def train(self, training_data, validation_data=None):
return self.train_model(
self.training_dataloader(training_data),
maybe.map(validation_data, self.validation_dataloader),
)
def train_model(self, training_data, validation_data):
raise NotImplementedError
import torch
import pytorch_lightning as pl
class TorchEstimator(DeepLearningEstimator):
tensor_type = torch.tensor
def train_model(self, training_data, validation_data):
trainer = pl.Trainer(**self.trainer_kwargs)
training_network = self.create_lightning_module()
trainer.fit(
model=training_network,
train_dataloaders=training_data,
val_dataloaders=validation_data,
)
return self.create_predictor(training_network)
from gluonts.util import lazy_property
@dataclass(eq=False)
class DistributionForecast:
distr_output: DistributionOutput
args: Union[zb.BatchTimeFrame, zb.TimeFrame]
arg_names: list
@lazy_property
def dist(self):
return self.distr_output.distribution(
[self.args.columns[name] for name in self.arg_names],
self.args.static["loc"],
self.args.static["scale"],
)
@property
def mean(self):
return self.args.like(
{"mean": self.dist.mean},
)["mean"]
def __iter__(self):
if isinstance(self.args, zb.BatchTimeFrame):
for args in self.args:
yield DistributionForecast(
self.distr_output, args, self.arg_names
)
else:
yield self
@dataclass
class DistributionPredictor:
network: ...
schema: ...
pipeline: ...
prediction_length: int
past_length: int
distr_output: DistributionOutput
tensor_type: ...
def predict_one(self, data):
return first(self.predict_batch([data]))
def predict_batch(self, data):
data = map(
lambda x: self.schema.load_splitframe(
x,
future_length=self.prediction_length,
freq="M",
start=x["start"],
),
data,
)
data = map(lambda x: x.resize(past_length=self.past_length), data)
for step in self.pipeline:
data = map(step, data)
inputs = zb.batch(list(data), type=torch.tensor)
distr_args, loc, scale = self.network(
**select(
self.network.model.describe_inputs(),
inputs.as_dict(),
)
)
arg_names = list(range(len(distr_args)))
return DistributionForecast(
self.distr_output,
inputs.future.like(
valmap(torch.detach, dict(zip(arg_names, distr_args))),
static={"scale": scale, "loc": loc},
),
arg_names=arg_names,
)
@dataclass
class SimpleFeedForwardEstimator(TorchEstimator, SplittingEstimator):
prediction_length: int
context_length: Optional[int] = None
hidden_dimensions: List[int] = field(default_factory=lambda: [20, 20])
lr: float = 1e-3
weight_decay: float = 1e-8
distr_output: DistributionOutput = StudentTOutput()
loss: DistributionLoss = NegativeLogLikelihood()
batch_norm: bool = False
batch_size: int = 32
num_batches_per_epoch: int = 50
trainer_kwargs: dict = field(default_factory=dict)
def __post_init__(self):
self.trainer_kwargs = dict(
{
"max_epochs": 100,
"gradient_clip_val": 10.0,
},
**self.trainer_kwargs,
)
self.context_length = maybe.unwrap_or(
self.context_length, 10 * self.prediction_length
)
@property
def past_length(self):
return self.context_length
def get_schema(self):
return zb.Schema(
{
"target": zb.Field(ndims=1, tdim=-1, past_only=True),
}
)
def training_pipeline(self):
return [ObservedValuesIndicator("target", "observed_values")]
def predictor_pipeline(self):
return []
def create_lightning_module(self):
return SimpleFeedForwardLightningModule(
loss=self.loss,
lr=self.lr,
weight_decay=self.weight_decay,
model_kwargs={
"prediction_length": self.prediction_length,
"context_length": self.context_length,
"hidden_dimensions": self.hidden_dimensions,
"distr_output": self.distr_output,
"batch_norm": self.batch_norm,
},
)
def create_predictor(self, model):
return DistributionPredictor(
model,
schema=self.get_schema(),
pipeline=self.predictor_pipeline(),
prediction_length=self.prediction_length,
past_length=self.past_length,
distr_output=self.distr_output,
tensor_type=torch.tensor,
)
from gluonts.dataset.repository.datasets import get_dataset
airpassengers = get_dataset("airpassengers")
my_estimator = SimpleFeedForwardEstimator(
prediction_length=8, batch_size=32, trainer_kwargs={"max_epochs": 1}
)
with env._let(cache_data=True):
predictor = my_estimator.train(
airpassengers.train,
airpassengers.test,
)
test_data = list(airpassengers.test)
forecasts = predictor.predict_batch(test_data * 2)
print(forecasts.mean)
# for forecast in forecasts:
# print(forecast.mean)
# print(
# predictor.predict_one(test_data[0]).mean,
# )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment