Created
September 15, 2020 19:26
-
-
Save exyi/e275844121f0a7d37ede8b4e4fdf5773 to your computer and use it in GitHub Desktop.
Intervaláče
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from dataclasses import dataclass | |
from typing import * | |
import unittest | |
@dataclass | |
class Node: | |
interval: Tuple[int, int] | |
value: float | |
left: Optional[Node] | |
right: Optional[Node] | |
def build_node(values: List[float], index: int = 0): | |
if len(values) == 1: | |
return Node( | |
(index, index + 1), | |
values[0], | |
None, | |
None | |
) | |
else: | |
splitIndex = len(values) // 2 | |
left = build_node(values[0 : splitIndex], index) | |
right = build_node(values[splitIndex : ], index + splitIndex) | |
return Node( | |
(index, index + len(values)), | |
min(left.value, right.value), | |
left, | |
right | |
) | |
def query(n: Node, start: int, end: int): | |
(node_start, node_end) = n.interval | |
end = min(node_end, end) | |
start = max(node_start, start) | |
if start >= end: | |
return None | |
if end == node_end and start == node_start: | |
return n.value | |
assert n.left is not None and n.right is not None | |
left_q = query(n.left, start, end) | |
right_q = query(n.right, start, end) | |
assert right_q is not None or left_q is not None | |
if left_q is None: | |
return right_q | |
if right_q is None: | |
return left_q | |
return min(left_q, right_q) | |
def update(n: Node, index: int, val: float): | |
if n.interval == (index, index + 1): | |
n.value = val | |
else: | |
assert n.left is not None and n.right is not None | |
(_, split_index) = n.left.interval | |
if index < split_index: | |
update(n.left, index, val) | |
else: | |
update(n.right, index, val) | |
n.value = min(n.left.value, n.right.value) | |
class Tests(unittest.TestCase): | |
def test_build(self): | |
tree = build_node([3, 5, 2, 3]) | |
self.assertEqual(tree.interval, (0, 4)) | |
self.assertEqual(tree.right.interval, (2, 4)) | |
self.assertEqual(tree.left.interval, (0, 2)) | |
self.assertEqual(tree.left.interval, (0, 2)) | |
self.assertEqual(tree.left.value, 3) | |
self.assertEqual(tree.right.value, 2) | |
self.assertEqual(tree.value, 2) | |
def test_query(self): | |
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | |
tree = build_node([88, 26, 51, 74, 28, 81, 19, 45, 41, 84, 89, 94, 24, 19, 87, 24, 38, 67, 37, 44]) | |
self.assertEqual(query(tree, 0, 1), 88) | |
self.assertEqual(query(tree, 0, 2), 26) | |
self.assertEqual(query(tree, 11, 15), 19) | |
self.assertEqual(query(tree, 6, 13), 19) | |
self.assertEqual(query(tree, 7, 13), 24) | |
# def test_update(self): | |
# tree = IntervalTree(lambda a, b: a + b, [0 for _ in range(20)]) | |
# self.assertEqual(tree.query(0, 20), 0) | |
# tree.update(10, 1) | |
# tree.update(15, 2) | |
# self.assertEqual(tree.query(0, 20), 3) | |
# self.assertEqual(tree.query(0, 12), 1) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from dataclasses import dataclass | |
from typing import * | |
import unittest | |
@dataclass | |
class Node: | |
interval: Tuple[int, int] | |
value: float | |
left: Optional[Node] | |
right: Optional[Node] | |
class IntervalTree: | |
def __init__(self, fn, values: List[float]): | |
self.fn = fn | |
self.root = self.build_node(0, values) | |
def build_node(self: IntervalTree, index: int, values: List[float]): | |
if len(values) == 1: | |
return Node( | |
(index, index + 1), | |
values[0], | |
None, | |
None | |
) | |
else: | |
splitIndex = len(values) // 2 | |
left = self.build_node(index, values[0 : splitIndex]) | |
right = self.build_node(index + splitIndex, values[splitIndex : ]) | |
return Node( | |
(index, index + len(values)), | |
self.fn(left.value, right.value), | |
left, | |
right | |
) | |
def query(self, start: int, end: int): | |
def query_node(n: Node, start:int, end:int): | |
(node_start, node_end) = n.interval | |
end = min(node_end, end) | |
start = max(node_start, start) | |
if start >= end: | |
return None | |
if end == node_end and start == node_start: | |
return n.value | |
assert n.left is not None and n.right is not None | |
left_q = query_node(n.left, start, end) | |
right_q = query_node(n.right, start, end) | |
assert right_q is not None or left_q is not None | |
if left_q is None: | |
return right_q | |
if right_q is None: | |
return left_q | |
return self.fn(left_q, right_q) | |
return query_node(self.root, start, end) | |
def update(self, index: int, val: float): | |
def update_node(n: Node): | |
if n.interval == (index, index + 1): | |
n.value = val | |
else: | |
assert n.left is not None and n.right is not None | |
(_, split_index) = n.left.interval | |
if index < split_index: | |
update_node(n.left) | |
else: | |
update_node(n.right) | |
n.value = self.fn(n.left.value, n.right.value) | |
return update_node(self.root) | |
class Tests(unittest.TestCase): | |
def test_build(self): | |
tree = IntervalTree(min, [3, 5, 2, 3]) | |
self.assertEqual(tree.root.interval, (0, 4)) | |
self.assertEqual(tree.root.right.interval, (2, 4)) | |
self.assertEqual(tree.root.left.interval, (0, 2)) | |
self.assertEqual(tree.root.left.interval, (0, 2)) | |
self.assertEqual(tree.root.left.value, 3) | |
self.assertEqual(tree.root.right.value, 2) | |
self.assertEqual(tree.root.value, 2) | |
def test_query(self): | |
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | |
tree = IntervalTree(min, [88, 26, 51, 74, 28, 81, 19, 45, 41, 84, 89, 94, 24, 19, 87, 24, 38, 67, 37, 44]) | |
self.assertEqual(tree.query(0, 1), 88) | |
self.assertEqual(tree.query(0, 2), 26) | |
self.assertEqual(tree.query(11, 15), 19) | |
self.assertEqual(tree.query(6, 13), 19) | |
self.assertEqual(tree.query(7, 13), 24) | |
def test_update(self): | |
tree = IntervalTree(lambda a, b: a + b, [0 for _ in range(20)]) | |
self.assertEqual(tree.query(0, 20), 0) | |
tree.update(10, 1) | |
tree.update(15, 2) | |
self.assertEqual(tree.query(0, 20), 3) | |
self.assertEqual(tree.query(0, 12), 1) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment