Skip to content

Instantly share code, notes, and snippets.

@nicolasdespres
Created March 8, 2017 08:12
Show Gist options
  • Save nicolasdespres/418279e152dca606a78584fc99630738 to your computer and use it in GitHub Desktop.
Save nicolasdespres/418279e152dca606a78584fc99630738 to your computer and use it in GitHub Desktop.
Iterate over slice of tensors.
class iter_tensors_slice(Iterator):
"""Yield slices of tensors.
Sequentially slice `tensors` along `axis` using indices provided
by the `slices` iterable.
Args:
`tensors`: a list of tensors or a single tensors of the same size.
`slices`: an iterable providing indices of each slice.
`axis`: slice tensors along this axis
Output:
A list of tensors of the same shape of `tensors` except for the `axis`
dimension which has a size of the length of the slice.
"""
def __init__(self, tensors, slices, axis=0):
self.tensors = as_list(tensors)
check_tensors_samesize(self.tensors, axis=axis)
if not isinstance(axis, int):
raise TypeError("axis must be int, not {}"
.format(type(axis).__name__))
if axis < 0:
raise ValueError("axis must be positive or null, not {}"
.format(axis))
self._axis = axis
self.slices = slices
self._it = iter(self.slices)
def __len__(self):
return len(self.slices)
@property
def axis(self):
return self._axis
def __next__(self):
slices = next(self._it)
return [np.take(tensor, slices, axis=self._axis)
for tensor in self.tensors]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment