Skip to content

Instantly share code, notes, and snippets.

@vxgmichel
Created May 18, 2020 21:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vxgmichel/03e7df1e35c1fd34dd871456e0217da3 to your computer and use it in GitHub Desktop.
Save vxgmichel/03e7df1e35c1fd34dd871456e0217da3 to your computer and use it in GitHub Desktop.
Fenwick tree - performing both fast sum and fast update on a list of number
"""
Fenwick tree - performing both fast sum and fast update on a list of number
This implementation is mostly compatible with regular list operations.
Time complexity table:
| Operation | List | Fenwick tree |
|-------------|------|--------------|
| Instantiate | O(n) | O(n) |
| Append | O(1) | O(1) |
| Get item | O(1) | O(log k) |
| Set item | O(1) | O(log n) |
| Sum a slice | O(k) | O(log k) |
| Delete item | O(n) | O(n log n) |
| Insert | O(n) | O(n log n) |
"""
from collections.abc import MutableSequence
# Base operations
def fenwick_sum(data, index):
assert 0 <= index < len(data)
index += 1
result = 0
while index != 0:
result += data[index - 1]
index -= index & -index
return result
def fenwick_add(data, index, value):
assert 0 <= index < len(data)
index += 1
while index <= len(data):
data[index - 1] += value
index += index & -index
def fenwick_append(data, value):
index = len(data)
mask = 1
while index & mask:
value += data[index - 1]
index -= mask
mask <<= 1
data.append(value)
# Slice operations
def fenwick_slice(data, start, stop):
stop = min(stop, len(data))
if start >= stop:
return []
assert 0 <= start < stop <= len(data)
result = []
previous_sum = fenwick_sum(data, start - 1) if start > 0 else 0
for index in range(start, stop):
next_sum = fenwick_sum(data, index)
result.append(next_sum - previous_sum)
previous_sum = next_sum
return result
def fenwick_sumslice(data, start, stop):
stop = min(stop, len(data))
if start >= stop:
return 0
assert 0 <= start < stop <= len(data)
result = fenwick_sum(data, stop - 1)
if start > 0:
result -= fenwick_sum(data, start - 1)
return result
def fenwick_delslice(data, start, stop):
stop = min(stop, len(data))
if start >= stop:
return
assert 0 <= start < stop <= len(data)
extra = fenwick_slice(data, stop, len(data))
del data[start:]
for value in extra:
fenwick_append(data, value)
# Access operations
def fenwick_get(data, index):
if not 0 <= index < len(data):
raise IndexError(index)
return fenwick_sumslice(data, index, index + 1)
def fenwick_set(data, index, value):
if not 0 <= index < len(data):
raise IndexError(index)
current = fenwick_get(data, index)
fenwick_add(data, index, value - current)
def fenwick_del(data, index):
if not 0 <= index < len(data):
raise IndexError(index)
return fenwick_delslice(data, index, index + 1)
def fenwick_insert(data, index, value):
assert 0 <= index
extra = fenwick_slice(data, index, len(data))
del data[index:]
fenwick_append(data, value)
for value in extra:
fenwick_append(data, value)
class Fenwick(MutableSequence):
def __init__(self, iterable=()):
self._data = []
for value in iterable:
fenwick_append(self._data, value)
def __len__(self):
return len(self._data)
def __getitem__(self, index):
if not isinstance(index, slice):
return fenwick_get(self._data, index)
assert index.step in (1, None)
return fenwick_slice(self._data, index.start, index.stop)
def __setitem__(self, index, value):
if not isinstance(index, slice):
return fenwick_set(self._data, index, value)
indexes = range(*index.indices(len(self._data)))
assert len(value) == len(indexes)
for index, value in zip(indexes, value):
fenwick_set(self._data, index, value)
def __delitem__(self, index):
if not isinstance(index, slice):
return fenwick_del(self._data, index)
assert index.step in (1, None)
return fenwick_delslice(self._data, index.start, index.stop)
def insert(self, index, value):
return fenwick_insert(self._data, index, value)
def sum(self, start, stop=None):
if stop is None:
start, stop = 0, start
return fenwick_sumslice(self._data, start, stop)
def test():
lst = list(range(1, 10))
fen = Fenwick(lst)
assert list(fen) == lst
lst[3] = 10
fen[3] = 10
assert list(fen) == lst
del lst[3]
del fen[3]
assert list(fen) == lst
lst.insert(3, 11)
fen.insert(3, 11)
assert list(fen) == lst
lst.insert(100, 12)
fen.insert(100, 12)
assert list(fen) == lst
lst[1:3] = [13, 14]
fen[1:3] = [13, 14]
assert list(fen) == lst
lst += [15, 16]
fen += [15, 16]
assert list(fen) == lst
del lst[1:3]
del fen[1:3]
assert list(fen) == lst
for i in range(len(lst)):
for j in range(len(lst)):
assert fen[i:j] == lst[i:j]
assert fen.sum(i, j) == sum(lst[i:j])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment