Skip to content

Instantly share code, notes, and snippets.

@bananabrick
Created August 7, 2020 16:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bananabrick/a3676673ab072c1e6fe4fb3852ecbdff to your computer and use it in GitHub Desktop.
Save bananabrick/a3676673ab072c1e6fe4fb3852ecbdff to your computer and use it in GitHub Desktop.
import sys
from collections import defaultdict
from functools import lru_cache
from collections import Counter
def mi(s):
return map(int, s.strip().split())
def lmi(s):
return list(mi(s))
def mf(f, s):
return map(f, s)
def lmf(f, s):
return list(mf(f, s))
class BSTNode(object):
def __init__(self, val):
self.val = val
self.disconnect()
def disconnect(self):
self.left = None
self.right = None
self.parent = None
class BST(object):
Node = BSTNode
def __init__(self):
self.root = None
@staticmethod
def minimum(node):
"""Finds and returns the minimum node
in the sub-tree rooted at 'node'.
"""
if node.left is None:
return node
else:
return BST.minimum(node.left)
def __str__(self):
"""
Prints the BST in a visually pleasing format.
Code for printing was acquried from MIT 6.006.
"""
if self.root is None: return '<empty tree>'
def recurse(node):
if node is None: return [], 0, 0
label = str(node.val)
left_lines, left_pos, left_width = recurse(node.left)
right_lines, right_pos, right_width = recurse(node.right)
middle = max(right_pos + left_width - left_pos + 1, len(label), 2)
pos = left_pos + middle // 2
width = left_pos + middle + right_width - right_pos
while len(left_lines) < len(right_lines):
left_lines.append(' ' * left_width)
while len(right_lines) < len(left_lines):
right_lines.append(' ' * right_width)
if (middle - len(label)) % 2 == 1 and node.parent is not None and \
node is node.parent.left and len(label) < middle:
label += '.'
label = label.center(middle, '.')
if label[0] == '.': label = ' ' + label[1:]
if label[-1] == '.': label = label[:-1] + ' '
lines = [' ' * left_pos + label + ' ' * (right_width - right_pos),
' ' * left_pos + '/' + ' ' * (middle-2) +
'\\' + ' ' * (right_width - right_pos)] + \
[left_line + ' ' * (width - left_width - right_width) +
right_line
for left_line, right_line in zip(left_lines, right_lines)]
return lines, pos, width
return '\n'.join(recurse(self.root) [0])
def find(self, val):
"""Searches for val in the BST.
If found, returns the node, else
return None.
"""
def rec_find(node):
if not node:
return None
elif node.val == val:
return node
if node.val <= val:
return rec_find(node.right)
else:
return rec_find(node.left)
return rec_find(self.root)
def __contains__(self, val):
return bool(self.find(val))
def insert(self, val):
"""Inserts the val into the BST.
Returns the newly inserted node.
"""
def rec_insert(node):
if node.val <= val:
if not node.right:
node.right = new_node
new_node.parent = node
else:
rec_insert(node.right)
else:
if not node.left:
node.left = new_node
new_node.parent = node
else:
rec_insert(node.left)
new_node = self.Node(val)
if not self.root:
self.root = new_node
else:
rec_insert(self.root)
return new_node
def delete(self, val):
"""Deletes the node corresposnding to val
from the BST.
Returns a (deleted node, parent node) tuple."""
def fix_attrs(node, new_node):
if node.parent:
if node is node.parent.left:
node.parent.left = new_node
else:
node.parent.right = new_node
else:
self.root = new_node
if new_node:
new_node.parent = node.parent
def delete_node(node):
if node.left and node.right:
right_minimum = self.minimum(node.right)
right_minimum.val, node.val = node.val, right_minimum.val
return delete_node(right_minimum)
elif node.left:
new_node = node.left
else:
new_node = node.right
fix_attrs(node, new_node)
parent = node.parent
node.disconnect()
return node, parent
to_delete = self.find(val)
if not to_delete:
return (None, None)
else:
return delete_node(to_delete)
def height(node):
if not node:
return -1
else:
return node.height
def update_height(node):
node.height = 1 + max(height(node.left), height(node.right))
def balanced(node):
"""Returns True if the AVL property is maintained
for the given node.
"""
return abs(height(node.left) - height(node.right)) <= 1
class AVLNode(BSTNode):
"""An AVLNode is a BSTNode with a height attribute."""
def __init__(self, val):
super(AVLNode, self).__init__(val)
self.height = 0
class AVL(BST):
"""A BST which maintains O(log(n)) height using by maintaining
the AVL property.
"""
Node = AVLNode
def rotate_left(self, node):
right_node = node.right
# Left rotate is not possible
# when node.right is None.
if not right_node:
raise ValueError("Left rotate not possible.")
if node.parent:
if node is node.parent.left:
node.parent.left = right_node
else:
node.parent.right = right_node
else:
self.root = right_node
node.parent, right_node.parent = right_node, node.parent
right_left_node = right_node.left
if right_left_node:
right_left_node.parent = node
right_node.left = node
node.right = right_left_node
update_height(node)
update_height(right_node)
return node, right_node
def rotate_right(self, node):
left_node = node.left
# Right rotate is not possible
# when node.left is None.
if not left_node:
raise ValueError("Right rotate not possible.")
if node.parent:
if node is node.parent.left:
node.parent.left = left_node
else:
node.parent.right = left_node
else:
self.root = left_node
node.parent, left_node.parent = left_node, node.parent
left_right_node = left_node.right
if left_right_node:
left_right_node.parent = node
left_node.right = node
node.left = left_right_node
update_height(node)
update_height(left_node)
return node, left_node
def fix_one_imbalance(self, node):
"""Fixes an AVL imbalance for the given node.
We assume that this function is only called when an imbalance exists."""
left_heavy = height(node.left) > height(node.right)
if left_heavy:
left_left_heavy = height(node.left.left) >= height(node.left.right)
if left_left_heavy:
self.rotate_right(node)
else:
self.rotate_left(node.left)
self.rotate_right(node)
else:
right_right_heavy = height(node.right.right) >= height(node.right.left)
if right_right_heavy:
self.rotate_left(node)
else:
self.rotate_right(node.right)
self.rotate_left(node)
def fix_tree_imbalance(self, node):
if node is None:
return
parent = node.parent
update_height(node)
if not balanced(node):
self.fix_one_imbalance(node)
self.fix_tree_imbalance(parent)
def insert(self, val):
"""Overrides the insert method of BST.
Updates heights and rotates.
"""
new_node = super(AVL, self).insert(val)
self.fix_tree_imbalance(new_node)
return new_node
def delete(self, val):
"""Overrides the delete method of BST.
Updates heights and rotates.
"""
deleted, parent = super(AVL, self).delete(val)
if deleted and parent:
self.fix_tree_imbalance(parent)
return deleted, parent
def size(node):
if not node:
return 0
else:
return node.size
def update_size(node):
node.size = 1 + size(node.left) + size(node.right)
def largest(node, parent=None):
'''
Returns the largest node in the tree
given the node.
'''
if not node:
return None
while node.right:
parent = node
node = node.right
return node, parent
def get_largest(tree):
'''
Gets the 3 highest frequency nodes in the tree.
'''
if tree.root is None:
return []
node, parent = largest(tree.root)
ans = [node.val]
if node.left:
second_largest, second_parent = largest(node.left, node)
ans.append(second_largest.val)
if second_largest.left:
third_largest, third_parent = largest(second_largest.left, second_largest)
ans.append(third_largest.val)
elif second_parent and second_parent.val not in ans:
ans.append(second_parent.val)
else:
if parent:
ans.append(parent.val)
# third largest?
if parent.left:
third_largest, third_parent = largest(parent.left, parent)
ans.append(third_largest.val)
elif parent.parent and parent.parent.val not in ans:
ans.append(parent.parent.val)
return ans
def main(curr, ops):
counts = Counter(curr)
tree = AVL()
# Node name to its node in the actual tree.
node_map = {}
for k in counts:
new_node = tree.insert((counts[k], k))
node_map[k] = new_node
for op, n in ops:
if op == '+':
if n in node_map:
old_node = node_map[n]
prev_count, _ = old_node.val
tree.delete((prev_count, n))
new_node = tree.insert((prev_count + 1, n))
node_map[n] = new_node
else:
new_node = tree.insert((1, n))
node_map[n] = new_node
else:
prev_count, n = node_map[n].val
tree.delete((prev_count, n))
if prev_count > 1:
new_node = tree.insert((prev_count - 1, n))
node_map[n] = new_node
largest = get_largest(tree)
if largest[0][0] >= 8:
print("YES")
elif len(largest) >= 2 and largest[0][0] >= 6 and largest[1][0] >= 2:
print("YES")
elif len(largest) >= 2 and largest[0][0] >= 4 and largest[1][0] >= 4:
print("YES")
elif len(largest) >= 3 and largest[0][0] >= 4 and largest[1][0] >= 2 and largest[2][0] >= 2:
print("YES")
else:
print("NO")
if __name__ == "__main__":
curr = []
ops = []
for e, line in enumerate(sys.stdin.readlines()):
if e == 0:
continue
elif e == 1:
curr = lmi(line)
elif e == 2:
continue
else:
op, n = line.strip().split()
ops.append((op, int(n)))
main(curr, ops)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment