Skip to content

Instantly share code, notes, and snippets.

@StefanKarpinski
Created July 28, 2012 20:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save StefanKarpinski/3194647 to your computer and use it in GitHub Desktop.
Save StefanKarpinski/3194647 to your computer and use it in GitHub Desktop.
#load ("debug.jl")
const BALANCED = int8(0)
const LEFT = int8(1)
const RIGHT = int8(2)
abstract Avl{K, V}
type Sortmap{K, V} # <: .... ?
root :: Avl{K, V}
count :: Int
cf :: Function # compare function
end
isempty(sm :: Sortmap) = map.count == 0
issorted(sm :: Sortmap) = true
sortmap(K, V) = Sortmap{K, V} (nil(K, V), 0, isless)
function sortmap{K, V} (ks :: Vector{K}, vs :: Vector{V})
root = build_avl(ks, vs)
return Sortmap(root, length(ks), isless)
end
type Nil{K, V} <: Avl{K, V}
end
nil(K, V) = Nil{K, V}()
type Node{K, V} <: Avl{K, V}
child :: Array{Avl{K, V}, 1}
ky :: K
value :: V
bal :: Int8
end
function issorted{T}(a :: Vector{T}, cf :: Function)
for i = 2:length(a)
if cf(a[i], a[i-1])
return false
end
end
return true
end
function valid_set{K, V}(sm :: Sortmap{K, V})
(valid, set, sorted, height, count, key) = validate(sm.root, sm.cf)
return valid && set && sorted && (count == sm.count)
end
function valid_multi_set{K, V}(sm :: Sortmap{K, V})
(valid, set, sorted, height, count, key) = validate(sm.root, sm.cf)
return valid && sorted && (count == sm.count)
end
# sortmap (K, V, f :: Function) = sortmap(K, V, f)
# sortmap (K, V, f :: Function) = sortmap(K, V, f)
copy(sm :: Sortmap) = Sortmap(copy(sm.root), sm.count, sm.cf)
# show(sm :: Sortmap)
#
# insert{K,V} (sm :: Sortmap, k :: K, v :: V)
#
# delete{K,V} (sm :: Sortmap, k :: K)
#
# first{K,V} (sm :: Sortmap) = first(sm.root)
#
# shift{K,V} (sm :: Sortmap)
#
# last{K,V} (sm :: Sortmap) = last(sm.root)
#
# pop{K,V} (sm :: Sortmap)
#
# ref{K,V} (sm :: Sortmap, ind :: K)
#
# assign{K,V} (sm :: Sortmap, k :: K, v :: V)
#
# #ref{K,V} (sm :: Sortmap, ind :: Range{K})
#
# #ref{K,V} (sm :: Sortmap, ind :: Vector{K})
#
# has{K, V} (sm :: Sortmap, k :: K)
#
# get{K, V} (sm :: Sortmap, k :: K)
#
# union {K, V} (sm :: Sortmap, ?)
#
# intersect {K, V} (sm :: Sortmap, ?)
#
# difference {K, V} (sm :: Sortmap, ?)
#
# start (sm :: Sortmap) = flatten(sm.root)
#
# next (junk :: Rubbish, a :: Vector) = shift(a), sub(a, [2:length(a)])
#
# done (junk :: Rubbish, a :: Vector) = isempty(a)
#################################################################################################
function Node{K, V}(ky :: K, value :: V)
newNode = Node(Array(Avl{K, V}, 2), ky, value, BALANCED)
newNode.child = [nil(K, V), nil(K, V)]
return newNode
end
# is node the root of an avl tree?
validate{K, V}(node :: Nil{K, V}, cf :: Function) = (true, true, true, 0, 0, nothing)
function validate{K, V}(node :: Node{K, V}, cf :: Function)
(valid_l, set_l, sorted_l, height_l, count_l, key_l) = validate (node.child[LEFT], cf)
(valid_r, set_r, sorted_r, height_r, count_r, key_r) = validate (node.child[RIGHT], cf)
bal = height_r - height_l
valid = node.bal == [LEFT, BALANCED, RIGHT] [(height_r - height_l) + 2] # check avl structure
a = [node.ky]
if count_l > 0
enqueue(a, key_l)
end
if count_r > 0
push(a, key_r)
end
set_node = true
for i in 2:length(a)
if !cf(a[i-1], a[i])
set_node = false
end
end
sorted_node = issorted(a, cf)
return ((valid && valid_l && valid_r),
(set_node && set_l && set_r),
(sorted_node && sorted_l && sorted_r),
(1 + max(height_l, height_r)),
(1 + count_l + count_r),
node.ky)
end
isempty (node :: Avl) = isa(node, Nil)
notempty (node :: Avl) = isa(node, Node)
copy{K, V}(node :: Nil{K, V}) = node
function copy{K, V}(node :: Node{K,V})
new_node = Node(Array(Avl{K,V}, 2), node.ky, node.value, node.bal)
new_node.child[LEFT] = copy(node.child[LEFT])
new_node.child[RIGHT] = copy(node.child[RIGHT])
new_node
end
function copy{K, V}(node :: Node{K,V})
new_node = Node(Array(Avl{K,V}, 2), node.ky, node.value, node.bal)
new_node.child = [copy(node.child[LEFT]), copy(node.child[RIGHT])]
new_node
end
# of course, this thing needs some attention.
function show{K,V}(node :: Avl{K, V})
function rec (node :: Avl)
if isempty(node)
return ""
end
# I couldn't figure out how to add spaces and commas correctly
return strcat (rec(node.child[LEFT]), " (", node.ky, ":", node.value, ") ", rec(node.child[RIGHT]))
end
print (strcat ("(", rec(node), ")\n"))
end
has{K}(node :: Nil, ky :: K) = false
function has{K,V}(node :: Node{K,V}, ky :: K)
if node.ky == ky
return true
end
side = (ky > node.ky) + 1
has(node.child[side], ky)
end
get{K, V} (node :: Nil{K,V}, ky :: K) = throw(KeyError(ky))
function get{K,V}(node :: Node{K,V}, ky :: K)
if node.ky == ky
return node.value
end
side = (ky > node.ky) + 1
get(node.child[side], ky)
end
# Implementing iteration like this felt like cheating until I realized this can often be
# the best solution: It's both efficient (linear time for traversing all values) and safe
# (no need to worry about tree changing while traversing. Still, it probably isn't ideal.
abstract Rubbish
start (tree :: Avl) = flatten(tree)
next (junk :: Rubbish, a :: Vector) = shift(a), sub(a, [2:length(a)])
done (junk :: Rubbish, a :: Vector) = isempty(a)
# make sorted array from tree
function flatten{K,V} (node :: Avl{K,V})
a = Array(Any, 0)
stack = Array(Avl{K,V}, 0)
push(stack, nil(K,V))
while notempty(node)
while notempty(node)
if notempty(node.child[LEFT])
new_node = Node(node.ky, node.value)
new_node.child[RIGHT] = node.child[RIGHT]
push(stack, new_node)
node = node.child[LEFT]
else
push(a, (node.ky, node.value))
node = node.child[RIGHT]
end
end
node = pop(stack)
end
return a
end
ref(node :: Nil, ky) = throw(KeyError(ky))
function ref(node :: Node, ky)
if ky == node.ky
return node.value
end
ref(node.child[(ky > node.ky) + 1], ky)
end
insert{K, V}(node :: Nil{K, V}, ky :: K, value :: V) = (true, Node(ky, value))
function insert{K, V}(node :: Node{K, V}, ky :: K, value :: V)
# wil later use user supplied function with isless as default
# side and edis are opposites, where LEFT = 1 and right = 2
# so if side = 2 , edis = 1 and vice versa
side = (ky > node.ky) + 1
longer, node.child[side] = insert(node.child[side], ky, value)
edis = 3 - side
if longer
if node.bal == edis
node.bal = BALANCED
longer = false
elseif node.bal == BALANCED
node.bal = side
else
longer, node = rotate(node, side)
end
end
return (longer, node)
end
assign{K, V}(node :: Nil{K, V}, ky :: K, value :: V) = (true, Node(ky, value))
function assign{K, V}(node :: Node{K, V}, ky :: K, value :: V)
if ky < node.ky
side = LEFT
elseif ky > node.ky
side = RIGHT
else
node.value = value
return (false, node)
end
edis = 3 - side
longer, node.child[side] = assign(node.child[side], ky, value)
if longer
if node.bal == edis
node.bal = BALANCED
longer = false
elseif node.bal == BALANCED
node.bal = side
else
longer, node = rotate(node, side)
end
end
return (longer, node)
end
###########################################################
#
# linear time build_avl DOES NOT WORK. DUE TO SOME STRANGE BUG
# I have no clue what it is. When translated to Euphoria it works fine,
# so the algorithm in general is probably ok. Must be something with the
#
# I HAVE REPLACED IT TEMPORAILY WITH this n log n function:
function build_avl{K, V}(ks :: Vector{K}, vs :: Vector{V})
klen = length(ks)
if klen != length(vs)
error("build_avl: key and value arrays must be of same length. length(keys) = $klen, length(values) = $length(vs)")
end
node = nil(K, V)
for i in 1 : klen
flag, node = insert(node, ks[i], vs[i])
end
node
end
# # make tree from two arrays, [Key1, key2... ] and [Value1, value2...)
# # The keys must be sorted
# function build_avl{K, V}(ks :: Vector{K}, vs :: Vector{V})
# # I tried to use subarrays first, but the types were too hard to figure out :)
# function rec(fst, lst)
# len = (lst - fst) + 1
#
# println("len: $len")
#
# if len <= 0
#
# println("nil")
#
# return 0, nil(K, V)
#
# elseif len == 1
#
# #println("new node: $(ks[fst])")
# println("new node: $(Node(ks[fst], vs[fst]))")
#
# return 1, Node(ks[fst], vs[fst])
#
# end
#
# mid = fst + ifloor(len / 2)
#
# println("mid: $mid")
#
# node = Node (ks[mid], vs[mid]) # create new node
#
# println("node: $node")
#
# hl, node.child[LEFT] = rec(fst, mid-1)
#
# println("node w. l. child: $node")
#
# println("hl: $hl")
#
# hr, node.child[RIGHT] = rec(mid+1, lst)
#
# println("node w. r. child: $node")
#
# println("hr: $hr")
#
# node.bal = [LEFT, BALANCED, RIGHT] [(hr - hl) + 2]
#
# println("bal: $(node.bal)")
#
# mx = max(hl, hr)
#
# println("max: $mx")
#
# return max(hl, hr) + 1, node
# end
# klen = length(ks)
# if klen != length(vs)
# error("build_avl: key and value arrays must be of same length. length(keys) = $klen, length(values) = $length(vs)")
# end
# height, node = rec(1, klen)
# return node
# end
#
# println(build_avl([], [])) # works
# println(build_avl([1], [1])) # works
# println(build_avl([1,2], [1,2])) # works
# println(build_avl([1,2,3], [1,2,3])) # works
#
#
# println(build_avl([1,2,3,4], [1,2,3,4])) # not so much
#
#
##############################################################################
function del_first{K, V}(node :: Avl{K, V})
if isempty(node.child[LEFT]) # at the bottom yet?
return (true, node.value, node.child[RIGHT])
end
shorter, ret_val, node.child[LEFT] = del_first(node.child[LEFT])
if shorter == false
return (false, ret_val, node)
end
if node.bal == LEFT
node.bal = BALANCED
elseif node.bal == BALANCED
node.bal = RIGHT
shorter = false
else node.bal == RIGHT
longer, node = rotate(node, RIGHT)
shorter = !longer
end
return (shorter, ret_val, node)
end
function del_last{K, V}(node :: Avl{K, V})
if isempty(node.child[RIGHT]) # at the bottom yet?
return (true, node.value, node.child[LEFT])
end
shorter, ret_val, node.child[RIGHT] = del_first(node.child[RIGHT])
if shorter == false
return (false, ret_val, node)
end
if node.bal == RIGHT
node.bal = BALANCED
elseif node.bal == BALANCED
node.bal = LEFT
shorter = false
else node.bal == LEFT
longer, node = rotate(node, LEFT)
shorter = !longer
end
return (shorter, ret_val, node)
end
# ARGH!, del required a seperate version of del_first that returns keys as well as values.
# This will be removed later
function del_first2{K, V}(node :: Avl{K, V})
if isempty(node.child[LEFT]) # at the bottom yet?
return (true, node.ky, node.value, node.child[RIGHT])
end
shorter, ret_key, ret_val, node.child[LEFT] = del_first2(node.child[LEFT])
if shorter == false
return (false, ret_key, ret_val, node)
end
if node.bal == LEFT
node.bal = BALANCED
elseif node.bal == BALANCED
node.bal = RIGHT
shorter = false
elseif node.bal == RIGHT
longer, node = rotate(node, RIGHT)
shorter = !longer
else
error("del_left: node.bal = $node.bal")
end
return (shorter, ret_key, ret_val, node)
end
# Handles the case of actually deleting a node when it's found
function del_helper{K, V} (node :: Avl{K, V})
if isempty(node.child[LEFT])
return (true, node.value, node.child[RIGHT])
elseif isempty(node.child[RIGHT])
return (true, node.value, node.child[LEFT])
end
shorter, ret_key, ret_val, node.child[RIGHT] = del_first2(node.child[RIGHT])
# swap leftmost (key, value) in right branch with this node's (key, value)
ret_key, ret_val, node.ky, node.value = node.ky, node.value, ret_key, ret_val
if shorter
if node.bal == RIGHT
node.bal = BALANCED
elseif node.bal == BALANCED
node.bal = LEFT
shorter = false
elseif node.bal == LEFT
longer, node = rotate (node, LEFT)
shorter = !longer
end
end
return (shorter, ret_val, node)
end
del{K, V}(node :: Nil{K, V}, ky :: K) = throw (KeyError(ky))
function del{K, V}(node :: Avl{K, V}, ky :: K)
side = ky < node.ky
if side == false
if ky > node.ky
side = RIGHT
else
return del_helper(node)
end
end
edis = 3 - side
shorter, ret_val, node.child[side] = del(node.child[side], ky)
if shorter == false
return (false, ret_val, node)
end
if node.bal == side
node.bal = BALANCED
elseif node.bal == BALANCED
node.bal = edis
shorter = false
elseif node.bal == edis
longer, node = rotate(node, edis)
shorter = !longer
end
return (shorter, ret_val, node)
end
function rotate(node, side)
edis = 3 - side
side_bal = node.child[side].bal
if side_bal == edis
# double rotate
tmp = node.child[side]
new_node = tmp.child[edis]
node.child[side] = new_node.child[edis]
tmp.child[edis] = new_node.child[side]
new_node.child[side] = tmp
new_node.child[edis] = node
if new_node.bal == side
new_node.child[side].bal = BALANCED
new_node.child[edis].bal = edis
elseif new_node.bal == edis
new_node.child[side].bal = side
new_node.child[edis].bal = BALANCED
elseif new_node.bal == BALANCED
# in case of delete
new_node.child[side].bal = BALANCED
new_node.child[edis].bal = BALANCED
end
new_node.bal = BALANCED
return (false, new_node)
end
# single rotate
tmp = node
node = node.child[side]
tmp.child[side] = node.child[edis]
node.child[edis] = tmp
if side_bal == side
node.child[edis].bal = BALANCED
node.bal = BALANCED
elseif side_bal == BALANCED
# may happen after delete
node.bal = edis
node.child[edis].bal = side
return (true, node)
end
return (false, node)
end
#
#
# # Doesn't even work yet !!!
# fast_insert2{K, V}(node :: Nil{K, V}, ky :: K, value :: V) = Node(ky, value)
# function fast_insert2{K, V}(node :: Node{K, V}, ky :: K, value :: V)
# # try to use loops instead of recursion
# stack = Array((Node{K, V}, Int8), 0)
#
# #find place to insert
# while notempty (node)
# side = (ky > node.ky) + 1
# push(stack, (node, side))
# node = node.child[side]
# end
# # insert new key
# node, side = pop(stack)
# node.child[side] = Node(ky, value)
#
#
# # eleiminate all but the necessary condtional branches
# longer = true
# while (longer) && (isempty(stack) == false) # if longer must check for balance changes
# if (node.bal == side)
# longer, node = rotate(node, side)
# parent, side = pop(stack)
# parent.child[side] = node
# node = parent
# break # during insert, rotate will never leave the branch higher
# end
#
# longer, node.bal = begin
# bal = node.bal
# nl = int8(!longer)
# nl |= nl << 1
# upper = bal & nl # bal if not longer
# m = bal $ bal >> 1
# m $= int8(1)
# m |= m << 1
# lower = m & side
# lower = (lower & ~nl) & 0x3
#
# (bool(int8(longer) & ~(bal | bal >> 1)) , lower | upper)
# end
#
# parent, side = pop(stack)
# parent.child[side] = node
# node = parent
# end
# while isempty(stack) == false
# parent, side = pop(stack)
# parent.child[side] = node
# node = parent
# end
# return node
# end
#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment