Skip to content

Instantly share code, notes, and snippets.

@noirmist
Forked from rkaplan/collate.py
Created March 5, 2019 04:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save noirmist/16ce6ad523725ba7655fab5ffa89eb0e to your computer and use it in GitHub Desktop.
Save noirmist/16ce6ad523725ba7655fab5ffa89eb0e to your computer and use it in GitHub Desktop.
PyTorch example of a custom collate function that uses shared memory when appropriate
import functools
def my_collate(batch, use_shared_memory=False):
r"""Puts each data field into a tensor with outer dimension batch size"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
if use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], container_abcs.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
elif ... # your custom condition:
# handle data frames
raise TypeError((error_msg.format(type(batch[0]))))
def main(args):
collate_fn = functools.partial(my_collate, use_shared_memory=args.num_workers > 0)
dataloader = DataLoader(..., num_workers=args.num_workers, collate_fn=collate_fn)
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment