Skip to content

Instantly share code, notes, and snippets.

@parvezmrobin
Last active November 25, 2022 03:18
Show Gist options
  • Save parvezmrobin/771ee424044676b1fea4fbbbe4665cd4 to your computer and use it in GitHub Desktop.
Save parvezmrobin/771ee424044676b1fea4fbbbe4665cd4 to your computer and use it in GitHub Desktop.
Simply and efficient python implementation of bit-array
from __future__ import annotations
class BitArray(bytearray):
mask_for = [2**i for i in range(8)]
inverse_mask_for = [256 - 2**i for i in range(8)]
valid_values = (0, 1)
def __init__(self, source, *args, **kwargs) -> None:
if isinstance(source, int):
source = source/8
if source % 1 != 0:
source = int(source) + 1
super().__init__(source, *args, **kwargs)
def fill(self, val):
assert val in self.valid_values
val = 0 if val == 0 else 255
for byte_idx in range(len(self)):
super().__setitem__(byte_idx, val)
return self
def _assert_key(self, key):
if not isinstance(key, int) and not isinstance(key, slice):
raise KeyError(f'{self.__name__} only support int and slices as index')
if isinstance(key, int):
if key < 0:
raise KeyError('Only positive keys are supported')
if key >= len(self) * 8:
raise IndexError(key)
@staticmethod
def _get_byte_index_and_offset(key: slice):
# 7th bit in 0th byte, 8/9th bit in 1st byte
start_byte_idx = key.start // 8
# 23rd bit in 2nd byte, 24/25th bit in 3rd byte
# also exclude the last index
stop_byte_idx = ((key.stop - 1) // 8) + 1
# if key.start is 10, then ignore first two values from first byte
start_offset = key.start % 8
# if key.stop is 19, then ignore last 5 values from last byte
stop_offset = 8 - (key.stop % 8)
return start_byte_idx, start_offset, stop_byte_idx, stop_offset
def __getitem__(self, key) -> int | list[int]:
if isinstance(key, slice):
start_byte_idx, start_offset, stop_byte_idx, stop_offset = self._get_byte_index_and_offset(key)
byte_list = super().__getitem__(slice(start_byte_idx, stop_byte_idx))
byte_value_list = [
int(bool(byte & self.mask_for[i]))
for byte in byte_list
for i in range(8)
]
return byte_value_list[start_offset: -stop_offset:key.step]
self._assert_key(key)
byte_idx = key // 8
byte = super().__getitem__(byte_idx)
idx = key % 8
val = byte & self.mask_for[idx]
return int(bool(val))
def __setitem__(self, key, val):
self._assert_key(key)
if isinstance(key, slice):
val_iter = iter(val)
start_byte_idx, start_offset, stop_byte_idx, stop_offset = self._get_byte_index_and_offset(key)
byte_list = super().__getitem__(slice(start_byte_idx, stop_byte_idx))
stop_offset = (stop_byte_idx - start_byte_idx) * 8 - stop_offset
for byte_idx in range(len(byte_list)):
for bit_idx in range(8):
val_idx = byte_idx * 8 + bit_idx
if val_idx < start_offset or val_idx >= stop_offset:
continue
if key.step is not None and (val_idx - start_offset) % key.step != 0:
continue
next_val = next(val_iter)
assert next_val in self.valid_values
if next_val == 0:
byte_list[byte_idx] = byte_list[byte_idx] & self.inverse_mask_for[bit_idx]
else:
byte_list[byte_idx] = byte_list[byte_idx] | self.mask_for[bit_idx]
super().__setitem__(slice(start_byte_idx, stop_byte_idx), byte_list)
return
if val not in self.valid_values:
raise ValueError(f"You want to put {val} in a BitArray!")
byte_idx = key // 8
byte = super().__getitem__(byte_idx)
bit_idx = key % 8
if val == 0:
new_byte = byte & self.inverse_mask_for[bit_idx]
else:
new_byte = byte | self.mask_for[bit_idx]
super().__setitem__(byte_idx, new_byte)
array = BitArray(20)
array[1] = 1
array[3] = 1
array[11] = 1
print(array[1:13:2]) # [1, 1, 0, 0, 0, 1]
array[1:7] = [1] * 6
print(array[1:13]) # [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0]
array[11:14:2] = [1,1,1,1]
print(array[1:15]) #[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment