Skip to content

Instantly share code, notes, and snippets.

@dboyliao
Last active June 4, 2021 08:39
Show Gist options
  • Save dboyliao/d4a72cbbd3f3e62517865519b4c1e9c6 to your computer and use it in GitHub Desktop.
Save dboyliao/d4a72cbbd3f3e62517865519b4c1e9c6 to your computer and use it in GitHub Desktop.
pure python implementation of strided slice
import numpy as np
class StridedIterator:
"""
Reference:
- https://github.com/python/cpython/blob/b2bf2bc1ece673d387341e06c8d3c2bc6e259747/Modules/itertoolsmodule.c#L2342
"""
def __init__(self, begin, end, strides):
self._idx_cnt = list(begin)
self._begin = list(begin)
self._end = list(end)
self._strides = list(strides)
self._hit_last = False
def __iter__(self):
return self
def __next__(self):
return self.next()
def next(self):
result = tuple(self._idx_cnt)
# update indices
n = len(self._idx_cnt)
self._idx_cnt[-1] += self._strides[-1]
for i in range(-2, -n - 1, -1):
if self._idx_cnt[i + 1] >= self._end[i + 1]:
self._idx_cnt[i] += self._strides[i]
if self._hit_last or self.is_done():
if self._hit_last:
self._idx_cnt = list(self._begin)
self._hit_last = False
raise StopIteration
else:
self._hit_last = True
# check if the idx exceed end
for i in range(len(self._idx_cnt)):
if self._idx_cnt[i] >= self._end[i]:
self._idx_cnt[i] = self._begin[i]
return result
def is_done(self):
return all(idx >= e for idx, e in zip(self._idx_cnt, self._end))
if __name__ == "__main__":
x = np.random.rand(10, 7, 5, 8)
slice1 = x[0:3:2, 2:4, 0:3, :]
it = StridedIterator([0, 2, 0, 0], [3, 4, 3, 8], [2, 1, 1, 1])
slice2 = np.array([x[v] for v in it]).reshape(slice1.shape)
print("slice1 == slice2: ", np.alltrue(slice1 == slice2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment