Last active
May 6, 2021 14:57
-
-
Save cyprienc/5b5909799af150710360eefd5c473568 to your computer and use it in GitHub Desktop.
MaxSumTree Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, TypeVar, Generic | |
T = TypeVar("T") | |
K = TypeVar("K") | |
class Node(Generic[K, T]): | |
def __init__(self, element: T, key: K, priority: float, parent: Optional["Node"] = None): | |
self.element: T = element | |
self.key: K = key | |
self.priority: float = priority | |
self.parent: Optional["Node"] = parent | |
self.left: Optional["Node"] = None | |
self.right: Optional["Node"] = None | |
self.max: float = priority | |
self.sum: float = priority | |
def add_child(self, node: "Node") -> None: | |
self.sum += node.sum | |
if node.priority <= self.priority: | |
if self.left is None: | |
self.left = node | |
node.parent = self | |
else: | |
self.left.add_child(node) | |
else: | |
if node.priority > self.max: | |
self.max = node.priority | |
if self.right is None: | |
self.right = node | |
node.parent = self | |
else: | |
self.right.add_child(node) | |
def find(self, key: K) -> Optional["Node"]: | |
if self.key == key: | |
return self | |
if self.left: | |
found = self.left.find(key) | |
if found: | |
return found | |
if self.right: | |
found = self.right.find(key) | |
if found: | |
return found | |
return None | |
def __str__(self): | |
return f"Node(key={self.key}, priority={self.priority}, element={self.element})" | |
class MaxSumTree(Generic[K, T]): | |
def __init__(self): | |
self.root: Optional[Node[K, T]] = None | |
def add(self, element: T, key: K, priority: float): | |
if not self.root: | |
self.root = Node(element, key, priority) | |
else: | |
node = Node(element, key, priority) | |
self.root.add_child(node) | |
def remove(self, key: K): | |
node = self.find(key) | |
parent = node.parent | |
while parent is not None: | |
# remove node.priority from the sum | |
parent.sum -= node.priority | |
# if node.priority is the max value (coming from the RHS) then use node's child max | |
if node.priority == parent.max: | |
if node.right: | |
parent.max = node.right.max | |
elif node.left: | |
parent.max = node.left.max | |
parent = parent.parent | |
if node.right: | |
node.right.parent = node.parent | |
substitute_node = node.right | |
if node.left: | |
node.right.add_child(node.left) | |
elif node.left: | |
node.left.parent = node.parent | |
substitute_node = node.left | |
else: | |
substitute_node = None | |
if node.parent: | |
if node == node.parent.left: | |
node.parent.left = substitute_node | |
else: | |
node.parent.right = substitute_node | |
# Removing refs for garbage collection | |
node.parent = None | |
node.left = None | |
node.right = None | |
if self.root == node: | |
self.root = substitute_node | |
def find(self, key: K): | |
if not self.root: | |
raise IndexError() | |
node = self.root.find(key) | |
if not node: | |
raise KeyError() | |
return node |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment