Skip to content

Instantly share code, notes, and snippets.

@DRMacIver DRMacIver/intset.py
Last active Jul 29, 2019

Embed
What would you like to do?
# coding=utf-8
# This file is part of Hypothesis (https://github.com/DRMacIver/hypothesis)
# Most of this work is copyright (C) 2013-2015 David R. MacIver
# (david@drmaciver.com), but it contains contributions by others. See
# https://github.com/DRMacIver/hypothesis/blob/master/CONTRIBUTING.rst for a
# full list of people who may hold copyright, and consult the git log if you
# need to determine who owns an individual contribution.
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at http://mozilla.org/MPL/2.0/.
# END HEADER
from collections import Sequence, Set
from abc import abstractmethod
def interval(start, end):
"""
Return an IntSet containing only the values x such that start <= x < end
"""
if end <= start:
return empty_intset
else:
assert 0 <= start
assert end <= 2 ** 64
return Interval(start, end)
def single(value):
"""
Return an IntSet containing only the single value provided
"""
return Interval(value, value + 1)
def empty():
"""Return an empty IntSet"""
return empty_intset
class IntSet(Sequence, Set):
"""
An IntSet is a compressed representation of a sorted list of unsigned
64-bit integers with fast membership, union and range restriction.
IntSets are immutable.
"""
def __len__(self):
return self.size
def __eq__(self, other):
if self is other:
return True
if not isinstance(other, IntSet):
return False
if self.size != other.size:
return False
return self.__cmp__(other) == 0
def __ne__(self, other):
return not self.__eq__(other)
def __cmp__(self, other):
self_intervals = list(self.intervals())
other_intervals = list(other.intervals())
self_intervals.reverse()
other_intervals.reverse()
while self_intervals and other_intervals:
self_head = self_intervals.pop()
other_head = other_intervals.pop()
if self_head[0] < other_head[0]:
return -1
if self_head[0] > other_head[0]:
return 1
if self_head[1] < other_head[1]:
other_intervals.append((self_head[1], other_head[1]))
if self_head[1] > other_head[1]:
self_intervals.append((other_head[1], self_head[1]))
if self_intervals:
return 1
if other_intervals:
return -1
return 0
def __lt__(self, other):
return self.__cmp__(other) < 0
def __gt__(self, other):
return self.__cmp__(other) > 0
def __le__(self, other):
return self.__cmp__(other) <= 0
def __ge__(self, other):
return self.__cmp__(other) >= 0
@classmethod
def from_intervals(cls, intervals):
"""
Return a new IntSet which contains precisely the intervals passed in.
"""
base = empty_intset
for ints in intervals:
base |= interval(*ints)
return base
@abstractmethod
def insert(self, value):
"""
Returns an IntSet which contains all the values of the current one
plus the provided value
"""
def __contains__(self, i):
if isinstance(self, Empty):
return False
while not isinstance(self, Interval):
assert isinstance(self, Split)
if _is_zero(i, self.mask):
self = self.left
else:
self = self.right
assert isinstance(self, Interval)
return self.start <= i < self.end
def __iter__(self):
for start, end in self.intervals():
for i in range(start, end):
yield i
def __getitem__(self, i):
if i < -self.size or i >= self.size:
raise IndexError("IntSet index %d out of range for size %d" % (
i, self.size,
))
if i < 0:
i += self.size
assert i >= 0
while isinstance(self, Split):
if i < self.left.size:
self = self.left
else:
i -= self.left.size
self = self.right
assert isinstance(self, Interval)
assert 0 <= i < self.size
return self.start + i
def __hash__(self):
mid = self.size // 2
return hash((
self[0], self[mid], self[-1]
))
def __and__(self, other):
if min(self.size, other.size) == 0:
return empty_intset
if self.size > other.size:
self, other = other, self
if self.size == 1:
if self.start in other:
return self
else:
return empty_intset
if isinstance(self, Interval):
return other.restrict(self.start, self.end)
if isinstance(other, Interval):
return self.restrict(other.start, other.end)
assert isinstance(self, Split)
assert isinstance(other, Split)
if _shorter(other.mask, self.mask):
self, other = other, self
if _shorter(self.mask, other.mask):
if _no_match(other.prefix, self.prefix, self.mask):
return empty_intset
elif _is_zero(other.prefix, self.mask):
return self.left & other
else:
return self.right & other
else:
assert self.mask == other.mask
if self.prefix == other.prefix:
return self._new_split(
self.prefix, self.mask,
self.left & other.left,
self.right & other.right
)
else:
return empty_intset
def __sub__(self, other):
if other.size == 0:
return self
if self.size == 0:
return self
if isinstance(other, Interval):
return self.restrict(self.start, other.start) | \
self.restrict(other.end, self.end)
if self.size == 1:
if self.start in other:
return empty_intset
else:
return self
if isinstance(self, Interval):
self = self._split_interval()
assert isinstance(self, Split)
assert isinstance(other, Split)
if _shorter(self.mask, other.mask):
if _no_match(other.prefix, self.prefix, self.mask):
return self
elif _is_zero(other.prefix, self.mask):
return self._new_split(
self.prefix, self.mask, self.left - other, self.right
)
else:
return self._new_split(
self.prefix, self.mask, self.left, self.right - other
)
elif _shorter(other.mask, self.mask):
if _is_zero(self.prefix, other.mask):
return self - other.left
else:
return self - other.right
else:
if self.prefix == other.prefix:
return self._new_split(
self.prefix, self.mask,
self.left - other.left,
self.right - other.right
)
else:
return self
def __xor__(self, other):
return (self | other) - (self & other)
def __or__(self, other):
if self.size == 0:
return other
if other.size == 0:
return self
if other.size > self.size:
other, self = self, other
if isinstance(self, Interval) and isinstance(other, Interval):
if not (self.start > other.end or other.start > self.end):
return self._new_interval(
min(self.start, other.start), max(self.end, other.end))
elif self.size > 1:
return self._split_interval() | other
else:
assert self.size == other.size == 1
return _join(self.start, self, other.start, other)
if isinstance(other, Interval):
if other.start <= self.start < self.end <= other.end:
return other
if isinstance(self, Interval):
if self.start <= other.start < other.end <= self.end:
return self
if isinstance(other, Interval):
if other.size == 1:
return self.insert(other.start)
else:
other = other._split_interval()
if isinstance(self, Interval):
self = self._split_interval()
assert isinstance(self, Split)
assert isinstance(other, Split)
if _shorter(other.mask, self.mask):
self, other = other, self
if _shorter(self.mask, other.mask):
if _no_match(other.prefix, self.prefix, self.mask):
return _join(
self.prefix, self, other.prefix, other
)
elif _is_zero(other.prefix, self.mask):
return self._new_split(
self.prefix, self.mask, self.left | other, self.right
)
else:
return self._new_split(
self.prefix, self.mask, self.left, self.right | other
)
else:
assert self.mask == other.mask
if self.prefix == other.prefix:
return self._new_split(
self.prefix, self.mask,
self.left | other.left,
self.right | other.right
)
else:
return _join(self.prefix, self, other.prefix, other)
def intervals(self):
"""
Provide a sorted iterator over a sequence of values start < end which
represent non-overlapping intervals such that for any start <= x < end
x in self
"""
stack = [self]
while stack:
head = stack.pop()
if isinstance(head, Interval):
yield (head.start, head.end)
elif isinstance(head, Split):
stack.append(head.right)
stack.append(head.left)
def _new_split(self, prefix, mask, left, right):
if left.size == 0:
return right
if right.size == 0:
return left
if (
isinstance(left, Interval) and isinstance(right, Interval) and
not (left.start > right.end or right.start > left.end)
):
return self._new_interval(
min(left.start, right.start), max(left.end, right.end))
if (
isinstance(self, Split) and
prefix == self.prefix and mask == self.mask and left is self.left
and right is self.right
):
return self._compress()
return Split(
prefix=prefix, mask=mask, left=left, right=right)._compress()
def _new_interval(self, start, end):
return interval(start, end)
class Empty(IntSet):
size = 0
def __hash__(self):
return 0
def insert(self, value):
return single(value)
def discard(self, value):
return self
def restrict(self, start, end):
return self
def __repr__(self):
return "empty()"
empty_intset = Empty()
class Split(IntSet):
def __init__(self, prefix, mask, left, right):
self.mask = mask
self.prefix = prefix
self.left = left
self.right = right
for sub in (left, right):
if isinstance(sub, Split):
assert _shorter(self.mask, sub.mask)
self.size = left.size + right.size
self.start = left.start
self.end = right.end
assert mask > 0
def _compress(self):
if self.end == self.start + self.size:
return interval(self.start, self.end)
else:
return self
def __repr__(self):
return "IntSet.from_intervals([%s])" % (', '.join(
"(%d, %d)" % interval for interval in self.intervals()
))
def insert(self, value):
if _no_match(value, self.prefix, self.mask):
return _join(
value, single(value),
self.prefix, self
)
elif _is_zero(value, self.mask):
return self._new_split(
prefix=self.prefix, mask=self.mask,
left=self.left.insert(value),
right=self.right,
)
else:
return self._new_split(
prefix=self.prefix, mask=self.mask,
left=self.left,
right=self.right.insert(value),
)
def discard(self, value):
if _is_zero(value, self.mask):
return self._new_split(
prefix=self.prefix, mask=self.mask,
left=self.left.discard(value), right=self.right
)
else:
return self._new_split(
prefix=self.prefix, mask=self.mask,
left=self.left, right=self.right.discard(value)
)
def restrict(self, start, end):
if (start <= self.start) and (self.end <= end):
return self
return self._new_split(
mask=self.mask, prefix=self.prefix,
left=self.left.restrict(start, end),
right=self.right.restrict(start, end),
)
class Interval(IntSet):
def __init__(self, start, end):
self.start = start
self.end = end
assert self.start < self.end
@property
def size(self):
return self.end - self.start
def __repr__(self):
if self.size == 1:
return "single(%d)" % (self.start,)
else:
return "interval(%d, %d)" % (self.start, self.end)
def insert(self, value):
if self.start <= value < self.end:
return self
elif self.size == 1:
return _join(self.start, self, value, single(value))
elif value + 1 == self.start:
return interval(self.start - 1, self.end)
elif value == self.end:
return interval(self.start, self.end + 1)
else:
return self._split_interval().insert(value)
def discard(self, value):
if value < self.start or value >= self.end:
return self
if value == self.start:
return interval(self.start + 1, self.end)
if value + 1 == self.end:
return interval(self.start, self.end - 1)
return self._split_interval().discard(value)
def restrict(self, start, end):
if start <= self.start < self.end <= end:
return self
else:
start = max(start, self.start)
end = min(end, self.end)
return interval(max(start, self.start), min(end, self.end))
def _split_interval(self):
assert self.size >= 2
split_mask = branch_mask(self.start, self.end - 1)
split_prefix = _mask_off(self.start, split_mask)
split_point = split_prefix | split_mask
assert self.start < split_point < self.end
return Split(
prefix=split_prefix, mask=split_mask,
left=Interval(self.start, split_point),
right=Interval(split_point, self.end),
)
def _new_interval(self, start, end):
if self.start == start and self.end == end:
return self
else:
return super(Interval, self)._new_interval(start, end)
def _right_fill_bits(key):
key |= (key >> 1)
key |= (key >> 2)
key |= (key >> 4)
key |= (key >> 8)
key |= (key >> 16)
key |= (key >> 32)
return key
def _highest_bit_mask(k):
k = _right_fill_bits(k)
k ^= (k >> 1)
return k
def branch_mask(p1, p2):
return _highest_bit_mask(p1 ^ p2)
def _mask_off(i, m):
return i & (~(m-1) ^ m)
def _is_zero(i, m):
return (i & m) == 0
def _join(p1, t1, p2, t2):
assert t1.size
assert t2.size
assert p1 != p2
m = branch_mask(p1, p2)
p = _mask_off(p1, m)
if not _is_zero(p1, m):
t1, t2 = t2, t1
return Split(prefix=p, mask=m, left=t1, right=t2)
def _no_match(i, p, m):
return _mask_off(i, m) != p
def _shorter(m1, m2):
return m1 > m2
# coding=utf-8
# This file is part of Hypothesis (https://github.com/DRMacIver/hypothesis)
# Most of this work is copyright (C) 2013-2015 David R. MacIver
# (david@drmaciver.com), but it contains contributions by others. See
# https://github.com/DRMacIver/hypothesis/blob/master/CONTRIBUTING.rst for a
# full list of people who may hold copyright, and consult the git log if you
# need to determine who owns an individual contribution.
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at http://mozilla.org/MPL/2.0/.
# END HEADER
from hypothesis import given, assume
import hypothesis.strategies as st
from intset import interval, empty, single, IntSet
import operator as op
import pytest
from hypothesis.stateful import RuleBasedStateMachine, rule, Bundle
def test_not_equal_to_other_types():
assert single(1) != 1
integers_in_range = st.integers(min_value=0, max_value=2 ** 64 - 1)
intervals = st.tuples(integers_in_range, integers_in_range).map(
lambda x: sorted(tuple(x))
)
SMALL = 100
short_intervals = st.builds(
lambda start, length:
assume(start + length <= 2 ** 64) and (start, start + length),
integers_in_range,
st.integers(0, SMALL)
)
interval_list = st.lists(intervals, average_size=10)
IntSets = st.builds(
IntSet.from_intervals, interval_list
) | integers_in_range.map(single)
@given(IntSets, st.integers())
def test_raises_index_error_out_of_bounds(x, i):
assume(abs(i) > x.size)
with pytest.raises(IndexError):
x[i]
SmallIntSets = st.builds(
IntSet.from_intervals, st.lists(short_intervals, average_size=5)
).filter(lambda x: x.size <= SMALL)
@given(IntSets, IntSets)
def test_ordering_method_consistency(x, y):
assert (x <= y) == (not (x > y))
assert (x >= y) == (not (x < y))
@given(SmallIntSets)
def test_an_intset_contains_all_its_values(imp):
for i in imp:
assert i in imp
@given(SmallIntSets)
def test_an_intset_iterates_in_sorted_order(imp):
last = None
for i in imp:
if last is not None:
assert i > last
last = i
@given(SmallIntSets)
def test_is_equal_to_sequential_insertion(imp):
equiv = empty()
for i in imp:
equiv = equiv.insert(i)
assert imp == equiv
@given(SmallIntSets)
def test_is_equal_to_reverse_insertion(imp):
equiv = empty()
for i in reversed(list(imp)):
equiv = equiv.insert(i)
assert imp == equiv
@given(SmallIntSets, st.randoms())
def test_is_equal_to_random_insertion(imp, rnd):
items = list(imp)
rnd.shuffle(items)
equiv = empty()
for i in items:
equiv = equiv.insert(i)
assert imp == equiv
@given(SmallIntSets)
def test_an_intset_is_consistent_with_its_index(imp):
for index, value in enumerate(imp):
assert imp[index] == value
@given(SmallIntSets)
def test_an_intset_is_consistent_with_its_negative_index(imp):
values = list(imp)
for index in range(-1, -len(values) - 1, -1):
assert values[index] == imp[index]
@given(intervals, st.lists(integers_in_range, min_size=1))
def test_insert_into_interval(bounds, ints):
imp = interval(*bounds)
for i in ints:
imp = imp.insert(i)
assert i in imp
for i in ints:
assert i in imp
@given(intervals, intervals)
def test_union_of_two_intervals_contains_each_start(i1, i2):
assume(i1[0] < i1[1])
assume(i2[0] < i2[1])
x = interval(*i1) | interval(*i2)
assert i1[0] in x
assert i2[0] in x
@given(interval_list, integers_in_range)
def test_unioning_a_value_in_includes_it(intervals, i):
mp = IntSet.from_intervals(intervals)
assume(i not in mp)
mp2 = mp | interval(i, i + 1)
assert i in mp2
@given(IntSets)
def test_restricting_bounds_reduces_size_by_one(imp):
assume(imp.size > 0)
lower = imp[0]
upper = imp[-1] + 1
pop_left = imp.restrict(lower + 1, upper)
pop_right = imp.restrict(lower, upper - 1)
assert pop_left.size == imp.size - 1
assert pop_right.size == imp.size - 1
@given(SmallIntSets)
def test_restricting_bounds_splits_set(imp):
assume(imp.size > 0)
lower = imp[0]
upper = imp[-1] + 1
for i in imp:
left = imp.restrict(lower, i)
right = imp.restrict(i, upper)
assert left.size + right.size == imp.size
assert i in right
assert i not in left
together = left | right
assert together.size == imp.size
assert i in together
@given(IntSets, short_intervals)
def test_restricting_bounds_restricts_bounds(imp, interval):
smaller = imp.restrict(*interval)
assert smaller.size <= interval[1] - interval[0]
for i in smaller:
assert i in imp
assert interval[0] <= i < interval[1]
@given(SmallIntSets, intervals)
def test_restricting_bounds_does_not_remove_other_items(imp, interval):
smaller = imp.restrict(*interval)
assert smaller.size <= interval[1] - interval[0]
for i in smaller:
assert i in imp
assert interval[0] <= i < interval[1]
@given(SmallIntSets)
def test_equality_is_preserved(imp):
for i in imp:
assert imp == imp.insert(i)
assert imp == (imp | imp)
@given(st.lists(SmallIntSets, average_size=10))
def test_sorts_as_lists(intsets):
as_lists = list(map(list, intsets))
as_lists.sort()
intsets.sort()
assert as_lists == list(map(list, intsets))
@given(st.lists(SmallIntSets, average_size=10))
def test_hashes_correctly(intsets):
as_set = set(intsets)
for i in intsets:
assert i in as_set
@given(SmallIntSets)
def test_all_values_lie_between_bounds(imp):
assume(imp.size > 0)
for i in imp:
assert imp.start <= i < imp.end
@given(SmallIntSets, SmallIntSets)
def test_union_gives_union(x, y):
z = x | y
for u in (x, y):
for t in u:
assert t in z
for t in z:
assert (t in x) or (t in u)
@given(SmallIntSets, SmallIntSets)
def test_intersection_gives_intersection(x, y):
assert set(x) & set(y) == set(x & y)
@given(SmallIntSets, SmallIntSets)
def test_subtraction_gives_subtraction(x, y):
assert set(x) - set(y) == set(x - y)
@given(SmallIntSets, SmallIntSets)
def test_subtraction_cancels_union(x, y):
assert (x - y) == (x | y) - y
assert (x - y) | y == x | y
@given(SmallIntSets, SmallIntSets, SmallIntSets)
def test_intersection_distributes_over_union(x, y, z):
assert x & (y | z) == (x & y) | (x & z)
@pytest.mark.parametrize('f', [op.and_, op.or_, op.xor])
@given(SmallIntSets, SmallIntSets, SmallIntSets)
def test_associative_operators(f, x, y, z):
assert f(f(x, y), z) == f(x, f(y, z))
@pytest.mark.parametrize('f', [op.and_, op.or_, op.xor])
@given(SmallIntSets, SmallIntSets)
def test_commutative_operators(f, x, y):
assert f(x, y) == f(y, x)
@given(IntSets, SmallIntSets)
def test_subtract_is_sequential_discard(x, y):
expected = x
for u in y:
expected = expected.discard(u)
assert (x - y) == expected
class SetModel(RuleBasedStateMachine):
intsets = Bundle('IntSets')
values = Bundle('values')
@rule(target=values, i=integers_in_range)
def int_value(self, i):
return i
@rule(target=values, i=integers_in_range, imp=intsets)
def endpoint_value(self, i, imp):
if len(imp[0]) > 0:
return imp[0][-1]
else:
return i
@rule(target=values, i=integers_in_range, imp=intsets)
def startpoint_value(self, i, imp):
if len(imp[0]) > 0:
return imp[0][0]
else:
return i
@rule(target=intsets, bounds=short_intervals)
def build_interval(self, bounds):
return (
interval(*bounds), list(range(*bounds))
)
@rule(target=intsets, v=values)
def single_value(self, v):
return (single(v), [v])
@rule(target=intsets, v=values)
def adjacent_values(self, v):
assume(v + 1 <= 2 ** 64)
return (interval(v, v+2), [v, v + 1])
@rule(target=intsets, v=values)
def three_adjacent_values(self, v):
assume(v + 2 <= 2 ** 64)
return (interval(v, v+3), [v, v + 1, v + 2])
@rule(target=intsets, v=values)
def three_adjacent_values_with_hole(self, v):
assume(v + 2 <= 2 ** 64)
return (single(v) | single(v + 2), [v, v + 2])
@rule(target=intsets, x=intsets, y=intsets)
def union(self, x, y):
return (
x[0] | y[0],
sorted(set(x[1] + y[1]))
)
@rule(target=intsets, x=intsets, y=intsets)
def intersect(self, x, y):
return (
x[0] & y[0],
sorted(set(x[1]) & set(y[1]))
)
@rule(target=intsets, x=intsets, y=intsets)
def subtract(self, x, y):
return (
x[0] - y[0],
sorted(set(x[1]) - set(y[1]))
)
@rule(target=intsets, x=intsets, i=values)
def insert(self, x, i):
return (
x[0].insert(i),
sorted(set(x[1] + [i]))
)
@rule(target=intsets, x=intsets, i=values)
def discard(self, x, i):
return (
x[0].discard(i),
sorted(set(x[1]) - set([i]))
)
@rule(target=intsets, source=intsets, bounds=intervals)
def restrict(self, source, bounds):
return (
source[0].restrict(*bounds),
[x for x in source[1] if bounds[0] <= x < bounds[1]]
)
@rule(target=intsets, x=intsets)
def peel_left(self, x):
if len(x[0]) == 0:
return x
return self.restrict(x, (x[0].start + 1, x[0].end))
@rule(target=intsets, x=intsets)
def peel_right(self, x):
if len(x[0]) == 0:
return x
return self.restrict(x, (x[0].start, x[0].end - 1))
@rule(x=intsets, y=intsets)
def validate_order(self, x, y):
assert (x[0] <= y[0]) == (x[1] <= y[1])
@rule(x=intsets, y=intsets)
def validate_equality(self, x, y):
assert (x[0] == y[0]) == (x[1] == y[1])
@rule(source=intsets)
def validate(self, source):
assert list(source[0]) == source[1]
assert len(source[0]) == len(source[1])
for i in range(-len(source[0]), len(source[0])):
assert source[0][i] == source[1][i]
if len(source[0]) > 0:
for v in source[1]:
assert source[0].start <= v < source[0].end
TestState = SetModel.TestCase
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.