Skip to content

Instantly share code, notes, and snippets.

@KeAWang
Last active September 2, 2022 05:53
Show Gist options
  • Save KeAWang/c77ec93eb5263b2d3726a90ba01dc94b to your computer and use it in GitHub Desktop.
Save KeAWang/c77ec93eb5263b2d3726a90ba01dc94b to your computer and use it in GitHub Desktop.
Pytorch timeseries Dataset and rolling windows Dataset
import torch
from torch.utils.data import Dataset
class TimeSeriesDataset(Dataset):
def __init__(
self,
ts: torch.Tensor,
x_ts: torch.Tensor,
normalize=True,
x_mean=None,
x_std=None,
jitter=1e-12,
):
"""
Takes in timestamps and trajectories. Will normalize if set to `True`. Uses precomputed mean and std if specified.
"""
N, T, _ = x_ts.shape
assert ts.shape == (N, T)
# compute mean and std if not specified
if x_mean is None:
x_mean = x_ts.mean((0, 1))
if x_std is None:
x_std = x_ts.std((0, 1))
if normalize:
x_ts = (x_ts - x_mean) / (x_std + jitter)
self.ts = ts
self.x_ts = x_ts
self.normalize = normalize
self.x_mean = x_mean
self.x_std = x_std
self.jitter = jitter
def __repr__(self):
kws = [
f"{key}={tnsr.shape}" if torch.is_tensor(tnsr) else f"{key}={tnsr}"
for key, tnsr in self.__dict__.items()
]
return "{}({})".format(type(self).__name__, ", ".join(kws))
def __getitem__(self, index):
return (self.ts[index], self.x_ts[index])
def __len__(self):
return self.ts.shape[0]
def unnormalize(self):
if not self.normalize:
return self.ts, self.x_ts
x_ts = self.x_ts * (self.x_std + self.jitter) + self.x_mean
return self.ts, x_ts
class RollingWindowDataset(Dataset):
def __init__(self, ts_dataset: TimeSeriesDataset, window_size: int, stride: int):
N, T, D = ts_dataset.x_ts.shape
# compute size of new tensor
windows_per_seq = 1 + (T - window_size) // stride
num_windows = N * windows_per_seq
chunked_shape = (N, windows_per_seq, window_size, D)
self.ts_dataset = ts_dataset
self.N = N
self.T = T
self.window_size = window_size
self.stride = stride
self.windows_per_seq = windows_per_seq
self.num_windows = num_windows
self.chunked_shape = chunked_shape
def __getitem__(self, index):
assert index < self.num_windows
# Convert index over all windows to (seq_index, window_index)
# index = i * windows_per_seq + j, where 0<=i<N and 0<=j<windows_per_seq
i = index // self.windows_per_seq # seq_index
j = index % self.windows_per_seq # window_index
# equivalent to `i, j = np.unravel_index(index, (self.N, self.windows_per_seq))`
start = j * self.stride
end = start + self.window_size
t = self.ts_dataset.ts[i, start:end] # (window_size,)
xz_t = self.ts_dataset.x_ts[i, start:end] # (window_size, D)
return t, xz_t
def __len__(self):
return self.num_windows
if __name__ == "__main__":
N, T = 4, 5
window_size, stride = 3, 2
ts = torch.arange(N * T).reshape(N, T)
x_ts = torch.arange(N * T * 2, dtype=torch.float).reshape(N, T, 2)
dataset = RollingWindowDataset(
TimeSeriesDataset(ts, x_ts, normalize=False),
window_size=window_size,
stride=stride,
)
last_ts, last_window = dataset[len(dataset) - 1]
assert torch.equal(last_ts, ts[-1, -window_size:])
assert torch.equal(last_window, x_ts[-1, -window_size:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment