Skip to content

Instantly share code, notes, and snippets.

@GenevieveBuckley
Last active June 28, 2021 12:58
Show Gist options
  • Save GenevieveBuckley/f913cef861a918c2e1c94862e7bdbc60 to your computer and use it in GitHub Desktop.
Save GenevieveBuckley/f913cef861a918c2e1c94862e7bdbc60 to your computer and use it in GitHub Desktop.
combine_slicing.py
import math
import numpy as np
import pytest
def combine_slices(slices):
starts = [s.start for s in slices if s.start is not None]
stops = [s.stop for s in slices if s.stop is not None]
steps = [s.step for s in slices if s.step is not None]
if len(starts) > 0:
start = sum(starts)
else:
start = None
if len(stops) > 0:
stop = sum(stops)
else:
stop = None
if len(steps) > 0:
step = np.prod(steps)
else:
step = None
combined_slice = slice(start, stop, step)
return combined_slice
@pytest.mark.parametrize("slices", [
([slice(None, None, None), slice(None, None, None)]),
([slice(1, -1, None), slice(1, -1, None)]),
([slice(None, None, 1), slice(None, None, -1), slice(None, None, -1)]),
([slice(None, None, -1), slice(None, None, -2)]),
([slice(None, None, -1), slice(None, None, -2), slice(None, None, -2)]),
([slice(None, None, -1), slice(None, None, 2), slice(None, None, -2)]),
([slice(1, None, None), slice(1, None, None), slice(1, None, None)]),
])
def test_combine_slices(slices):
x = np.arange(1000).astype(int)
output = x
for s in slices:
output = output[s]
assert len(output) > 0
combined_slice = combine_slices(slices)
output2 = x[combined_slice]
assert np.allclose(output, output2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment