Last active
July 29, 2019 00:10
-
-
Save DRMacIver/c5b93c08ef7ad109fcc6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# This file is part of Hypothesis (https://github.com/DRMacIver/hypothesis) | |
# Most of this work is copyright (C) 2013-2015 David R. MacIver | |
# (david@drmaciver.com), but it contains contributions by others. See | |
# https://github.com/DRMacIver/hypothesis/blob/master/CONTRIBUTING.rst for a | |
# full list of people who may hold copyright, and consult the git log if you | |
# need to determine who owns an individual contribution. | |
# This Source Code Form is subject to the terms of the Mozilla Public License, | |
# v. 2.0. If a copy of the MPL was not distributed with this file, You can | |
# obtain one at http://mozilla.org/MPL/2.0/. | |
# END HEADER | |
from collections import Sequence, Set | |
from abc import abstractmethod | |
def interval(start, end): | |
""" | |
Return an IntSet containing only the values x such that start <= x < end | |
""" | |
if end <= start: | |
return empty_intset | |
else: | |
assert 0 <= start | |
assert end <= 2 ** 64 | |
return Interval(start, end) | |
def single(value): | |
""" | |
Return an IntSet containing only the single value provided | |
""" | |
return Interval(value, value + 1) | |
def empty(): | |
"""Return an empty IntSet""" | |
return empty_intset | |
class IntSet(Sequence, Set): | |
""" | |
An IntSet is a compressed representation of a sorted list of unsigned | |
64-bit integers with fast membership, union and range restriction. | |
IntSets are immutable. | |
""" | |
def __len__(self): | |
return self.size | |
def __eq__(self, other): | |
if self is other: | |
return True | |
if not isinstance(other, IntSet): | |
return False | |
if self.size != other.size: | |
return False | |
return self.__cmp__(other) == 0 | |
def __ne__(self, other): | |
return not self.__eq__(other) | |
def __cmp__(self, other): | |
self_intervals = list(self.intervals()) | |
other_intervals = list(other.intervals()) | |
self_intervals.reverse() | |
other_intervals.reverse() | |
while self_intervals and other_intervals: | |
self_head = self_intervals.pop() | |
other_head = other_intervals.pop() | |
if self_head[0] < other_head[0]: | |
return -1 | |
if self_head[0] > other_head[0]: | |
return 1 | |
if self_head[1] < other_head[1]: | |
other_intervals.append((self_head[1], other_head[1])) | |
if self_head[1] > other_head[1]: | |
self_intervals.append((other_head[1], self_head[1])) | |
if self_intervals: | |
return 1 | |
if other_intervals: | |
return -1 | |
return 0 | |
def __lt__(self, other): | |
return self.__cmp__(other) < 0 | |
def __gt__(self, other): | |
return self.__cmp__(other) > 0 | |
def __le__(self, other): | |
return self.__cmp__(other) <= 0 | |
def __ge__(self, other): | |
return self.__cmp__(other) >= 0 | |
@classmethod | |
def from_intervals(cls, intervals): | |
""" | |
Return a new IntSet which contains precisely the intervals passed in. | |
""" | |
base = empty_intset | |
for ints in intervals: | |
base |= interval(*ints) | |
return base | |
@abstractmethod | |
def insert(self, value): | |
""" | |
Returns an IntSet which contains all the values of the current one | |
plus the provided value | |
""" | |
def __contains__(self, i): | |
if isinstance(self, Empty): | |
return False | |
while not isinstance(self, Interval): | |
assert isinstance(self, Split) | |
if _is_zero(i, self.mask): | |
self = self.left | |
else: | |
self = self.right | |
assert isinstance(self, Interval) | |
return self.start <= i < self.end | |
def __iter__(self): | |
for start, end in self.intervals(): | |
for i in range(start, end): | |
yield i | |
def __getitem__(self, i): | |
if i < -self.size or i >= self.size: | |
raise IndexError("IntSet index %d out of range for size %d" % ( | |
i, self.size, | |
)) | |
if i < 0: | |
i += self.size | |
assert i >= 0 | |
while isinstance(self, Split): | |
if i < self.left.size: | |
self = self.left | |
else: | |
i -= self.left.size | |
self = self.right | |
assert isinstance(self, Interval) | |
assert 0 <= i < self.size | |
return self.start + i | |
def __hash__(self): | |
mid = self.size // 2 | |
return hash(( | |
self[0], self[mid], self[-1] | |
)) | |
def __and__(self, other): | |
if min(self.size, other.size) == 0: | |
return empty_intset | |
if self.size > other.size: | |
self, other = other, self | |
if self.size == 1: | |
if self.start in other: | |
return self | |
else: | |
return empty_intset | |
if isinstance(self, Interval): | |
return other.restrict(self.start, self.end) | |
if isinstance(other, Interval): | |
return self.restrict(other.start, other.end) | |
assert isinstance(self, Split) | |
assert isinstance(other, Split) | |
if _shorter(other.mask, self.mask): | |
self, other = other, self | |
if _shorter(self.mask, other.mask): | |
if _no_match(other.prefix, self.prefix, self.mask): | |
return empty_intset | |
elif _is_zero(other.prefix, self.mask): | |
return self.left & other | |
else: | |
return self.right & other | |
else: | |
assert self.mask == other.mask | |
if self.prefix == other.prefix: | |
return self._new_split( | |
self.prefix, self.mask, | |
self.left & other.left, | |
self.right & other.right | |
) | |
else: | |
return empty_intset | |
def __sub__(self, other): | |
if other.size == 0: | |
return self | |
if self.size == 0: | |
return self | |
if isinstance(other, Interval): | |
return self.restrict(self.start, other.start) | \ | |
self.restrict(other.end, self.end) | |
if self.size == 1: | |
if self.start in other: | |
return empty_intset | |
else: | |
return self | |
if isinstance(self, Interval): | |
self = self._split_interval() | |
assert isinstance(self, Split) | |
assert isinstance(other, Split) | |
if _shorter(self.mask, other.mask): | |
if _no_match(other.prefix, self.prefix, self.mask): | |
return self | |
elif _is_zero(other.prefix, self.mask): | |
return self._new_split( | |
self.prefix, self.mask, self.left - other, self.right | |
) | |
else: | |
return self._new_split( | |
self.prefix, self.mask, self.left, self.right - other | |
) | |
elif _shorter(other.mask, self.mask): | |
if _is_zero(self.prefix, other.mask): | |
return self - other.left | |
else: | |
return self - other.right | |
else: | |
if self.prefix == other.prefix: | |
return self._new_split( | |
self.prefix, self.mask, | |
self.left - other.left, | |
self.right - other.right | |
) | |
else: | |
return self | |
def __xor__(self, other): | |
return (self | other) - (self & other) | |
def __or__(self, other): | |
if self.size == 0: | |
return other | |
if other.size == 0: | |
return self | |
if other.size > self.size: | |
other, self = self, other | |
if isinstance(self, Interval) and isinstance(other, Interval): | |
if not (self.start > other.end or other.start > self.end): | |
return self._new_interval( | |
min(self.start, other.start), max(self.end, other.end)) | |
elif self.size > 1: | |
return self._split_interval() | other | |
else: | |
assert self.size == other.size == 1 | |
return _join(self.start, self, other.start, other) | |
if isinstance(other, Interval): | |
if other.start <= self.start < self.end <= other.end: | |
return other | |
if isinstance(self, Interval): | |
if self.start <= other.start < other.end <= self.end: | |
return self | |
if isinstance(other, Interval): | |
if other.size == 1: | |
return self.insert(other.start) | |
else: | |
other = other._split_interval() | |
if isinstance(self, Interval): | |
self = self._split_interval() | |
assert isinstance(self, Split) | |
assert isinstance(other, Split) | |
if _shorter(other.mask, self.mask): | |
self, other = other, self | |
if _shorter(self.mask, other.mask): | |
if _no_match(other.prefix, self.prefix, self.mask): | |
return _join( | |
self.prefix, self, other.prefix, other | |
) | |
elif _is_zero(other.prefix, self.mask): | |
return self._new_split( | |
self.prefix, self.mask, self.left | other, self.right | |
) | |
else: | |
return self._new_split( | |
self.prefix, self.mask, self.left, self.right | other | |
) | |
else: | |
assert self.mask == other.mask | |
if self.prefix == other.prefix: | |
return self._new_split( | |
self.prefix, self.mask, | |
self.left | other.left, | |
self.right | other.right | |
) | |
else: | |
return _join(self.prefix, self, other.prefix, other) | |
def intervals(self): | |
""" | |
Provide a sorted iterator over a sequence of values start < end which | |
represent non-overlapping intervals such that for any start <= x < end | |
x in self | |
""" | |
stack = [self] | |
while stack: | |
head = stack.pop() | |
if isinstance(head, Interval): | |
yield (head.start, head.end) | |
elif isinstance(head, Split): | |
stack.append(head.right) | |
stack.append(head.left) | |
def _new_split(self, prefix, mask, left, right): | |
if left.size == 0: | |
return right | |
if right.size == 0: | |
return left | |
if ( | |
isinstance(left, Interval) and isinstance(right, Interval) and | |
not (left.start > right.end or right.start > left.end) | |
): | |
return self._new_interval( | |
min(left.start, right.start), max(left.end, right.end)) | |
if ( | |
isinstance(self, Split) and | |
prefix == self.prefix and mask == self.mask and left is self.left | |
and right is self.right | |
): | |
return self._compress() | |
return Split( | |
prefix=prefix, mask=mask, left=left, right=right)._compress() | |
def _new_interval(self, start, end): | |
return interval(start, end) | |
class Empty(IntSet): | |
size = 0 | |
def __hash__(self): | |
return 0 | |
def insert(self, value): | |
return single(value) | |
def discard(self, value): | |
return self | |
def restrict(self, start, end): | |
return self | |
def __repr__(self): | |
return "empty()" | |
empty_intset = Empty() | |
class Split(IntSet): | |
def __init__(self, prefix, mask, left, right): | |
self.mask = mask | |
self.prefix = prefix | |
self.left = left | |
self.right = right | |
for sub in (left, right): | |
if isinstance(sub, Split): | |
assert _shorter(self.mask, sub.mask) | |
self.size = left.size + right.size | |
self.start = left.start | |
self.end = right.end | |
assert mask > 0 | |
def _compress(self): | |
if self.end == self.start + self.size: | |
return interval(self.start, self.end) | |
else: | |
return self | |
def __repr__(self): | |
return "IntSet.from_intervals([%s])" % (', '.join( | |
"(%d, %d)" % interval for interval in self.intervals() | |
)) | |
def insert(self, value): | |
if _no_match(value, self.prefix, self.mask): | |
return _join( | |
value, single(value), | |
self.prefix, self | |
) | |
elif _is_zero(value, self.mask): | |
return self._new_split( | |
prefix=self.prefix, mask=self.mask, | |
left=self.left.insert(value), | |
right=self.right, | |
) | |
else: | |
return self._new_split( | |
prefix=self.prefix, mask=self.mask, | |
left=self.left, | |
right=self.right.insert(value), | |
) | |
def discard(self, value): | |
if _is_zero(value, self.mask): | |
return self._new_split( | |
prefix=self.prefix, mask=self.mask, | |
left=self.left.discard(value), right=self.right | |
) | |
else: | |
return self._new_split( | |
prefix=self.prefix, mask=self.mask, | |
left=self.left, right=self.right.discard(value) | |
) | |
def restrict(self, start, end): | |
if (start <= self.start) and (self.end <= end): | |
return self | |
return self._new_split( | |
mask=self.mask, prefix=self.prefix, | |
left=self.left.restrict(start, end), | |
right=self.right.restrict(start, end), | |
) | |
class Interval(IntSet): | |
def __init__(self, start, end): | |
self.start = start | |
self.end = end | |
assert self.start < self.end | |
@property | |
def size(self): | |
return self.end - self.start | |
def __repr__(self): | |
if self.size == 1: | |
return "single(%d)" % (self.start,) | |
else: | |
return "interval(%d, %d)" % (self.start, self.end) | |
def insert(self, value): | |
if self.start <= value < self.end: | |
return self | |
elif self.size == 1: | |
return _join(self.start, self, value, single(value)) | |
elif value + 1 == self.start: | |
return interval(self.start - 1, self.end) | |
elif value == self.end: | |
return interval(self.start, self.end + 1) | |
else: | |
return self._split_interval().insert(value) | |
def discard(self, value): | |
if value < self.start or value >= self.end: | |
return self | |
if value == self.start: | |
return interval(self.start + 1, self.end) | |
if value + 1 == self.end: | |
return interval(self.start, self.end - 1) | |
return self._split_interval().discard(value) | |
def restrict(self, start, end): | |
if start <= self.start < self.end <= end: | |
return self | |
else: | |
start = max(start, self.start) | |
end = min(end, self.end) | |
return interval(max(start, self.start), min(end, self.end)) | |
def _split_interval(self): | |
assert self.size >= 2 | |
split_mask = branch_mask(self.start, self.end - 1) | |
split_prefix = _mask_off(self.start, split_mask) | |
split_point = split_prefix | split_mask | |
assert self.start < split_point < self.end | |
return Split( | |
prefix=split_prefix, mask=split_mask, | |
left=Interval(self.start, split_point), | |
right=Interval(split_point, self.end), | |
) | |
def _new_interval(self, start, end): | |
if self.start == start and self.end == end: | |
return self | |
else: | |
return super(Interval, self)._new_interval(start, end) | |
def _right_fill_bits(key): | |
key |= (key >> 1) | |
key |= (key >> 2) | |
key |= (key >> 4) | |
key |= (key >> 8) | |
key |= (key >> 16) | |
key |= (key >> 32) | |
return key | |
def _highest_bit_mask(k): | |
k = _right_fill_bits(k) | |
k ^= (k >> 1) | |
return k | |
def branch_mask(p1, p2): | |
return _highest_bit_mask(p1 ^ p2) | |
def _mask_off(i, m): | |
return i & (~(m-1) ^ m) | |
def _is_zero(i, m): | |
return (i & m) == 0 | |
def _join(p1, t1, p2, t2): | |
assert t1.size | |
assert t2.size | |
assert p1 != p2 | |
m = branch_mask(p1, p2) | |
p = _mask_off(p1, m) | |
if not _is_zero(p1, m): | |
t1, t2 = t2, t1 | |
return Split(prefix=p, mask=m, left=t1, right=t2) | |
def _no_match(i, p, m): | |
return _mask_off(i, m) != p | |
def _shorter(m1, m2): | |
return m1 > m2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# This file is part of Hypothesis (https://github.com/DRMacIver/hypothesis) | |
# Most of this work is copyright (C) 2013-2015 David R. MacIver | |
# (david@drmaciver.com), but it contains contributions by others. See | |
# https://github.com/DRMacIver/hypothesis/blob/master/CONTRIBUTING.rst for a | |
# full list of people who may hold copyright, and consult the git log if you | |
# need to determine who owns an individual contribution. | |
# This Source Code Form is subject to the terms of the Mozilla Public License, | |
# v. 2.0. If a copy of the MPL was not distributed with this file, You can | |
# obtain one at http://mozilla.org/MPL/2.0/. | |
# END HEADER | |
from hypothesis import given, assume | |
import hypothesis.strategies as st | |
from intset import interval, empty, single, IntSet | |
import operator as op | |
import pytest | |
from hypothesis.stateful import RuleBasedStateMachine, rule, Bundle | |
def test_not_equal_to_other_types(): | |
assert single(1) != 1 | |
integers_in_range = st.integers(min_value=0, max_value=2 ** 64 - 1) | |
intervals = st.tuples(integers_in_range, integers_in_range).map( | |
lambda x: sorted(tuple(x)) | |
) | |
SMALL = 100 | |
short_intervals = st.builds( | |
lambda start, length: | |
assume(start + length <= 2 ** 64) and (start, start + length), | |
integers_in_range, | |
st.integers(0, SMALL) | |
) | |
interval_list = st.lists(intervals, average_size=10) | |
IntSets = st.builds( | |
IntSet.from_intervals, interval_list | |
) | integers_in_range.map(single) | |
@given(IntSets, st.integers()) | |
def test_raises_index_error_out_of_bounds(x, i): | |
assume(abs(i) > x.size) | |
with pytest.raises(IndexError): | |
x[i] | |
SmallIntSets = st.builds( | |
IntSet.from_intervals, st.lists(short_intervals, average_size=5) | |
).filter(lambda x: x.size <= SMALL) | |
@given(IntSets, IntSets) | |
def test_ordering_method_consistency(x, y): | |
assert (x <= y) == (not (x > y)) | |
assert (x >= y) == (not (x < y)) | |
@given(SmallIntSets) | |
def test_an_intset_contains_all_its_values(imp): | |
for i in imp: | |
assert i in imp | |
@given(SmallIntSets) | |
def test_an_intset_iterates_in_sorted_order(imp): | |
last = None | |
for i in imp: | |
if last is not None: | |
assert i > last | |
last = i | |
@given(SmallIntSets) | |
def test_is_equal_to_sequential_insertion(imp): | |
equiv = empty() | |
for i in imp: | |
equiv = equiv.insert(i) | |
assert imp == equiv | |
@given(SmallIntSets) | |
def test_is_equal_to_reverse_insertion(imp): | |
equiv = empty() | |
for i in reversed(list(imp)): | |
equiv = equiv.insert(i) | |
assert imp == equiv | |
@given(SmallIntSets, st.randoms()) | |
def test_is_equal_to_random_insertion(imp, rnd): | |
items = list(imp) | |
rnd.shuffle(items) | |
equiv = empty() | |
for i in items: | |
equiv = equiv.insert(i) | |
assert imp == equiv | |
@given(SmallIntSets) | |
def test_an_intset_is_consistent_with_its_index(imp): | |
for index, value in enumerate(imp): | |
assert imp[index] == value | |
@given(SmallIntSets) | |
def test_an_intset_is_consistent_with_its_negative_index(imp): | |
values = list(imp) | |
for index in range(-1, -len(values) - 1, -1): | |
assert values[index] == imp[index] | |
@given(intervals, st.lists(integers_in_range, min_size=1)) | |
def test_insert_into_interval(bounds, ints): | |
imp = interval(*bounds) | |
for i in ints: | |
imp = imp.insert(i) | |
assert i in imp | |
for i in ints: | |
assert i in imp | |
@given(intervals, intervals) | |
def test_union_of_two_intervals_contains_each_start(i1, i2): | |
assume(i1[0] < i1[1]) | |
assume(i2[0] < i2[1]) | |
x = interval(*i1) | interval(*i2) | |
assert i1[0] in x | |
assert i2[0] in x | |
@given(interval_list, integers_in_range) | |
def test_unioning_a_value_in_includes_it(intervals, i): | |
mp = IntSet.from_intervals(intervals) | |
assume(i not in mp) | |
mp2 = mp | interval(i, i + 1) | |
assert i in mp2 | |
@given(IntSets) | |
def test_restricting_bounds_reduces_size_by_one(imp): | |
assume(imp.size > 0) | |
lower = imp[0] | |
upper = imp[-1] + 1 | |
pop_left = imp.restrict(lower + 1, upper) | |
pop_right = imp.restrict(lower, upper - 1) | |
assert pop_left.size == imp.size - 1 | |
assert pop_right.size == imp.size - 1 | |
@given(SmallIntSets) | |
def test_restricting_bounds_splits_set(imp): | |
assume(imp.size > 0) | |
lower = imp[0] | |
upper = imp[-1] + 1 | |
for i in imp: | |
left = imp.restrict(lower, i) | |
right = imp.restrict(i, upper) | |
assert left.size + right.size == imp.size | |
assert i in right | |
assert i not in left | |
together = left | right | |
assert together.size == imp.size | |
assert i in together | |
@given(IntSets, short_intervals) | |
def test_restricting_bounds_restricts_bounds(imp, interval): | |
smaller = imp.restrict(*interval) | |
assert smaller.size <= interval[1] - interval[0] | |
for i in smaller: | |
assert i in imp | |
assert interval[0] <= i < interval[1] | |
@given(SmallIntSets, intervals) | |
def test_restricting_bounds_does_not_remove_other_items(imp, interval): | |
smaller = imp.restrict(*interval) | |
assert smaller.size <= interval[1] - interval[0] | |
for i in smaller: | |
assert i in imp | |
assert interval[0] <= i < interval[1] | |
@given(SmallIntSets) | |
def test_equality_is_preserved(imp): | |
for i in imp: | |
assert imp == imp.insert(i) | |
assert imp == (imp | imp) | |
@given(st.lists(SmallIntSets, average_size=10)) | |
def test_sorts_as_lists(intsets): | |
as_lists = list(map(list, intsets)) | |
as_lists.sort() | |
intsets.sort() | |
assert as_lists == list(map(list, intsets)) | |
@given(st.lists(SmallIntSets, average_size=10)) | |
def test_hashes_correctly(intsets): | |
as_set = set(intsets) | |
for i in intsets: | |
assert i in as_set | |
@given(SmallIntSets) | |
def test_all_values_lie_between_bounds(imp): | |
assume(imp.size > 0) | |
for i in imp: | |
assert imp.start <= i < imp.end | |
@given(SmallIntSets, SmallIntSets) | |
def test_union_gives_union(x, y): | |
z = x | y | |
for u in (x, y): | |
for t in u: | |
assert t in z | |
for t in z: | |
assert (t in x) or (t in u) | |
@given(SmallIntSets, SmallIntSets) | |
def test_intersection_gives_intersection(x, y): | |
assert set(x) & set(y) == set(x & y) | |
@given(SmallIntSets, SmallIntSets) | |
def test_subtraction_gives_subtraction(x, y): | |
assert set(x) - set(y) == set(x - y) | |
@given(SmallIntSets, SmallIntSets) | |
def test_subtraction_cancels_union(x, y): | |
assert (x - y) == (x | y) - y | |
assert (x - y) | y == x | y | |
@given(SmallIntSets, SmallIntSets, SmallIntSets) | |
def test_intersection_distributes_over_union(x, y, z): | |
assert x & (y | z) == (x & y) | (x & z) | |
@pytest.mark.parametrize('f', [op.and_, op.or_, op.xor]) | |
@given(SmallIntSets, SmallIntSets, SmallIntSets) | |
def test_associative_operators(f, x, y, z): | |
assert f(f(x, y), z) == f(x, f(y, z)) | |
@pytest.mark.parametrize('f', [op.and_, op.or_, op.xor]) | |
@given(SmallIntSets, SmallIntSets) | |
def test_commutative_operators(f, x, y): | |
assert f(x, y) == f(y, x) | |
@given(IntSets, SmallIntSets) | |
def test_subtract_is_sequential_discard(x, y): | |
expected = x | |
for u in y: | |
expected = expected.discard(u) | |
assert (x - y) == expected | |
class SetModel(RuleBasedStateMachine): | |
intsets = Bundle('IntSets') | |
values = Bundle('values') | |
@rule(target=values, i=integers_in_range) | |
def int_value(self, i): | |
return i | |
@rule(target=values, i=integers_in_range, imp=intsets) | |
def endpoint_value(self, i, imp): | |
if len(imp[0]) > 0: | |
return imp[0][-1] | |
else: | |
return i | |
@rule(target=values, i=integers_in_range, imp=intsets) | |
def startpoint_value(self, i, imp): | |
if len(imp[0]) > 0: | |
return imp[0][0] | |
else: | |
return i | |
@rule(target=intsets, bounds=short_intervals) | |
def build_interval(self, bounds): | |
return ( | |
interval(*bounds), list(range(*bounds)) | |
) | |
@rule(target=intsets, v=values) | |
def single_value(self, v): | |
return (single(v), [v]) | |
@rule(target=intsets, v=values) | |
def adjacent_values(self, v): | |
assume(v + 1 <= 2 ** 64) | |
return (interval(v, v+2), [v, v + 1]) | |
@rule(target=intsets, v=values) | |
def three_adjacent_values(self, v): | |
assume(v + 2 <= 2 ** 64) | |
return (interval(v, v+3), [v, v + 1, v + 2]) | |
@rule(target=intsets, v=values) | |
def three_adjacent_values_with_hole(self, v): | |
assume(v + 2 <= 2 ** 64) | |
return (single(v) | single(v + 2), [v, v + 2]) | |
@rule(target=intsets, x=intsets, y=intsets) | |
def union(self, x, y): | |
return ( | |
x[0] | y[0], | |
sorted(set(x[1] + y[1])) | |
) | |
@rule(target=intsets, x=intsets, y=intsets) | |
def intersect(self, x, y): | |
return ( | |
x[0] & y[0], | |
sorted(set(x[1]) & set(y[1])) | |
) | |
@rule(target=intsets, x=intsets, y=intsets) | |
def subtract(self, x, y): | |
return ( | |
x[0] - y[0], | |
sorted(set(x[1]) - set(y[1])) | |
) | |
@rule(target=intsets, x=intsets, i=values) | |
def insert(self, x, i): | |
return ( | |
x[0].insert(i), | |
sorted(set(x[1] + [i])) | |
) | |
@rule(target=intsets, x=intsets, i=values) | |
def discard(self, x, i): | |
return ( | |
x[0].discard(i), | |
sorted(set(x[1]) - set([i])) | |
) | |
@rule(target=intsets, source=intsets, bounds=intervals) | |
def restrict(self, source, bounds): | |
return ( | |
source[0].restrict(*bounds), | |
[x for x in source[1] if bounds[0] <= x < bounds[1]] | |
) | |
@rule(target=intsets, x=intsets) | |
def peel_left(self, x): | |
if len(x[0]) == 0: | |
return x | |
return self.restrict(x, (x[0].start + 1, x[0].end)) | |
@rule(target=intsets, x=intsets) | |
def peel_right(self, x): | |
if len(x[0]) == 0: | |
return x | |
return self.restrict(x, (x[0].start, x[0].end - 1)) | |
@rule(x=intsets, y=intsets) | |
def validate_order(self, x, y): | |
assert (x[0] <= y[0]) == (x[1] <= y[1]) | |
@rule(x=intsets, y=intsets) | |
def validate_equality(self, x, y): | |
assert (x[0] == y[0]) == (x[1] == y[1]) | |
@rule(source=intsets) | |
def validate(self, source): | |
assert list(source[0]) == source[1] | |
assert len(source[0]) == len(source[1]) | |
for i in range(-len(source[0]), len(source[0])): | |
assert source[0][i] == source[1][i] | |
if len(source[0]) > 0: | |
for v in source[1]: | |
assert source[0].start <= v < source[0].end | |
TestState = SetModel.TestCase |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment