This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.distributions import Categorical, Distribution | |
from typing import Optional, List | |
from gluonts.core.component import validated | |
from gluonts.torch.distributions import DistributionOutput | |
class MixtureDistribution(Distribution): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import mxnet as mx | |
from gluonts.dataset.pandas import PandasDataset | |
from gluonts.dataset.split import split | |
from gluonts.mx import DeepAREstimator, Trainer | |
from gluonts.model.predictor import Predictor | |
from pathlib import Path | |
def train_model(): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gluonts.dataset.repository.datasets import get_dataset | |
from gluonts.transform import ExpectedNumInstanceSampler | |
from gluonts.torch.model.deepar import DeepAREstimator | |
import torch | |
dataset = get_dataset("electricity") | |
context_length = 2 * 7 * 24 | |
prediction_length = dataset.metadata.prediction_length | |
model = DeepAREstimator( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gluonts.dataset.repository.datasets import get_dataset | |
from gluonts.torch.model.deepar import DeepARModel, DeepARLightningModule | |
from gluonts.transform import ( | |
AddObservedValuesIndicator, | |
InstanceSplitter, | |
ExpectedNumInstanceSampler, | |
) | |
from gluonts.dataset.field_names import FieldName | |
from gluonts.dataset.loader import TrainDataLoader | |
from gluonts.itertools import Cached |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
import pandas as pd | |
import cudf | |
import numpy as np | |
start = pd.Timestamp(datetime.strptime('2021-03-12 00:00+0000', '%Y-%m-%d %H:%M%z')) | |
end = pd.Timestamp(datetime.strptime('2021-03-12 11:00+0000', '%Y-%m-%d %H:%M%z')) | |
timestamps = pd.date_range(start, end, freq='1H') | |
value = np.random.normal(size=12) | |
df = pd.DataFrame(value, index=timestamps, columns=['value']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gluonts.dataset.repository.datasets import get_dataset | |
from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator | |
from gluonts.model.deepar import DeepAREstimator | |
from gluonts.mx.distribution.gaussian import GaussianOutput | |
from gluonts.mx import Trainer | |
from gluonts.mx.trainer.callback import TrainingHistory | |
from gluonts.evaluation import Evaluator | |
from gluonts.dataset.common import Dataset | |
from gluonts.mx import copy_parameters | |
from gluonts.model.predictor import Predictor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gluonts.dataset.repository.datasets import get_dataset | |
from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator | |
from gluonts.model.deepar import DeepAREstimator | |
from gluonts.mx.distribution.gaussian import GaussianOutput | |
from gluonts.mx import Trainer | |
from gluonts.mx.trainer.callback import TrainingHistory | |
from gluonts.evaluation import Evaluator | |
from gluonts.dataset.common import Dataset | |
from gluonts.mx import copy_parameters | |
from gluonts.model.predictor import Predictor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import glob | |
import numpy as np | |
import cupy as cp | |
import imageio | |
from random import shuffle | |
from nvidia.dali import Pipeline | |
import nvidia.dali.fn as fn | |
import nvidia.dali.plugin.tf as dali_tf | |
import tensorflow as tf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import glob | |
import numpy as np | |
import cupy as cp | |
import imageio | |
from random import shuffle | |
from nvidia.dali import Pipeline | |
import nvidia.dali.fn as fn | |
import nvidia.dali.plugin.tf as dali_tf | |
import tensorflow as tf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import glob | |
import numpy as np | |
import cupy as cp | |
import imageio | |
from random import shuffle | |
from nvidia.dali import Pipeline | |
import nvidia.dali.fn as fn | |
import nvidia.dali.plugin.tf as dali_tf | |
import tensorflow as tf |
NewerOlder