Skip to content

Instantly share code, notes, and snippets.

@quinnj
Created October 30, 2019 16:10
Show Gist options
  • Save quinnj/5d3ea86a0cc9d2506bfae5c16260d361 to your computer and use it in GitHub Desktop.
Save quinnj/5d3ea86a0cc9d2506bfae5c16260d361 to your computer and use it in GitHub Desktop.
module BinaryTrees
mutable struct Node{K, T}
value::K
data::T
left::Union{Node, Nothing}
right::Union{Node, Nothing}
end
Node(x) = Node(typemax(UInt) >> 1, x, nothing, nothing)
Node(x, data) = Node(x, data, nothing, nothing)
function Base.show(io::IO, n::Node, indent=0)
println(io, string(" " ^ indent, n.value))
if n.left isa Nothing
println(io, string(" " ^ (indent+1), "null"))
else
show(io, n.left, indent + 1)
end
if n.right isa Nothing
println(io, string(" " ^ (indent+1), "null"))
else
show(io, n.right, indent + 1)
end
end
mutable struct Tree{K, T}
root::Union{Node{K, T}, Nothing}
nodes::Int
end
Tree(x) = Tree(Node(x), 0)
Tree(n::Node) = Tree(n, 1)
Base.show(io::IO, t::Tree) = (println(io, "Tree with $(t.nodes) nodes:"); show(io, t.root))
incr!(t::Tree, x, incr, step) = t.root !== nothing && incr!(t.root, x, incr, step)
Base.insert!(t::Tree, x, data) = t.nodes += insert!(t.root, x, data)
Base.delete!(t::Tree, x) = t.nodes -= delete!(t.root, t.root, x)
function Base.collect(t::Tree{K, T}) where {K, T}
A = Vector{Tuple{K,T}}(undef, t.nodes+(t.root !== nothing && t.root.value == (typemax(UInt64)>>1)))
setindex(A, t.root, 1)
return A
end
setindex(A, ::Nothing, i) = return nothing
function setindex(A, n::Node, i)
if n.left isa Nothing
A[i] = (n.value,n.data)
idx = i + 1
else
idx = setindex(A, n.left, i)
A[idx] = (n.value,n.data)
idx += 1
end
if n.right isa Nothing
return idx
else
return setindex(A, n.right, idx)
end
end
Tree(A::AbstractArray{Tuple{K, T}}) where {K, T} = Tree{K, T}(buildTree(A, 1, length(A)), length(A))
function buildTree(A, lo, hi)
lo > hi && return nothing
m = (hi + lo) >> 1
n = Node(A[m]...)
n.left = buildTree(A, lo, m-1)
n.right = buildTree(A, m+1, hi)
return n
end
function incr!(n::Node{K, Int}, x, incr, step) where {K}
if x < n.value && n.left !== nothing
return incr!(n.left, x, incr, step)
elseif x < (n.value + step) || x ≈ (n.value + step) || n.right === nothing
n.data += incr
return nothing
else
return incr!(n.right, x, incr, step)
end
end
function Base.insert!(n::Node, x, data)
c = cmp(x, n.value)
if c === 0
# replace node
n.data = data
return false
elseif c === -1
if n.left isa Nothing
n.left = Node(x, data)
return true
else
return insert!(n.left, x, data)
end
else # c === 1
if n.right isa Nothing
n.right = Node(x, data)
return true
else
return insert!(n.right, x, data)
end
end
end
function replacewithmin!(orig::Node, n::Node)
if n.left isa Nothing
orig.value = n.value
orig.data = n.data
return
else
return replacewithmin!(orig, n.left)
end
end
function Base.delete!(parent::Node, n::Node, x)
c = cmp(x, n.value)
if c === 0
# found our node to delete
left = n.left isa Nothing
if !left && !(n.right isa Nothing)
replacewithmin!(n, n.right)
delete!(n, n.right, n.value)
elseif parent.left === n
parent.left = left ? n.right : n.left
else
parent.right = left ? n.right : n.left
end
return true
elseif c === -1
if n.left isa Nothing
return false
else
return delete!(n, n.left, x)
end
else # c === 1
if n.right isa Nothing
return false
else
return delete!(n, n.right, x)
end
end
end
end # module
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment