Last active
May 29, 2020 02:09
-
-
Save wassname/f38f8774b6f97977b660d20dfa0f0036 to your computer and use it in GitHub Desktop.
Cache a torch dataset to npy files using dask
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Cache a torch dataset to npy files using dask | |
url:https://gist.github.com/wassname/f38f8774b6f97977b660d20dfa0f0036 | |
lic:MIT | |
author:wassname | |
usage: | |
batch_size=16 | |
chunk_size=batch_size*4 | |
num_workers=4 | |
cache_base_path = Path('/tmp/') | |
dataset_train = NumpyDataset(np.zeros((4,16,3), np.zeros(4, 1)) | |
dataset_train_cached, loader_train = cache2loader( | |
dataset_train, | |
chunk_size=chunk_size, | |
num_workers=num_workers, | |
cache_base_path=cache_base_path, | |
shuffle=True, | |
batch_size=batch_size | |
) | |
dataset_test = NumpyDataset(np.zeros((4,16,3), np.zeros(4, 1)) | |
dataset_test_cached, loader_test = cache2loader( | |
dataset_test, | |
chunk_size=chunk_size, | |
num_workers=num_workers, | |
cache_base_path=cache_base_path, | |
shuffle=False, | |
batch_size=batch_size | |
) | |
""" | |
import dask | |
import dask.array as da | |
import torch.utils.data.sampler | |
import numpy as np | |
from pathlib import Path | |
import logging | |
logger = logging.getLogger(__name__) | |
def to_hash(obj): | |
"""Hash most python objects to int and persists between sessions""" | |
s = pd.io.json.dumps(obj).encode('utf-8') | |
m = hashlib.md5(s) | |
return int(m.hexdigest(), 16) % 10**8 | |
# Test | |
assert to_hash('a') == to_hash('a') | |
assert isinstance(to_hash([1,2]), int) | |
assert to_hash([pd.Timedelta('1h'), pd.Timestamp('2019'), {}, ['a'], np.array([1,2,])]) | |
class ConcatNumpyDataset(torch.utils.data.ConcatDataset): | |
def unique_keys(self): | |
return itertools.chain(*[d.unique_keys() for d in self.datasets]) | |
def __add__(self, other): | |
return ConcatNumpyDataset([self, other]) | |
class NumpyDataset(torch.utils.data.Dataset): | |
"""Dataset wrapping arrays. | |
Each sample will be retrieved by indexing array along the first dimension. | |
Arguments: | |
*arrays (numpy.array): arrays that have the same size of the first dimension. | |
""" | |
def __init__(self, *arrays): | |
assert all(arrays[0].shape[0] == array.shape[0] for array in arrays) | |
self.arrays = arrays | |
def __getitem__(self, index): | |
return tuple(array[index] for array in self.arrays) | |
def unique_keys(self): | |
""" | |
This returns an int that is a hash of the input data and args. | |
Needed to invalidate cache. | |
""" | |
return to_hash(self.arrays) | |
def __add__(self, other): | |
return ConcatNumpyDataset([self, other]) | |
def __len__(self): | |
return self.arrays[0].shape[0] | |
def write_info( | |
dataset_length: int, x: np.array, dirname: Path, chunk_size: int | |
) -> dict: | |
"""Write info similar to dask.array.to_npy_stack""" | |
chunks = (chunk_size,) * (dataset_length // chunk_size) | |
chunks = (chunks,) + tuple((s,) for s in x.shape) | |
dirname.mkdir(exist_ok=True) | |
meta = {"chunks": chunks, "dtype": x.dtype, "axis": 0} | |
with dirname.joinpath("info").open("wb") as f: | |
pickle.dump(meta, f) | |
return meta | |
def to_hash(obj): | |
"""Hash most python objects that persists between sessions""" | |
s = pd.io.json.dumps(obj).encode("utf-8") | |
m = hashlib.md5(s) | |
# TEST should work for pandas timestamp, numpy, set etc | |
return int(m.hexdigest(), 16) % 10 ** 8 | |
class DaskDataset(torch.utils.data.Dataset): | |
r"""Dataset wrapping dask arrays. | |
Each sample will be retrieved by indexing tensors along the first dimension. | |
Arguments: | |
*tensors (Tensor): tensors that have the same size of the first dimension. | |
""" | |
def __init__(self, *tensors): | |
assert all(tensors[0].shape[0] == tensor.shape[0] for tensor in tensors) | |
self.tensors = tensors | |
def __getitem__(self, index): | |
# return tuple(torch.from_numpy(tensor[index].compute()) for tensor in self.tensors) | |
return tuple(tensor[index] for tensor in self.tensors) | |
def __len__(self): | |
return self.tensors[0].shape[0] | |
def __repr__(self): | |
return f"{type(self)}(len={len(self)})" | |
# Write it so we can load it using dask.from_npy | |
def cache_dataset( | |
dataset_cache: torch.utils.data.Dataset, | |
cache_base_path: Path, | |
chunk_size: int = 16, | |
num_workers=0, | |
shuffle=False, | |
): | |
""" | |
Takes torch dataset, caches it in chunks to npz files, and loads into a dataset using dask | |
""" | |
np_data_cache_x = cache_base_path.joinpath( | |
f"np_data_cache_x_{to_hash(dataset_cache.unique_keys())}" | |
) | |
np_data_cache_y = cache_base_path.joinpath( | |
f"np_data_cache_y_{to_hash(dataset_cache.unique_keys())}" | |
) | |
x, y = dataset_cache[0] | |
if not np_data_cache_x.exists(): | |
write_info( | |
len(dataset_cache), x.numpy(), dirname=np_data_cache_x, chunk_size=chunk_size | |
) | |
write_info( | |
len(dataset_cache), y.numpy(), dirname=np_data_cache_y, chunk_size=chunk_size | |
) | |
loader_cache = torch.utils.data.DataLoader( | |
dataset_cache, | |
batch_size=chunk_size, | |
shuffle=shuffle, | |
pin_memory=False, | |
drop_last=False, | |
num_workers=num_workers, | |
) | |
logger.info(f"Saving to cache exists {np_data_cache_x}") | |
for i, (x_batch, y_batch) in enumerate( | |
tqdm(loader_cache, desc="Caching dataloader") | |
): | |
np.save(np_data_cache_x.joinpath(f"{i}.npy").open("wb"), x_batch.numpy()) | |
np.save(np_data_cache_y.joinpath(f"{i}.npy").open("wb"), y_batch.numpy()) | |
else: | |
logger.info(f"Cache already exists {np_data_cache_x}") | |
return DaskDataset( | |
da.from_npy_stack(np_data_cache_x), da.from_npy_stack(np_data_cache_y) | |
) | |
class SequenceInChunkSampler(torch.utils.data.sampler.Sampler): | |
""" | |
Samples sequences of elements sequentially, but random sequences in a chunk. | |
Arguments: | |
data_source (Dataset): dataset to sample from | |
seq_len (int): length of sequential sequences | |
chunksize (int): length of cached data to take random sequences from | |
url: https://gist.github.com/wassname/8ae1f64389c2aaceeb84fcd34c3651c3 | |
""" | |
def __init__(self, data_source, seq_len=6, chunksize=6000): | |
assert chunksize % seq_len == 0, "chunk size should be a multiple of seq_len" | |
assert len(data_source) > chunksize | |
self.data_source = data_source | |
self.seq_len = seq_len | |
self.chunksize = chunksize | |
def __iter__(self): | |
chunk_idxs = np.arange(0, len(self.data_source), self.chunksize) | |
max_i = len(self.data_source) | |
print('max_i', max_i) | |
for chunk_idx in chunk_idxs: | |
seqs = np.arange( | |
chunk_idx, min(chunk_idx + self.chunksize, max_i), self.seq_len | |
) | |
np.random.shuffle(seqs) | |
for seq_i in seqs: | |
for i in np.arange(seq_i, min(seq_i + self.seq_len, max_i)): | |
yield i | |
def __len__(self): | |
return len(self.data_source) | |
def collate_fns(scheduler=None): | |
def collate_fn(batch): | |
"""Collare uncomputed dask arrays.""" | |
x = da.stack([x for x, y in batch], 0).compute(scheduler=scheduler) | |
y = da.stack([y for x, y in batch], 0).compute(scheduler=scheduler) | |
x = torch.from_numpy(x).float() | |
y = torch.from_numpy(y).float() | |
return x, y | |
return collate_fn | |
def cache2loader( | |
dataset: torch.utils.data.Dataset, | |
cache_base_path: Path, | |
chunk_size: int = 16, | |
num_workers=0, | |
shuffle=False, | |
batch_size: int = 16, | |
): | |
dataset_cached = cache_dataset( | |
dataset, | |
chunk_size, | |
num_workers, | |
cache_base_path, | |
shuffle) | |
loader = DataLoader( | |
dataset_cached, | |
batch_size, | |
pin_memory=True, | |
drop_last=True, | |
sampler=SequenceInChunkSampler(dataset_cached, seq_len=batch_size, chunksize=batch_size*4) if shuffle else None, | |
collate_fn=collate_fns(scheduler='synchronous' if num_workers else None), | |
num_workers=num_workers | |
) | |
return dataset_cached, loader |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Notes: it may be easier to write xarray with named coordinates. This can make plotting, visualization, and custom metrics easier later on.