Skip to content

Instantly share code, notes, and snippets.

@mfalt
Last active January 28, 2022 17:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mfalt/f4756e351ae1575a7d831e7f822324eb to your computer and use it in GitHub Desktop.
Save mfalt/f4756e351ae1575a7d831e7f822324eb to your computer and use it in GitHub Desktop.
module SimpleTreeInterface
export ChildrenFirst, ParentsFirst
export NodeOnly, NodeDepth, NodeIndex
abstract type AbstractReturnMode end
abstract type AbstractTreeIterator{T,RetMode<:AbstractReturnMode} end
struct NodeOnly <: AbstractReturnMode end
struct NodeDepth <: AbstractReturnMode end
struct NodeIndex <: AbstractReturnMode end
# Requires
""" `itr = ParentsFirst(node::T)`
for n in ParentsFirst(node::T)
println(n)
end
Assumes that `node` has fields
parent::T
children::Vector{T}
or defines `children(node)`, `parents(node)`.
Will only visit `node` and its descendants, assumes proper tree and that `children` have `length` and `getindex`.
"""
struct ParentsFirst{T,RetMode} <: AbstractTreeIterator{T,RetMode}
x::T
end
""" `itr = ParentsFirst(node::T)`
for n in ParentsFirst(node::T)
println(n)
end
Assumes that `node` has fields
parent::T
children::Vector{T}.
Will only visit `node` and its descendants, assumes proper tree.
"""
struct ChildrenFirst{T,RetMode} <: AbstractTreeIterator{T,RetMode}
x::T
end
(::Type{TreeItr})(x::N, ::RetMode) where {N, RetMode <: AbstractReturnMode, TreeItr<:AbstractTreeIterator{S,T} where {S,T}} =
TreeItr{N,RetMode}(x)
(::Type{TreeItr})(x::N) where {N, TreeItr <:AbstractTreeIterator{S,T} where {S,T}} =
TreeItr{N,NodeOnly}(x)
children(x) = x.children
parent(x) = x.parent
to_item(::AbstractTreeIterator{T,NodeOnly}, x, idx) where T = x
to_item(::AbstractTreeIterator{T,NodeDepth}, x, idx) where T = (x, length(idx))
to_item(::AbstractTreeIterator{T,NodeIndex}, x, idx) where T = (x, idx)
Base.Iterators.IteratorEltype(::Type{<:AbstractTreeIterator{T}}) where T = Base.Iterators.HasEltype()
Base.Iterators.IteratorSize(::Type{<:AbstractTreeIterator{T}}) where T = Base.Iterators.SizeUnknown()
Base.eltype(::Type{<:AbstractTreeIterator{T,NodeOnly}}) where T = T
Base.eltype(::Type{<:AbstractTreeIterator{T,NodeDepth}}) where T = Tuple{T,Int}
Base.eltype(::Type{<:AbstractTreeIterator{T,NodeOnly}}) where T = Tuple{T,Vector{T}}
is_last_sibling(node) = children(parent(node))[end] == node
function to_first_child!(node, idx)
push!(idx, 1)
item = children(node)[1]
return item, (item, idx)
end
function to_next_sibling!(node, idx)
idx[end] += 1
item = children(parent(node))[idx[end]]
return item, (item, idx)
end
function to_parent!(node, idx)
pop!(idx)
item = parent(node)
return item, (item, idx)
end
function Base.iterate(itr::ChildrenFirst)
node = itr.x
idx = Int[]
while length(children(node)) > 0
node, _ = to_first_child!(node, idx)
end
return (to_item(itr, node, idx), (node, idx))
end
function Base.iterate(itr::ParentsFirst)
item = itr.x
idx = Int[]
state = (item, idx)
return (to_item(itr, item, idx), state)
end
function Base.iterate(itr::ParentsFirst, state)
prev, idx = state
if length(idx) == 0 || length(children(prev)) > 0
# We have just stared, or traversing down
if length(children(prev)) > 0
node, state = to_first_child!(prev, idx)
return to_item(itr, node, idx), state
else # Only one element in tree
return nothing
end
else # We are not at root, and node is leaf
if !is_last_sibling(prev) # Leaf but has sibling
node, state = to_next_sibling!(prev, idx)
return to_item(itr, node, idx), state
else # Leef and no siblings, going upwards
parent, _ = to_parent!(prev, idx)
# Go up while parent is last sibling
while length(idx) > 0 && is_last_sibling(parent)
parent, _ = to_parent!(parent, idx)
end
# "parent" is either root or has next siblin
if length(idx) > 0 # New "parent" has next sibling
node, state = to_next_sibling!(parent, idx)
return to_item(itr, node, idx), state
else # At root and no more states
return nothing
end
end
return
end
end
function Base.iterate(itr::ChildrenFirst, state)
prev, idx = state
if length(idx) == 0
# We were at top, done
return nothing
elseif !is_last_sibling(prev)
node, _ = to_next_sibling!(prev, idx)
# Go all the way down
while length(node.children) > 0
node, _ = to_first_child!(node, idx)
end
# Found something with no more children
return to_item(itr, node, idx), (node, idx)
else # We are last sibling
node, state = to_parent!(prev, idx)
return to_item(itr, node, idx), state
end
end
end
### Example usage::
mutable struct Node{T}
val::T
parent::Node{T}
children::Vector{Node{T}}
function Node{T}(val::T) where T
p = new{T}(val)
p.parent = p
p.children = Vector{Node{T}}()
return p
end
Node{T}(parent::Node{T}, val::T) where T = new{T}(val, parent, Vector{Node{T}}())
end
Node(parent::Node{T}, val::T) where T = Node{T}(parent, val)
Base.show(io::IO, n::Node) = show(io, MIME"text/plain"(), n)
function Base.show(io::IO, m::MIME"text/plain", node::Node)
for (n,depth) in ParentsFirst(node, NodeDepth())
println(io, "-"^depth * "Node: $(n.val)")
end
end
n1 = Node{Int}(1)
n2 = [Node{Int}(n1, i) for i in 2:4];
n1.children = n2;
n21 = [Node{Int}(n2[1], i) for i in 5:6];
n2[1].children = n21;
n211 = [Node{Int}(n21[1], i) for i in 9:11];
n21[1].children = n211;
n22 = [Node{Int}(n2[2], i) for i in 7:8];
n2[2].children = n22;
println(n1)
for n in ParentsFirst(n1)
println(n.val)
end
for (n,d) in ParentsFirst(n1, NodeDepth())
println(n.val, " depth:", d)
end
for n in ChildrenFirst(n1)
println(n.val)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment