Created
March 8, 2017 08:12
-
-
Save nicolasdespres/418279e152dca606a78584fc99630738 to your computer and use it in GitHub Desktop.
Iterate over slice of tensors.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class iter_tensors_slice(Iterator): | |
"""Yield slices of tensors. | |
Sequentially slice `tensors` along `axis` using indices provided | |
by the `slices` iterable. | |
Args: | |
`tensors`: a list of tensors or a single tensors of the same size. | |
`slices`: an iterable providing indices of each slice. | |
`axis`: slice tensors along this axis | |
Output: | |
A list of tensors of the same shape of `tensors` except for the `axis` | |
dimension which has a size of the length of the slice. | |
""" | |
def __init__(self, tensors, slices, axis=0): | |
self.tensors = as_list(tensors) | |
check_tensors_samesize(self.tensors, axis=axis) | |
if not isinstance(axis, int): | |
raise TypeError("axis must be int, not {}" | |
.format(type(axis).__name__)) | |
if axis < 0: | |
raise ValueError("axis must be positive or null, not {}" | |
.format(axis)) | |
self._axis = axis | |
self.slices = slices | |
self._it = iter(self.slices) | |
def __len__(self): | |
return len(self.slices) | |
@property | |
def axis(self): | |
return self._axis | |
def __next__(self): | |
slices = next(self._it) | |
return [np.take(tensor, slices, axis=self._axis) | |
for tensor in self.tensors] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment