Skip to content

Instantly share code, notes, and snippets.

@nat-n
Last active February 26, 2022 21:51
Show Gist options
  • Save nat-n/02e6882ef8b6aa3e8efe143795a4eae0 to your computer and use it in GitHub Desktop.
Save nat-n/02e6882ef8b6aa3e8efe143795a4eae0 to your computer and use it in GitHub Desktop.
A gold plated implementation of an immutable bit set in python, including full test coverage.
from collections.abc import Set
from numpy import packbits
import math
import random
from typing import Any, Callable, Collection, Iterable, Iterator
class ImmutableBitSet(Set):
_content: bytes
@property
def size(self) -> int:
"""
The number of allocated bits in this bitset, i.e. the maximum value of
"""
return len(self._content) * 8
def __init__(self, values: Collection[int] = b""):
if isinstance(values, bytes):
if len(values):
# Remove trailing NUL bytes
for index in range(len(values) - 1, -2, -1):
if values[index] != 0:
break
self._content = values[: index + 1]
else:
self._content = b""
elif not len(values):
self._content = b""
else:
values_set = set(values)
for value in values_set:
if not isinstance(value, int) or value < 0:
raise ValueError(
f"ImmutableBitSet only accepts positive intergers, not {value!r}"
)
bitvalues = [0] * (max(values_set) + 1)
for value in values_set:
bitvalues[value] = 1
self._content = bytes(packbits(bitvalues))
def __contains__(self, value: object) -> bool:
if not isinstance(value, int):
return False
try:
return bool(self._content[(value // 8)] & 0b1 << (8 - (value % 8) - 1))
except IndexError:
return False
def __iter__(self) -> Iterator[int]:
index = 0
cursor = 0b10000000
for byte in self._content:
for pos in range(8):
if byte & cursor >> pos:
yield index
index += 1
def __bytes__(self) -> bytes:
return self._content
def __len__(self) -> int:
"""
Return the number of elements in the ImmutableBitSet.
That is the number of bits set to 1 in self._content (aka the Hamming weight).
"""
weight = 0
for byte in self._content:
while byte:
weight += 1
byte &= byte - 1
return weight
def union(self, other: Iterable[Any]) -> "ImmutableBitSet":
"""
Create a new ImmutableBitSet as the union of self and other.
"""
if isinstance(other, ImmutableBitSet):
return ImmutableBitSet(
bytes(b1 | b2 for b1, b2 in zip(self._content, bytes(other)))
)
return ImmutableBitSet(set(self) | set(other))
def __or__(self, other: Set) -> "ImmutableBitSet":
return self.union(other)
def intersection(self, other: Iterable[Any]) -> "ImmutableBitSet":
"""
Create a new ImmutableBitSet as the intersection of self and other.
"""
if isinstance(other, ImmutableBitSet):
return ImmutableBitSet(
bytes(b1 & b2 for b1, b2 in zip(self._content, bytes(other)))
)
return ImmutableBitSet(set(self) & set(other))
def __and__(self, other: Set) -> "ImmutableBitSet":
return self.intersection(other)
def isdisjoint(self, other: Iterable[Any]) -> bool:
"""
Return True if the set has no elements in common with other. Sets are disjoint
if and only if their intersection is the empty set.
"""
return not self.intersection(other)
def issubset(self, other: Iterable[Any]) -> bool:
"""
Test whether every element in the set is in other.
"""
if isinstance(other, ImmutableBitSet):
return all(
b1 == (b1 & b2) for b1, b2 in zip(self._content, bytes(other))
) and len(other) >= len(self)
return set(self) <= set(other)
def __le__(self, other: Set) -> bool:
return self.issubset(other)
def __lt__(self, other: Set) -> bool:
"""
Test whether the set is a proper subset of other, that is, set <= other and set
!= other.
"""
if isinstance(other, ImmutableBitSet):
return (
all(b1 == (b1 & b2) for b1, b2 in zip(self._content, bytes(other)))
and len(other) >= len(self)
and self != other
)
return set(self) < set(other)
def issuperset(self, other: Set) -> bool:
"""
Test whether every element in other is in the set.
"""
if isinstance(other, ImmutableBitSet):
return all(
b2 == (b1 & b2) for b1, b2 in zip(self._content, bytes(other))
) and len(self) >= len(other)
return set(self) >= set(other)
def __ge__(self, other: Set) -> bool:
return self.issuperset(other)
def __gt__(self, other: Set) -> bool:
"""
Test whether the set is a proper superset of other, that is, set >= other and
set != other.
"""
if isinstance(other, ImmutableBitSet):
return (
all(b2 == (b1 & b2) for b1, b2 in zip(self._content, bytes(other)))
and len(self) >= len(other)
and self != other
)
return set(self) > set(other)
def __eq__(self, other: object) -> bool:
return isinstance(other, ImmutableBitSet) and self._content == bytes(other)
def __ne__(self, other: object) -> bool:
return not isinstance(other, ImmutableBitSet) or self._content != bytes(other)
def difference(self, other: Iterable[Any]) -> "ImmutableBitSet":
"""Return a new ImmutableBitSet with elements in either the set or other but not both."""
if isinstance(other, ImmutableBitSet):
return ImmutableBitSet(
bytes(b1 - (b1 & b2) for b1, b2 in zip(self._content, bytes(other)))
)
return ImmutableBitSet(set(self) - set(other))
def __sub__(self, other: Set):
return self.difference(other)
def symmetric_difference(self, other: Iterable[Any]) -> "ImmutableBitSet":
"""Return a new ImmutableBitSet with elements in either the set or other but not both."""
if isinstance(other, ImmutableBitSet):
return ImmutableBitSet(
bytes(b1 ^ b2 for b1, b2 in zip(self._content, bytes(other)))
)
return ImmutableBitSet(set(self) ^ set(other))
def __xor__(self, other: Set):
return self.symmetric_difference(other)
def __bool__(self) -> bool:
return bool(self._content)
def copy(self) -> "ImmutableBitSet":
return ImmutableBitSet(self._content)
#
# Tests
# usage: pytest -v ./immutable_bitset.py
#
def test_immutable_bit_set_with_arbitrary_bits():
"""
Generate 50 random sets of 20 integers and test various assumptions
"""
population = list(range(256))
for _ in range(50):
values = random.choices(population, k=20)
bs = ImmutableBitSet(values)
assert len(bs) == len(
set(values)
), "Bitset should know how many items it contains"
assert (
bs.size == math.ceil((max(values) + 1) / 8) * 8
), "Bitset should use the optimal number of bytes"
for num in range(256):
assert (num in bs) == (
num in values
), f"Expected Bitset to only contain given values"
assert set(bs) == set(
values
), f"Expected Bitset to contain the original values "
def test_immutable_bit_set_with_single_bits():
for num in range(0, 256):
values = [num]
bs = ImmutableBitSet(values)
assert len(bs) == 1, "Bitset should know how many items it contains"
assert (
bs.size == math.ceil((max(values) + 1) / 8) * 8
), "Bitset should use the optimal number of bytes"
assert num in bs, f"Expected Bitset to contain {num}"
assert set(bs) == set(
values
), f"Expected Bitset to contain the original values "
def test_boolean_cast():
assert not len(ImmutableBitSet([]))
assert not ImmutableBitSet([])
assert ImmutableBitSet([42])
def test_set_union_and_intersection():
fizz = set()
buzz = set()
fizzbuzz = set()
for num in range(1, 256):
if not num % 3:
fizz.add(num)
if not num % 5:
buzz.add(num)
if not num % 3:
fizzbuzz.add(num)
fizz_bs = ImmutableBitSet(fizz)
buzz_bs = ImmutableBitSet(buzz)
fizzbuzz_bs = ImmutableBitSet(fizzbuzz)
# sanity checks
assert fizz == set(fizz_bs)
assert buzz == set(buzz_bs)
assert fizzbuzz == set(fizzbuzz_bs)
assert fizz_bs != buzz_bs
assert fizz_bs != fizzbuzz_bs
assert buzz_bs != fizzbuzz_bs
assert fizz & buzz == fizzbuzz, "Set intersection should work as expected"
assert fizz_bs & buzz_bs == fizzbuzz_bs, "Set intersection should work as expected"
assert fizz | buzz == set(fizz_bs) | set(
buzz_bs
), "Set union should work as expected"
assert (
ImmutableBitSet(fizz) | ImmutableBitSet(buzz) == fizz_bs | buzz_bs
), "Set union should work as expected"
def test_set_comparisons():
s1 = {1, 2, 3}
s2 = {3, 4, 5}
s3 = {4, 5, 6}
s4 = {3, 4, "cow"}
s5 = {33, 4, "cow"}
bs1 = ImmutableBitSet(s1)
bs2 = ImmutableBitSet(s2)
bs3 = ImmutableBitSet(s3)
assert_value_error_on_init(lambda: ImmutableBitSet(s4))
assert_value_error_on_init(lambda: ImmutableBitSet(s5))
assert bs1.union(s1) == bs1
assert bs1.union(s2) == ImmutableBitSet(s1 | s2)
assert bs1.union(s3) == ImmutableBitSet(s1 | s3)
assert bs1.union(bs1) == bs1
assert bs1.union(bs2) == ImmutableBitSet(s1 | s2)
assert bs1.union(bs3) == ImmutableBitSet(s1 | s3)
assert_value_error_on_init(lambda: bs1.union(s4))
assert_value_error_on_init(lambda: bs1.union(s5))
assert bs1.intersection(s1) == bs1
assert bs1.intersection(s2) == ImmutableBitSet(s1 & s2)
assert bs1.intersection(s3) == ImmutableBitSet(s1 & s3)
assert bs1.intersection(s4) == ImmutableBitSet(s1 & s4)
assert bs1.intersection(bs1) == bs1
assert bs1.intersection(bs2) == ImmutableBitSet(s1 & s2)
assert bs1.intersection(bs3) == ImmutableBitSet(s1 & s3)
assert bs1.intersection(s4) == ImmutableBitSet([3])
assert bs1.intersection("lol") == ImmutableBitSet(tuple())
assert bs1.isdisjoint(bs3)
assert not bs1.isdisjoint(bs2)
assert bs1.isdisjoint(s3)
assert not bs1.isdisjoint(s2)
assert not bs1.isdisjoint(s4)
assert bs1.isdisjoint(s5)
assert bs1.isdisjoint("lol")
assert not bs1 <= ImmutableBitSet()
assert bs1 <= bs1
assert not bs1.issubset(ImmutableBitSet())
assert not bs1.issubset({1})
assert bs1.issubset(bs1)
assert not bs2.issubset(bs1)
assert bs2.issubset(bs1 | bs3)
assert bs2.issubset({1, 2, 3, 4, 5, 6, "goose!"})
assert not bs2.issubset(bs3)
assert not bs2.issubset("lol")
assert not bs1 < ImmutableBitSet()
assert not bs1 < ImmutableBitSet({0, 2, 3, 4})
assert not bs1 < {0, 2, 3, 4}
assert not bs1 < bs1
assert bs1 < bs1 | bs3
assert bs2 < bs1 | bs3
assert bs1 >= ImmutableBitSet()
assert bs1 >= bs1
assert bs1.issuperset(ImmutableBitSet())
assert bs1.issuperset({1})
assert bs1.issuperset({1, 3})
assert bs1.issuperset(bs1)
assert not bs2.issuperset(bs1)
assert not bs2.issuperset(bs1 | bs3)
assert not bs2.issuperset(bs3)
assert not bs2.issuperset("lol")
assert bs1 > ImmutableBitSet()
assert bs1 > ImmutableBitSet({1, 3})
assert bs1 > {1, 3}
assert not bs1 > ImmutableBitSet({0, 2, 3, 4})
assert not bs1 > bs1
assert not bs1 > bs1 | bs3
assert not bs2 > bs1 | bs3
assert bs1 == ImmutableBitSet({1, 2, 3})
assert bs1 != ImmutableBitSet({1, 3})
assert bs1 != bs2
assert bs1 != bs3
assert bs1 != "lol"
assert bs1 != s1
assert not bs1 ^ bs1
assert bs1 ^ bs2 == ImmutableBitSet((1, 2, 4, 5))
assert not bs1.symmetric_difference(bs1)
assert bs1.symmetric_difference(s2) == ImmutableBitSet((1, 2, 4, 5))
assert bs1.symmetric_difference(bs2) == ImmutableBitSet((1, 2, 4, 5))
assert_value_error_on_init(lambda: bs1.symmetric_difference("lol"))
assert not bs1 - bs1
assert not bs2 - (bs1 | bs3)
assert bs1 - ImmutableBitSet({2}) == ImmutableBitSet({1, 3})
assert not bs1.difference(bs1)
assert not bs2 - (s1 | s3)
assert bs1 - {2} == ImmutableBitSet({1, 3})
def test_contains_non_int():
bs = ImmutableBitSet([0])
assert 0 in bs
assert 30 not in bs
assert "0" not in bs
assert "30" not in bs
assert (0,) not in bs
def test_copy():
bs = ImmutableBitSet([1, 2, 3])
bs_copy = bs.copy()
assert bs == bs_copy
assert bs is not bs_copy
def assert_value_error_on_init(fun: Callable): # pragma: no cover
try:
fun()
assert False, "Instanciating a ImmutableBitSet with non int value should raise"
except ValueError:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment