Skip to content

Instantly share code, notes, and snippets.

@exyi
Created September 15, 2020 19:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save exyi/e275844121f0a7d37ede8b4e4fdf5773 to your computer and use it in GitHub Desktop.
Save exyi/e275844121f0a7d37ede8b4e4fdf5773 to your computer and use it in GitHub Desktop.
Intervaláče
from __future__ import annotations
from dataclasses import dataclass
from typing import *
import unittest
@dataclass
class Node:
interval: Tuple[int, int]
value: float
left: Optional[Node]
right: Optional[Node]
def build_node(values: List[float], index: int = 0):
if len(values) == 1:
return Node(
(index, index + 1),
values[0],
None,
None
)
else:
splitIndex = len(values) // 2
left = build_node(values[0 : splitIndex], index)
right = build_node(values[splitIndex : ], index + splitIndex)
return Node(
(index, index + len(values)),
min(left.value, right.value),
left,
right
)
def query(n: Node, start: int, end: int):
(node_start, node_end) = n.interval
end = min(node_end, end)
start = max(node_start, start)
if start >= end:
return None
if end == node_end and start == node_start:
return n.value
assert n.left is not None and n.right is not None
left_q = query(n.left, start, end)
right_q = query(n.right, start, end)
assert right_q is not None or left_q is not None
if left_q is None:
return right_q
if right_q is None:
return left_q
return min(left_q, right_q)
def update(n: Node, index: int, val: float):
if n.interval == (index, index + 1):
n.value = val
else:
assert n.left is not None and n.right is not None
(_, split_index) = n.left.interval
if index < split_index:
update(n.left, index, val)
else:
update(n.right, index, val)
n.value = min(n.left.value, n.right.value)
class Tests(unittest.TestCase):
def test_build(self):
tree = build_node([3, 5, 2, 3])
self.assertEqual(tree.interval, (0, 4))
self.assertEqual(tree.right.interval, (2, 4))
self.assertEqual(tree.left.interval, (0, 2))
self.assertEqual(tree.left.interval, (0, 2))
self.assertEqual(tree.left.value, 3)
self.assertEqual(tree.right.value, 2)
self.assertEqual(tree.value, 2)
def test_query(self):
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
tree = build_node([88, 26, 51, 74, 28, 81, 19, 45, 41, 84, 89, 94, 24, 19, 87, 24, 38, 67, 37, 44])
self.assertEqual(query(tree, 0, 1), 88)
self.assertEqual(query(tree, 0, 2), 26)
self.assertEqual(query(tree, 11, 15), 19)
self.assertEqual(query(tree, 6, 13), 19)
self.assertEqual(query(tree, 7, 13), 24)
# def test_update(self):
# tree = IntervalTree(lambda a, b: a + b, [0 for _ in range(20)])
# self.assertEqual(tree.query(0, 20), 0)
# tree.update(10, 1)
# tree.update(15, 2)
# self.assertEqual(tree.query(0, 20), 3)
# self.assertEqual(tree.query(0, 12), 1)
if __name__ == '__main__':
unittest.main()
from __future__ import annotations
from dataclasses import dataclass
from typing import *
import unittest
@dataclass
class Node:
interval: Tuple[int, int]
value: float
left: Optional[Node]
right: Optional[Node]
class IntervalTree:
def __init__(self, fn, values: List[float]):
self.fn = fn
self.root = self.build_node(0, values)
def build_node(self: IntervalTree, index: int, values: List[float]):
if len(values) == 1:
return Node(
(index, index + 1),
values[0],
None,
None
)
else:
splitIndex = len(values) // 2
left = self.build_node(index, values[0 : splitIndex])
right = self.build_node(index + splitIndex, values[splitIndex : ])
return Node(
(index, index + len(values)),
self.fn(left.value, right.value),
left,
right
)
def query(self, start: int, end: int):
def query_node(n: Node, start:int, end:int):
(node_start, node_end) = n.interval
end = min(node_end, end)
start = max(node_start, start)
if start >= end:
return None
if end == node_end and start == node_start:
return n.value
assert n.left is not None and n.right is not None
left_q = query_node(n.left, start, end)
right_q = query_node(n.right, start, end)
assert right_q is not None or left_q is not None
if left_q is None:
return right_q
if right_q is None:
return left_q
return self.fn(left_q, right_q)
return query_node(self.root, start, end)
def update(self, index: int, val: float):
def update_node(n: Node):
if n.interval == (index, index + 1):
n.value = val
else:
assert n.left is not None and n.right is not None
(_, split_index) = n.left.interval
if index < split_index:
update_node(n.left)
else:
update_node(n.right)
n.value = self.fn(n.left.value, n.right.value)
return update_node(self.root)
class Tests(unittest.TestCase):
def test_build(self):
tree = IntervalTree(min, [3, 5, 2, 3])
self.assertEqual(tree.root.interval, (0, 4))
self.assertEqual(tree.root.right.interval, (2, 4))
self.assertEqual(tree.root.left.interval, (0, 2))
self.assertEqual(tree.root.left.interval, (0, 2))
self.assertEqual(tree.root.left.value, 3)
self.assertEqual(tree.root.right.value, 2)
self.assertEqual(tree.root.value, 2)
def test_query(self):
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
tree = IntervalTree(min, [88, 26, 51, 74, 28, 81, 19, 45, 41, 84, 89, 94, 24, 19, 87, 24, 38, 67, 37, 44])
self.assertEqual(tree.query(0, 1), 88)
self.assertEqual(tree.query(0, 2), 26)
self.assertEqual(tree.query(11, 15), 19)
self.assertEqual(tree.query(6, 13), 19)
self.assertEqual(tree.query(7, 13), 24)
def test_update(self):
tree = IntervalTree(lambda a, b: a + b, [0 for _ in range(20)])
self.assertEqual(tree.query(0, 20), 0)
tree.update(10, 1)
tree.update(15, 2)
self.assertEqual(tree.query(0, 20), 3)
self.assertEqual(tree.query(0, 12), 1)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment