Skip to content

Instantly share code, notes, and snippets.

@mzaks
Created August 19, 2023 17:59
Show Gist options
  • Save mzaks/821b6128da63f4db6590b6401b25fca7 to your computer and use it in GitHub Desktop.
Save mzaks/821b6128da63f4db6590b6401b25fca7 to your computer and use it in GitHub Desktop.
FibyTree Mojo
from Bit import bit_length
from String import String
from Vector import DynamicVector, UnsafeFixedVector
from List import VariadicList
struct FibyTree[T: AnyType, cmp: fn(T, T)->Int, to_str: fn(T) -> String]:
alias Union = 0
alias Intersection = 1
alias Difference = 2
alias SymetricDifference = 3
alias OtherDifference = 4
alias IsDisjoint = 5
alias IsSubset = 6
alias IsSuperset = 7
var elements: DynamicVector[T]
var left: DynamicVector[UInt16]
var right: DynamicVector[UInt16]
var deleted: Int
var max_depth: UInt16
var balanced: Bool
fn __init__(inout self, *elements: T):
self.elements = DynamicVector[T]()
self.left = DynamicVector[UInt16]()
self.right = DynamicVector[UInt16]()
self.deleted = 0
self.max_depth = 0
self.balanced = False
let elements_list: VariadicList[T] = elements
for i in range(len(elements_list)):
self.add(elements[i])
fn __moveinit__(inout self, owned existing: Self):
self.elements = existing.elements
self.left = existing.left
self.right = existing.right
self.deleted = existing.deleted
self.max_depth = existing.max_depth
self.balanced = existing.balanced
@always_inline("nodebug")
fn has_left(self, parent: UInt16) -> Bool:
return (self.left[parent.to_int()] != parent).__bool__()
@always_inline("nodebug")
fn has_right(self, parent: UInt16) -> Bool:
return (self.right[parent.to_int()] != parent).__bool__()
fn add(inout self, element: T):
if self.__len__() == 0:
self._set_root(element)
self._set_max_depth(1)
return
var parent = 0
var depth: UInt16 = 1
while True:
let diff = cmp(self.elements[parent], element)
if diff == 0:
return
depth += 1
if diff > 0:
let left = self.left[parent].to_int()
if left == parent:
self._add_left(parent, element)
break
else:
parent = left
else:
let right = self.right[parent].to_int()
if right == parent:
self._add_right(parent, element)
break
else:
parent = right
self.balanced = False
self._set_max_depth(depth)
if self.max_depth > self._optimal_depth() + 63:
self.balance()
@always_inline("nodebug")
fn _set_max_depth(inout self, candidate: UInt16):
if self.max_depth < candidate:
self.max_depth = candidate
fn _optimal_depth(self) -> UInt16:
return bit_length(UInt16(self.__len__()))
@always_inline("nodebug")
fn _set_root(inout self, element: T):
if len(self.elements) == 0:
self.elements.push_back(element)
self.left.push_back(0)
self.right.push_back(0)
else:
self.elements[0] = element
self.left[0] = 0
self.right[0] = 0
if self.deleted > 0:
self.deleted -= 1
@always_inline("nodebug")
fn _add_left(inout self, parent: UInt16, element: T):
let index = len(self.elements)
self.elements.push_back(element)
self.left.push_back(index)
self.right.push_back(index)
self.left[parent.to_int()] = index
@always_inline("nodebug")
fn _add_right(inout self, parent: UInt16, element: T):
let index = len(self.elements)
self.elements.push_back(element)
self.left.push_back(index)
self.right.push_back(index)
self.right[parent.to_int()] = index
@always_inline("nodebug")
fn delete(inout self, element: T) -> Bool:
let index_tuple = self._get_index(element)
let parent = index_tuple.get[0, Int]()
let index = index_tuple.get[1, Int]()
if index == -1:
return False
self.balanced = False
if self._is_leaf(index):
self._delete_leaf(index, parent)
return True
if self.has_left(index) and not self.has_right(index):
if index == 0:
let left = self.left[0]
self.elements[0] = self.elements[left.to_int()]
if self.has_left(left):
self.left[0] = self.left[left.to_int()]
else:
self.left[0] = 0
if self.has_right(left):
self.right[0] = self.right[left.to_int()]
else:
self.right[0] = 0
else:
if self.left[parent] == index:
self.left[parent] = self.left[index]
else:
self.right[parent] = self.left[index]
self.deleted += 1
return True
if self.has_right(index) and not self.has_left(index):
if index == 0:
let right = self.right[0]
self.elements[0] = self.elements[right.to_int()]
if self.has_left(right):
self.left[0] = self.left[right.to_int()]
else:
self.left[0] = 0
if self.has_right(right):
self.right[0] = self.right[right.to_int()]
else:
self.right[0] = 0
else:
if self.left[parent] == index:
self.left[parent] = self.right[index]
else:
self.right[parent] = self.right[index]
self.deleted += 1
return True
return self._swap_with_next_smaller_leaf(index)
@always_inline("nodebug")
fn sorted_elements(self) -> UnsafeFixedVector[T]:
let number_of_elements = self.__len__()
var result = UnsafeFixedVector[T](number_of_elements)
if number_of_elements == 0:
return result
var stack = DynamicVector[UInt16](self.max_depth.to_int())
var current: UInt16 = 0
while len(result) < number_of_elements:
if len(result) == 0 or cmp(result[len(result) - 1], self.elements[self.left[current.to_int()].to_int()]) < 0:
while self.has_left(current):
stack.push_back(current)
current = self.left[current.to_int()]
result.append(self.elements[current.to_int()])
if self.has_right(current):
current = self.right[current.to_int()]
else:
current = stack.pop_back()
return result
fn clear(inout self):
self.elements.clear()
self.left.clear()
self.right.clear()
self.deleted = 0
self.max_depth = 0
self.balanced = False
fn union(self, other: Self) -> Self:
var result = Self()
let combined: UnsafeFixedVector[T]
if other.__len__() == 0:
combined = self.sorted_elements()
elif self.__len__() == 0:
combined = other.sorted_elements()
else:
combined = self._combine[Self.Union](other)
result._balance_with(combined)
return result^
fn union_inplace(inout self, other: Self):
if other.__len__() == 0:
return
if self.__len__() == 0:
self._balance_with(other.sorted_elements())
return
let combined = self._combine[Self.Union](other)
self._balance_with(combined)
fn intersection(self, other: Self) -> Self:
var result = FibyTree[T, cmp, to_str]()
if other.__len__() == 0:
return result^
if self.__len__() == 0:
return result^
let combined = self._combine[Self.Intersection](other)
result._balance_with(combined)
return result^
fn intersection_inplace(inout self, other: Self):
if other.__len__() == 0:
self.clear()
return
if self.__len__() == 0:
self.clear()
return
let combined = self._combine[Self.Intersection](other)
self._balance_with(combined)
fn difference(self, other: Self) -> Self:
var result = FibyTree[T, cmp, to_str]()
let combined: UnsafeFixedVector[T]
if other.__len__() == 0 or self.__len__() == 0:
combined = self.sorted_elements()
else:
combined = self._combine[Self.Difference](other)
result._balance_with(combined)
return result^
fn difference_inplace(inout self, other: Self):
if other.__len__() == 0 or self.__len__() == 0:
return
let combined = self._combine[Self.Difference](other)
self._balance_with(combined)
fn other_difference_inplace(inout self, other: Self):
if other.__len__() == 0:
self.clear()
return
if self.__len__() == 0:
self._balance_with(other.sorted_elements())
return
let combined = self._combine[Self.OtherDifference](other)
self._balance_with(combined)
fn symetric_difference(self, other: Self) -> Self:
var result = FibyTree[T, cmp, to_str]()
let combined: UnsafeFixedVector[T]
if other.__len__() == 0:
combined = self.sorted_elements()
elif self.__len__() == 0:
combined = other.sorted_elements()
else:
combined = self._combine[Self.SymetricDifference](other)
result._balance_with(combined)
return result^
fn symetric_difference_inplace(inout self, other: Self):
if other.__len__() == 0:
return
if self.__len__() == 0:
self._balance_with(other.sorted_elements())
return
let combined = self._combine[Self.SymetricDifference](other)
self._balance_with(combined)
@always_inline("nodebug")
fn _combine[type: Int](self, other: Self) -> UnsafeFixedVector[T]:
let num1 = self.__len__()
let num2 = other.__len__()
# assert(num1 > 0)
# assert(num2 > 0)
var combined = UnsafeFixedVector[T](num1 + num2)
var cur1: UInt16 = 0
var cur2: UInt16 = 0
var stack1 = DynamicVector[UInt16](self.max_depth.to_int())
var stack2 = DynamicVector[UInt16](other.max_depth.to_int())
var last_returned1 = UnsafeFixedVector[T](1)
var last_returned2 = UnsafeFixedVector[T](1)
var e1 = self._sorted_iter(cur1, stack1, last_returned1)
last_returned1.append(e1)
var e2 = other._sorted_iter(cur2, stack2, last_returned2)
last_returned2.append(e2)
var compute1 = False
var compute2 = False
var cursor1 = 1
var cursor2 = 1
var increase1 = False
var increase2 = False
while True:
if compute1 and cursor1 < num1:
e1 = self._sorted_iter(cur1, stack1, last_returned1)
last_returned1.clear()
last_returned1.append(e1)
increase1 = True
if compute2 and cursor2 < num2:
e2 = other._sorted_iter(cur2, stack2, last_returned2)
last_returned2.clear()
last_returned2.append(e2)
increase2 = True
let diff = cmp(e1, e2)
if diff < 0:
if num1 == 1 and num2 == 1:
@parameter
if type == Self.Union or type == Self.Difference or type == Self.SymetricDifference:
combined.append(e1)
@parameter
if type == Self.Union or type == Self.SymetricDifference or type == Self.OtherDifference:
combined.append(e2)
break
if cursor1 < num1:
@parameter
if type == Self.Union or type == Self.Difference or type == Self.SymetricDifference:
if len(combined) == 0 or cmp(combined[len(combined) - 1], e1) < 0:
combined.append(e1)
compute1 = cursor1 < num1
compute2 = False
else:
@parameter
if type == Self.Union or type == Self.SymetricDifference or type == Self.OtherDifference:
if len(combined) == 0 or cmp(combined[len(combined) - 1], e2) < 0:
combined.append(e2)
compute1 = False
compute2 = cursor2 < num2
elif diff > 0:
if num1 == 1 and num2 == 1:
@parameter
if type == Self.Union or type == Self.SymetricDifference or type == Self.OtherDifference:
combined.append(e2)
@parameter
if type == Self.Union or type == Self.Difference or type == Self.SymetricDifference:
combined.append(e1)
break
if cursor2 < num2:
@parameter
if type == Self.Union or type == Self.SymetricDifference or type == Self.OtherDifference:
if len(combined) == 0 or cmp(combined[len(combined) - 1], e2) < 0:
combined.append(e2)
compute1 = False
compute2 = cursor2 < num2
else:
@parameter
if type == Self.Union or type == Self.Difference or type == Self.SymetricDifference:
if len(combined) == 0 or cmp(combined[len(combined) - 1], e1) < 0:
combined.append(e1)
compute1 = cursor1 < num1
compute2 = False
else:
@parameter
if type == Self.Union or type == Self.Intersection:
if len(combined) == 0 or cmp(combined[len(combined) - 1], e1) < 0:
combined.append(e1)
compute1 = cursor1 < num1
compute2 = cursor2 < num2
if increase1 and cursor1 < num1:
cursor1 += 1
increase1 = False
if increase2 and cursor2 < num2:
cursor2 += 1
increase2 = False
if cursor1 >= num1 and cursor2 >= num2:
break
return combined
fn is_subset(self, other: Self) -> Bool:
return self._check[Self.IsSubset](other)
fn is_superset(self, other: Self) -> Bool:
return self._check[Self.IsSuperset](other)
fn is_disjoint(self, other: Self) -> Bool:
return self._check[Self.IsDisjoint](other)
@always_inline("nodebug")
fn _check[type: Int](self, other: Self) -> Bool:
let num1 = self.__len__()
let num2 = other.__len__()
@parameter
if type == Self.IsSubset:
if num1 == 0:
return True
if num1 > num2 or num2 == 0:
return False
@parameter
if type == Self.IsSuperset:
if num2 == 0:
return True
if num1 < num2 or num1 == 0:
return False
@parameter
if type == Self.IsDisjoint:
if num1 == 0 or num2 == 0:
return True
var cur1: UInt16 = 0
var cur2: UInt16 = 0
var stack1 = DynamicVector[UInt16](self.max_depth.to_int())
var stack2 = DynamicVector[UInt16](other.max_depth.to_int())
var last_returned1 = UnsafeFixedVector[T](1)
var last_returned2 = UnsafeFixedVector[T](1)
var e1 = self._sorted_iter(cur1, stack1, last_returned1)
last_returned1.append(e1)
var e2 = other._sorted_iter(cur2, stack2, last_returned2)
last_returned2.append(e2)
var compute1 = False
var compute2 = False
var cursor1 = 1
var cursor2 = 1
var increase1 = False
var increase2 = False
var num_eq = 0
while True:
if compute1 and cursor1 < num1:
e1 = self._sorted_iter(cur1, stack1, last_returned1)
last_returned1.clear()
last_returned1.append(e1)
increase1 = True
if compute2 and cursor2 < num2:
e2 = other._sorted_iter(cur2, stack2, last_returned2)
last_returned2.clear()
last_returned2.append(e2)
increase2 = True
let diff = cmp(e1, e2)
if diff == 0:
@parameter
if type == Self.IsDisjoint:
return False
compute1 = cursor1 < num1
compute2 = cursor2 < num2
num_eq += 1
else:
if diff < 0:
@parameter
if type == Self.IsSubset:
break
compute1 = True
compute2 = cursor1 >= num1
else:
@parameter
if type == Self.IsSuperset:
break
compute1 = cursor2 >= num2
compute2 = True
if increase1 and cursor1 < num1:
cursor1 += 1
increase1 = False
if increase2 and cursor2 < num2:
cursor2 += 1
increase2 = False
if cursor1 >= num1 and cursor2 >= num2:
break
@parameter
if type == Self.IsSuperset:
return num_eq == num2
@parameter
if type == Self.IsSubset:
return num_eq == num1
@parameter
if type == Self.IsDisjoint:
return True
return False
@always_inline("nodebug")
fn _sorted_iter(self, inout current: UInt16, inout stack: DynamicVector[UInt16], inout last_returned: UnsafeFixedVector[T]) -> T:
# using UnsafeFixedVector[T](1) as poor mans Optional for last_returned
if len(last_returned) == 0 or cmp(last_returned[0], self.elements[self.left[current.to_int()].to_int()]) < 0:
while self.has_left(current):
stack.push_back(current)
current = self.left[current.to_int()]
let result = self.elements[current.to_int()]
if self.has_right(current):
current = self.right[current.to_int()]
else:
current = stack.pop_back()
return result
@always_inline("nodebug")
fn __len__(self) -> Int:
return len(self.elements) - self.deleted
@always_inline("nodebug")
fn __contains__(self, element: T) -> Bool:
return self._get_index(element).get[1, Int]() > -1
fn _get_index(self, element: T) -> (Int, Int):
if self.__len__() == 0:
return -1, -1
if self.balanced:
return self._get_index_balanced(element)
var parent = 0
var index = 0
while True:
let diff = cmp(self.elements[index], element)
if diff == 0:
return parent, index
if diff > 0:
let left = self.left[index].to_int()
if left == index:
return index, -1
else:
parent = index
index = left
else:
let right = self.right[index].to_int()
if right == index:
return index, -1
else:
parent = index
index = right
fn _get_index_balanced(self, element: T) -> (Int, Int):
var parent = 0
var index = 0
let len = self.__len__()
while index < len:
let diff = cmp(element, self.elements[index])
if diff == 0:
return parent, index
parent = index
index = (index + 1) * 2 + (diff >> 63)
return parent, -1
fn min_index(self) -> Int:
if self.__len__() < 2:
return self.__len__() - 1
if self.balanced:
return (1 << (self.max_depth.to_int() - 1)) - 1
var cand = self.left[0]
while self.has_left(cand):
cand = self.left[cand.to_int()]
return cand.to_int()
fn max_index(self) -> Int:
let size = self.__len__()
if size < 2:
return size - 1
if self.balanced:
if size == (1 << self.max_depth.to_int()) - 1:
return size - 1
return (1 << (self.max_depth.to_int() - 1)) - 2
var cand = self.right[0]
while self.has_right(cand):
cand = self.right[cand.to_int()]
return cand.to_int()
fn _swap_with_next_smaller_leaf(inout self, index: UInt16) -> Bool:
var parent = index
var candidate = self.left[index.to_int()]
if candidate == index:
return False
while True:
if self._is_leaf(candidate):
self.elements[index.to_int()] = self.elements[candidate.to_int()]
self._delete_leaf(candidate.to_int(), parent.to_int())
return True
let right = self.right[candidate.to_int()]
if right == candidate:
self.elements[index.to_int()] = self.elements[candidate.to_int()]
self.right[parent.to_int()] = self.left[candidate.to_int()]
self.deleted += 1
return True
else:
parent = candidate
candidate = right
@always_inline("nodebug")
fn _is_leaf(self, index: UInt16) -> Bool:
return (self.left[index.to_int()] == index).__bool__() and (self.right[index.to_int()] == index).__bool__()
@always_inline("nodebug")
fn _delete_leaf(inout self, index: Int, parent: Int):
self.deleted += 1
if self.left[parent] == index:
self.left[parent] = parent
else:
self.right[parent] = parent
fn balance(inout self):
let sorted_elements = self.sorted_elements()
self._balance_with(sorted_elements)
@always_inline("nodebug")
fn _balance_with(inout self, sorted_elements: UnsafeFixedVector[T]):
let new_size = len(sorted_elements)
self.elements.resize(new_size)
self.left.resize(new_size)
self.right.resize(new_size)
var i: Int = 0
self._eytzinger(i, 1, sorted_elements)
for index in range(new_size):
let l = (index + 1) * 2 - 1
let r = (index + 1) * 2
if l < self.__len__():
self.left[index] = l
else:
self.left[index] = index
if r < self.__len__():
self.right[index] = r
else:
self.right[index] = index
self.deleted = 0
self.balanced = True
self.max_depth = self._optimal_depth()
fn _eytzinger(inout self, inout i: Int, k: Int, v: UnsafeFixedVector[T]):
if k <= len(v):
self._eytzinger(i, k * 2, v)
self.elements[k - 1] = v[i]
i += 1
self._eytzinger(i, k * 2 + 1, v)
fn print_tree(self, root: UInt16 = 0):
if self.__len__() == 0:
print("・")
return
self._print("", 0)
fn _print(self, indentation: String, index: UInt16):
if len(indentation) > 0:
print(indentation, "-", to_str(self.elements[index.to_int()]))
else:
print("-", to_str(self.elements[index.to_int()]))
if self.has_left(index):
self._print(indentation + " ", self.left[index.to_int()])
elif self.has_right(index):
print(indentation + " ", "- ・")
if self.has_right(index):
self._print(indentation + " ", self.right[index.to_int()])
fn test_start_with_empty_tree():
var bst = fiby()
assert_tree(bst, 0, 0)
bst.add(13)
assert_tree(bst, 1, 1)
bst.add(15)
assert_tree(bst, 2, 2)
_= assert_true(bst.delete(13))
assert_tree(bst, 1, 2)
_= assert_true(bst.delete(15))
assert_tree(bst, 0, 2)
_= assert_false(bst.__contains__(15))
bst.balance()
assert_tree(bst, 0, 0)
fn test_longer_sequence_dedup_and_balance():
var bst = fiby(5, 6, 3, 8, 11, 34, 56, 12, 48, 11, 9)
assert_vec(bst.sorted_elements(), vec(3, 5, 6, 8, 9, 11, 12, 34, 48, 56))
assert_tree(bst, 10, 7)
bst.balance()
assert_tree(bst, 10, 4)
_= assert_true(bst.delete(8))
assert_vec(bst.sorted_elements(), vec(3, 5, 6, 9, 11, 12, 34, 48, 56))
assert_tree(bst, 9, 4)
let elements = bst.sorted_elements()
for i in range(len(elements)):
_= assert_true(bst.delete(elements[i]))
assert_tree(bst, 0, 4)
bst.add(13)
assert_tree(bst, 1, 4)
bst.clear()
assert_tree(bst, 0, 0)
fn test_add_ascending():
var bst = fiby()
for i in range(10):
bst.add(i)
assert_vec(bst.sorted_elements(), vec(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
assert_tree(bst, 10, 10)
bst.balance()
assert_tree(bst, 10, 4)
for i in range(10):
_= assert_true(bst.__contains__(i))
fn test_union_inplace():
var b1 = fiby(1, 2, 3)
b1.union_inplace(fiby())
assert_vec(b1.sorted_elements(), vec(1, 2, 3))
var b2 = fiby()
b2.union_inplace(b1)
assert_vec(b2.sorted_elements(), vec(1, 2, 3))
b1.union_inplace(fiby(3, 4, 1))
assert_vec(b1.sorted_elements(), vec(1, 2, 3, 4))
b1.union_inplace(fiby(2, 3))
assert_vec(b1.sorted_elements(), vec(1, 2, 3, 4))
b1.union_inplace(fiby(9, 12, 11, 10))
assert_vec(b1.sorted_elements(), vec(1, 2, 3, 4, 9, 10, 11, 12))
b2 = fiby(1)
b2.union_inplace(fiby(1))
assert_vec(b2.sorted_elements(), vec(1))
b2.union_inplace(fiby(2))
assert_vec(b2.sorted_elements(), vec(1, 2))
b2 = fiby(2)
b2.union_inplace(fiby(1))
assert_vec(b2.sorted_elements(), vec(1, 2))
fn test_intersection_inplace():
var b1 = fiby(1, 2, 3)
b1.intersection_inplace(fiby(3, 4, 1, 6, 7, 10))
assert_vec(b1.sorted_elements(), vec(1, 3))
b1.intersection_inplace(fiby())
assert_vec(b1.sorted_elements(), vec[Int]())
b1.intersection_inplace(fiby(3, 4, 1))
assert_vec(b1.sorted_elements(), vec[Int]())
var b2 = fiby(3, 4, 1, 6, 7, 10)
b2.intersection_inplace(fiby(1, 2, 3, 8))
assert_vec(b2.sorted_elements(), vec(1, 3))
b2 = fiby(1)
b2.intersection_inplace(fiby(1))
assert_vec(b2.sorted_elements(), vec(1))
b2 = fiby(1)
b2.intersection_inplace(fiby(2))
assert_vec(b2.sorted_elements(), vec[Int]())
b2 = fiby(2)
b2.intersection_inplace(fiby(1))
assert_vec(b2.sorted_elements(), vec[Int]())
fn test_difference_inplace():
var b1 = fiby(1, 2, 3)
b1.difference_inplace(fiby(5, 6, 7, 1))
assert_vec(b1.sorted_elements(), vec(2, 3))
b1.difference_inplace(fiby())
assert_vec(b1.sorted_elements(), vec(2, 3))
b1.difference_inplace(fiby(1, 12, 34))
assert_vec(b1.sorted_elements(), vec(2, 3))
var b2 = fiby()
b2.difference_inplace(fiby(1, 2, 3))
assert_tree(b2, 0, 0)
b2 = fiby(1)
b2.difference_inplace(fiby(1))
assert_vec(b2.sorted_elements(), vec[Int]())
b2 = fiby(1)
b2.difference_inplace(fiby(2))
assert_vec(b2.sorted_elements(), vec(1))
b2 = fiby(2)
b2.difference_inplace(fiby(1))
assert_vec(b2.sorted_elements(), vec(2))
fn test_other_difference_inplace():
var b1 = fiby(1, 2, 3)
b1.other_difference_inplace(fiby(5, 6, 7, 1))
assert_vec(b1.sorted_elements(), vec(5, 6, 7))
b1.other_difference_inplace(fiby())
assert_vec(b1.sorted_elements(), vec[Int]())
b1 = fiby(1, 2, 3)
b1.other_difference_inplace(fiby(0, 1, 12, 34))
assert_vec(b1.sorted_elements(), vec(0, 12, 34))
var b2 = fiby()
b2.other_difference_inplace(fiby(1, 2, 3))
assert_vec(b2.sorted_elements(), vec(1, 2, 3))
b2 = fiby(1)
b2.other_difference_inplace(fiby(1))
assert_vec(b2.sorted_elements(), vec[Int]())
b2 = fiby(1)
b2.other_difference_inplace(fiby(2))
assert_vec(b2.sorted_elements(), vec(2))
b2 = fiby(2)
b2.other_difference_inplace(fiby(1))
assert_vec(b2.sorted_elements(), vec(1))
fn test_symetric_difference_inplace():
var b1 = fiby(1, 2, 3)
b1.symetric_difference_inplace(fiby(3, 4, 5))
assert_vec(b1.sorted_elements(), vec(1, 2, 4, 5))
b1.symetric_difference_inplace(fiby(0, 2, 8, 5, 13))
assert_vec(b1.sorted_elements(), vec(0, 1, 4, 8, 13))
b1.symetric_difference_inplace(fiby())
assert_vec(b1.sorted_elements(), vec(0, 1, 4, 8, 13))
var b2 = fiby()
b2.symetric_difference_inplace(fiby(1, 2, 3))
assert_vec(b2.sorted_elements(), vec(1, 2, 3))
b2 = fiby()
b2.symetric_difference_inplace(fiby(1))
assert_vec(b2.sorted_elements(), vec(1))
b2 = fiby(1)
b2.symetric_difference_inplace(fiby())
assert_vec(b2.sorted_elements(), vec(1))
b2 = fiby(1)
b2.symetric_difference_inplace(fiby(1))
assert_vec(b2.sorted_elements(), vec[Int]())
b2 = fiby(1)
b2.symetric_difference_inplace(fiby(2))
assert_vec(b2.sorted_elements(), vec(1, 2))
b2 = fiby(2)
b2.symetric_difference_inplace(fiby(1))
assert_vec(b2.sorted_elements(), vec(1, 2))
fn test_union():
let b0 = fiby(2)
let b1 = fiby(1)
# blocked by a compiler bug
# assert_vec(b0.union(b1).sorted_elements(), vec[Int]())
fn test_disjoint():
_= assert_true(fiby().is_disjoint(fiby()))
_= assert_true(fiby().is_disjoint(fiby(1)))
_= assert_true(fiby(1).is_disjoint(fiby()))
_= assert_true(fiby(1).is_disjoint(fiby(2)))
_= assert_false(fiby(1).is_disjoint(fiby(1)))
_= assert_true(fiby(1, 3).is_disjoint(fiby(2, 5, 6)))
_= assert_true(fiby(1, 3, 5).is_disjoint(fiby(2, 0, 7)))
_= assert_false(fiby(1, 3, 5).is_disjoint(fiby(2, 5)))
_= assert_false(fiby(1, 5).is_disjoint(fiby(2, 5)))
fn test_subset():
_= assert_true(fiby().is_subset(fiby()))
_= assert_true(fiby().is_subset(fiby(1, 2, 3)))
_= assert_true(fiby(3).is_subset(fiby(3)))
_= assert_true(fiby(3).is_subset(fiby(1, 2, 3)))
_= assert_true(fiby(3, 1).is_subset(fiby(1, 2, 3)))
_= assert_true(fiby(3, 1, 2).is_subset(fiby(1, 2, 3)))
_= assert_false(fiby(1).is_subset(fiby(3)))
_= assert_false(fiby(3, 1, 2, 5).is_subset(fiby(1, 2, 3)))
_= assert_false(fiby(3, 1, 5).is_subset(fiby(1, 2, 3)))
fn test_superset():
_= assert_false(fiby(1).is_superset(fiby(2)))
_= assert_false(fiby(1, 5, 8).is_superset(fiby(1, 5, 8, 9)))
_= assert_false(fiby(1, 5, 8).is_superset(fiby(1, 5, 9)))
_= assert_true(fiby().is_superset(fiby()))
_= assert_true(fiby(1).is_superset(fiby(1)))
_= assert_true(fiby(1, 5, 8, 9).is_superset(fiby(1, 5, 8, 9)))
_= assert_true(fiby(1, 5, 8, 9).is_superset(fiby(1, 5, 8)))
_= assert_true(fiby(0, 1, 5, 8, 9).is_superset(fiby(1, 5, 8)))
fn test_min_index():
_= assert_equal(fiby().min_index(), -1)
_= assert_equal(fiby(1).min_index(), 0)
_= assert_equal(fiby(1, 2, 3, 4).min_index(), 0)
_= assert_equal(fiby(3, 4, 1, 2).min_index(), 2)
var f = fiby(3, 4, 1, 2)
f.balance()
_= assert_equal(f.min_index(), 3)
f = fiby(1, 3, 2, 4, 5)
_= assert_equal(f.min_index(), 0)
f = fiby(3, 1, 2, 4, 5)
_= assert_equal(f.min_index(), 1)
f = fiby(3, 2, 1, 4, 5)
_= assert_equal(f.min_index(), 2)
f = fiby(3, 2, 4, 1, 5)
_= assert_equal(f.min_index(), 3)
f = fiby(3, 2, 4, 5, 1)
_= assert_equal(f.min_index(), 4)
f = fiby(1, 3, 2, 4, 5)
f.balance()
_= assert_equal(f.min_index(), 3)
f = fiby(3, 1, 2, 4, 5)
f.balance()
_= assert_equal(f.min_index(), 3)
f = fiby(3, 2, 1, 4, 5)
f.balance()
_= assert_equal(f.min_index(), 3)
f = fiby(3, 2, 4, 1, 5)
f.balance()
_= assert_equal(f.min_index(), 3)
f = fiby(3, 2, 4, 5, 1)
f.balance()
_= assert_equal(f.min_index(), 3)
fn test_max_index():
_= assert_equal(fiby().max_index(), -1)
_= assert_equal(fiby(1).max_index(), 0)
_= assert_equal(fiby(1, 2, 3, 4).max_index(), 3)
_= assert_equal(fiby(3, 4, 1, 2).max_index(), 1)
var f = fiby(3, 4, 1, 2)
f.balance()
_= assert_equal(f.max_index(), 2)
f = fiby(3, 4, 1, 2, 5, 6, 0)
f.balance()
_= assert_equal(f.max_index(), 6)
_= assert_equal(f.elements[6], 6)
f = fiby(1, 3, 2, 4, 5)
_= assert_equal(f.max_index(), 4)
f = fiby(3, 1, 2, 5, 4)
_= assert_equal(f.max_index(), 3)
f = fiby(3, 2, 5, 4, 1)
_= assert_equal(f.max_index(), 2)
f = fiby(3, 5, 4, 1, 2)
_= assert_equal(f.max_index(), 1)
f = fiby(5, 2, 4, 5, 3)
_= assert_equal(f.max_index(), 0)
f = fiby(1, 3, 2, 4, 5)
f.balance()
_= assert_equal(f.max_index(), 2)
f = fiby(3, 1, 2, 5, 4)
f.balance()
_= assert_equal(f.max_index(), 2)
f = fiby(3, 2, 5, 4, 1)
f.balance()
_= assert_equal(f.max_index(), 2)
f = fiby(3, 5, 4, 1, 2)
f.balance()
_= assert_equal(f.max_index(), 2)
f = fiby(5, 2, 4, 5, 3)
f.balance()
_= assert_equal(f.max_index(), 2)
_= assert_equal(f.elements[2], 5)
test_start_with_empty_tree()
test_longer_sequence_dedup_and_balance()
test_add_ascending()
test_union_inplace()
test_intersection_inplace()
test_difference_inplace()
test_other_difference_inplace()
test_symetric_difference_inplace()
test_union()
test_disjoint()
test_subset()
test_superset()
test_min_index()
test_max_index()
print("Done!!!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment