Skip to content

Instantly share code, notes, and snippets.

@lostmsu
Created March 12, 2022 00:46
Show Gist options
  • Save lostmsu/ddc9b274a2bf3f11985b3e55b645058c to your computer and use it in GitHub Desktop.
Save lostmsu/ddc9b274a2bf3f11985b3e55b645058c to your computer and use it in GitHub Desktop.
A infitite PyTorch "dataset", that simulates behavior of digital clock
import torch
from torch import device as Device
from torch.utils.data import IterableDataset
from typing import Optional
MS_IN_SECOND = 1000
MS_IN_MINUTE = 60 * MS_IN_SECOND
MS_IN_HOUR = 60 * MS_IN_MINUTE
MS_IN_DAY = 24 * MS_IN_HOUR
MS_IN_YEAR = 365 * MS_IN_DAY
def _msec_to_clock(msec):
msec_val = msec % 1000
seconds = torch.div(msec, 1000, rounding_mode='floor')
sec_val = seconds % 60
minutes = torch.div(seconds, 60, rounding_mode='floor')
min_val = minutes % 60
hours = torch.div(minutes, 60, rounding_mode='floor')
hour_val = hours % 24
days = torch.div(hours, 24, rounding_mode='floor')
return torch.cat([days, hour_val, min_val, sec_val, msec_val], dim=1)
class Clock(IterableDataset):
def __init__(self, batch_size: int = 1, device: Optional[Device] = None,
min_msec: int = -60 * MS_IN_YEAR,
max_msec: int = 60 * MS_IN_YEAR):
super().__init__()
self.batch_size = batch_size
self.device = device
self.min_msec = min_msec
self.max_msec = max_msec
def sample(self):
msec = torch.randint(self.min_msec, self.max_msec, (self.batch_size, 1), device=self.device, dtype=torch.int64)
expected = _msec_to_clock(msec)
return msec.to(torch.float64) / 1000, expected.to(torch.float32)
def range(self):
return self.min_msec / 1000.0, self.max_msec / 1000.0
def __iter__(self):
while True:
yield self.sample()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment