Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active May 29, 2020 02:09
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/f38f8774b6f97977b660d20dfa0f0036 to your computer and use it in GitHub Desktop.
Save wassname/f38f8774b6f97977b660d20dfa0f0036 to your computer and use it in GitHub Desktop.
Cache a torch dataset to npy files using dask
"""
Cache a torch dataset to npy files using dask
url:https://gist.github.com/wassname/f38f8774b6f97977b660d20dfa0f0036
lic:MIT
author:wassname
usage:
batch_size=16
chunk_size=batch_size*4
num_workers=4
cache_base_path = Path('/tmp/')
dataset_train = NumpyDataset(np.zeros((4,16,3), np.zeros(4, 1))
dataset_train_cached, loader_train = cache2loader(
dataset_train,
chunk_size=chunk_size,
num_workers=num_workers,
cache_base_path=cache_base_path,
shuffle=True,
batch_size=batch_size
)
dataset_test = NumpyDataset(np.zeros((4,16,3), np.zeros(4, 1))
dataset_test_cached, loader_test = cache2loader(
dataset_test,
chunk_size=chunk_size,
num_workers=num_workers,
cache_base_path=cache_base_path,
shuffle=False,
batch_size=batch_size
)
"""
import dask
import dask.array as da
import torch.utils.data.sampler
import numpy as np
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
def to_hash(obj):
"""Hash most python objects to int and persists between sessions"""
s = pd.io.json.dumps(obj).encode('utf-8')
m = hashlib.md5(s)
return int(m.hexdigest(), 16) % 10**8
# Test
assert to_hash('a') == to_hash('a')
assert isinstance(to_hash([1,2]), int)
assert to_hash([pd.Timedelta('1h'), pd.Timestamp('2019'), {}, ['a'], np.array([1,2,])])
class ConcatNumpyDataset(torch.utils.data.ConcatDataset):
def unique_keys(self):
return itertools.chain(*[d.unique_keys() for d in self.datasets])
def __add__(self, other):
return ConcatNumpyDataset([self, other])
class NumpyDataset(torch.utils.data.Dataset):
"""Dataset wrapping arrays.
Each sample will be retrieved by indexing array along the first dimension.
Arguments:
*arrays (numpy.array): arrays that have the same size of the first dimension.
"""
def __init__(self, *arrays):
assert all(arrays[0].shape[0] == array.shape[0] for array in arrays)
self.arrays = arrays
def __getitem__(self, index):
return tuple(array[index] for array in self.arrays)
def unique_keys(self):
"""
This returns an int that is a hash of the input data and args.
Needed to invalidate cache.
"""
return to_hash(self.arrays)
def __add__(self, other):
return ConcatNumpyDataset([self, other])
def __len__(self):
return self.arrays[0].shape[0]
def write_info(
dataset_length: int, x: np.array, dirname: Path, chunk_size: int
) -> dict:
"""Write info similar to dask.array.to_npy_stack"""
chunks = (chunk_size,) * (dataset_length // chunk_size)
chunks = (chunks,) + tuple((s,) for s in x.shape)
dirname.mkdir(exist_ok=True)
meta = {"chunks": chunks, "dtype": x.dtype, "axis": 0}
with dirname.joinpath("info").open("wb") as f:
pickle.dump(meta, f)
return meta
def to_hash(obj):
"""Hash most python objects that persists between sessions"""
s = pd.io.json.dumps(obj).encode("utf-8")
m = hashlib.md5(s)
# TEST should work for pandas timestamp, numpy, set etc
return int(m.hexdigest(), 16) % 10 ** 8
class DaskDataset(torch.utils.data.Dataset):
r"""Dataset wrapping dask arrays.
Each sample will be retrieved by indexing tensors along the first dimension.
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
def __init__(self, *tensors):
assert all(tensors[0].shape[0] == tensor.shape[0] for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
# return tuple(torch.from_numpy(tensor[index].compute()) for tensor in self.tensors)
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].shape[0]
def __repr__(self):
return f"{type(self)}(len={len(self)})"
# Write it so we can load it using dask.from_npy
def cache_dataset(
dataset_cache: torch.utils.data.Dataset,
cache_base_path: Path,
chunk_size: int = 16,
num_workers=0,
shuffle=False,
):
"""
Takes torch dataset, caches it in chunks to npz files, and loads into a dataset using dask
"""
np_data_cache_x = cache_base_path.joinpath(
f"np_data_cache_x_{to_hash(dataset_cache.unique_keys())}"
)
np_data_cache_y = cache_base_path.joinpath(
f"np_data_cache_y_{to_hash(dataset_cache.unique_keys())}"
)
x, y = dataset_cache[0]
if not np_data_cache_x.exists():
write_info(
len(dataset_cache), x.numpy(), dirname=np_data_cache_x, chunk_size=chunk_size
)
write_info(
len(dataset_cache), y.numpy(), dirname=np_data_cache_y, chunk_size=chunk_size
)
loader_cache = torch.utils.data.DataLoader(
dataset_cache,
batch_size=chunk_size,
shuffle=shuffle,
pin_memory=False,
drop_last=False,
num_workers=num_workers,
)
logger.info(f"Saving to cache exists {np_data_cache_x}")
for i, (x_batch, y_batch) in enumerate(
tqdm(loader_cache, desc="Caching dataloader")
):
np.save(np_data_cache_x.joinpath(f"{i}.npy").open("wb"), x_batch.numpy())
np.save(np_data_cache_y.joinpath(f"{i}.npy").open("wb"), y_batch.numpy())
else:
logger.info(f"Cache already exists {np_data_cache_x}")
return DaskDataset(
da.from_npy_stack(np_data_cache_x), da.from_npy_stack(np_data_cache_y)
)
class SequenceInChunkSampler(torch.utils.data.sampler.Sampler):
"""
Samples sequences of elements sequentially, but random sequences in a chunk.
Arguments:
data_source (Dataset): dataset to sample from
seq_len (int): length of sequential sequences
chunksize (int): length of cached data to take random sequences from
url: https://gist.github.com/wassname/8ae1f64389c2aaceeb84fcd34c3651c3
"""
def __init__(self, data_source, seq_len=6, chunksize=6000):
assert chunksize % seq_len == 0, "chunk size should be a multiple of seq_len"
assert len(data_source) > chunksize
self.data_source = data_source
self.seq_len = seq_len
self.chunksize = chunksize
def __iter__(self):
chunk_idxs = np.arange(0, len(self.data_source), self.chunksize)
max_i = len(self.data_source)
print('max_i', max_i)
for chunk_idx in chunk_idxs:
seqs = np.arange(
chunk_idx, min(chunk_idx + self.chunksize, max_i), self.seq_len
)
np.random.shuffle(seqs)
for seq_i in seqs:
for i in np.arange(seq_i, min(seq_i + self.seq_len, max_i)):
yield i
def __len__(self):
return len(self.data_source)
def collate_fns(scheduler=None):
def collate_fn(batch):
"""Collare uncomputed dask arrays."""
x = da.stack([x for x, y in batch], 0).compute(scheduler=scheduler)
y = da.stack([y for x, y in batch], 0).compute(scheduler=scheduler)
x = torch.from_numpy(x).float()
y = torch.from_numpy(y).float()
return x, y
return collate_fn
def cache2loader(
dataset: torch.utils.data.Dataset,
cache_base_path: Path,
chunk_size: int = 16,
num_workers=0,
shuffle=False,
batch_size: int = 16,
):
dataset_cached = cache_dataset(
dataset,
chunk_size,
num_workers,
cache_base_path,
shuffle)
loader = DataLoader(
dataset_cached,
batch_size,
pin_memory=True,
drop_last=True,
sampler=SequenceInChunkSampler(dataset_cached, seq_len=batch_size, chunksize=batch_size*4) if shuffle else None,
collate_fn=collate_fns(scheduler='synchronous' if num_workers else None),
num_workers=num_workers
)
return dataset_cached, loader
@wassname
Copy link
Author

Notes: it may be easier to write xarray with named coordinates. This can make plotting, visualization, and custom metrics easier later on.

@wassname
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment