Skip to content

Instantly share code, notes, and snippets.

@cyprienc
Last active May 6, 2021 14:57
Show Gist options
  • Save cyprienc/5b5909799af150710360eefd5c473568 to your computer and use it in GitHub Desktop.
Save cyprienc/5b5909799af150710360eefd5c473568 to your computer and use it in GitHub Desktop.
MaxSumTree Python
from typing import Optional, TypeVar, Generic
T = TypeVar("T")
K = TypeVar("K")
class Node(Generic[K, T]):
def __init__(self, element: T, key: K, priority: float, parent: Optional["Node"] = None):
self.element: T = element
self.key: K = key
self.priority: float = priority
self.parent: Optional["Node"] = parent
self.left: Optional["Node"] = None
self.right: Optional["Node"] = None
self.max: float = priority
self.sum: float = priority
def add_child(self, node: "Node") -> None:
self.sum += node.sum
if node.priority <= self.priority:
if self.left is None:
self.left = node
node.parent = self
else:
self.left.add_child(node)
else:
if node.priority > self.max:
self.max = node.priority
if self.right is None:
self.right = node
node.parent = self
else:
self.right.add_child(node)
def find(self, key: K) -> Optional["Node"]:
if self.key == key:
return self
if self.left:
found = self.left.find(key)
if found:
return found
if self.right:
found = self.right.find(key)
if found:
return found
return None
def __str__(self):
return f"Node(key={self.key}, priority={self.priority}, element={self.element})"
class MaxSumTree(Generic[K, T]):
def __init__(self):
self.root: Optional[Node[K, T]] = None
def add(self, element: T, key: K, priority: float):
if not self.root:
self.root = Node(element, key, priority)
else:
node = Node(element, key, priority)
self.root.add_child(node)
def remove(self, key: K):
node = self.find(key)
parent = node.parent
while parent is not None:
# remove node.priority from the sum
parent.sum -= node.priority
# if node.priority is the max value (coming from the RHS) then use node's child max
if node.priority == parent.max:
if node.right:
parent.max = node.right.max
elif node.left:
parent.max = node.left.max
parent = parent.parent
if node.right:
node.right.parent = node.parent
substitute_node = node.right
if node.left:
node.right.add_child(node.left)
elif node.left:
node.left.parent = node.parent
substitute_node = node.left
else:
substitute_node = None
if node.parent:
if node == node.parent.left:
node.parent.left = substitute_node
else:
node.parent.right = substitute_node
# Removing refs for garbage collection
node.parent = None
node.left = None
node.right = None
if self.root == node:
self.root = substitute_node
def find(self, key: K):
if not self.root:
raise IndexError()
node = self.root.find(key)
if not node:
raise KeyError()
return node
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment