Skip to content

Instantly share code, notes, and snippets.

@n1chre
Last active March 30, 2017 07:45
Show Gist options
  • Save n1chre/eb73d059cc2cd40c92b8ede01d85a45e to your computer and use it in GitHub Desktop.
Save n1chre/eb73d059cc2cd40c92b8ede01d85a45e to your computer and use it in GitHub Desktop.
Implementation of a red black tree in Erlang, guided by Sedgewick's paper
%%%-------------------------------------------------------------------
%%% @author fhrenic
%%% @copyright (C) 2017, FER
%%% @doc
%%%
%%% @end
%%% Created : 27. Mar 2017 00:01
%%%-------------------------------------------------------------------
-module(rbtree).
-author("fhrenic").
%% API
-export([
tree/0, fromList/1,
inorder/1, preorder/1, postorder/1,
treesize/1, height/1, contains/2,
add/3, get/2, delete/2,
min/1, max/1, floor/2, ceiling/2,
rank/2, select/2, map/2,
size_con/1, height_con/1,
combine_con/2, combine_con/3
]).
% Node => { key, value, color, left child, right child }
-record(node, {key, value, color = red, left = null, right = null}).
tree() -> #node{}.
fromList([]) -> null;
fromList([N | L]) -> fromList(L, add(null, N, N)).
fromList([], T) -> T;
fromList([N | L], T) -> fromList(L, add(T, N, N)).
inorder(null) -> [];
inorder(N) -> inorder(N#node.left) ++ [N#node.value | inorder(N#node.right)].
preorder(null) -> [];
preorder(N) -> [N#node.value | preorder(N#node.left)] ++ preorder(N#node.right).
postorder(null) -> [];
postorder(N) -> postorder(N#node.left) ++ postorder(N#node.right) ++ [N#node.value].
treesize(null) -> 0;
treesize(N) -> 1 + treesize(N#node.left) + treesize(N#node.right).
height(null) -> 0;
height(N) -> 1 + erlang:max(height(N#node.left), height(N#node.right)).
map(null, _) -> null;
map(N, Fun) ->
N#node{value = Fun(N#node.key, N#node.value), left = map(N#node.left, Fun), right = map(N#node.right, Fun)}.
%% concurrent functions
%% calculate tree size : parallel
size_con(Tree) -> combine_con(Tree, fun(X, Y) -> X + Y + 1 end).
%% calculate tree height : parallel
height_con(Tree) -> combine_con(Tree, fun(X, Y) -> 1 + erlang:max(X, Y) end).
%% start parallel processes on left and right child and combine their values
combine_con(Tree, Combine) ->
spawn(rbtree, combine_con, [self(), Tree, Combine]),
receive
X -> X
end.
combine_con(Parent, null, _) ->
Parent ! 0;
combine_con(Parent, N, Combine) ->
spawn(rbtree, combine_con, [self(), N#node.left, Combine]),
spawn(rbtree, combine_con, [self(), N#node.right, Combine]),
receive
X ->
receive
Y -> Parent ! Combine(X, Y)
end
end.
% get value from tree
% get( Tree, Key )
get(null, _)
-> null;
get(N = #node{key = Key}, Key)
-> N#node.value;
get(N, Key) when Key < N#node.key
-> get(N#node.left, Key);
get(N, Key) when Key > N#node.key
-> get(N#node.right, Key).
contains(Tree, Key) -> get(Tree, Key) =/= null.
%% Find node with largest key which isn't greater than given key
floor(null, _) -> null;
floor(N, Key) when Key == N#node.key
-> N;
floor(N, Key) when Key < N#node.key
-> floor(N#node.left, Key);
floor(N, Key) ->
case floor(N#node.right, Key) of
null -> N;
T -> T
end.
%% Find node with smallest key which isn't smaller than given key
ceiling(null, _) -> null;
ceiling(N, Key) when Key == N#node.key
-> N;
ceiling(N, Key) when Key > N#node.key
-> ceiling(N#node.right, Key);
ceiling(N, Key) ->
case ceiling(N#node.left, Key) of
null -> N;
T -> T
end.
%% Add value to tree
add(Tree, Key, Value) ->
Root = balance(put(Tree, Key, Value)),
Root#node{color = black}.
put(null, Key, Value) ->
#node{key = Key, value = Value};
put(N, Key, _) when Key == N#node.key ->
N;
put(N, Key, Value) when Key < N#node.key ->
N#node{left = put(N#node.left, Key, Value)};
put(N, Key, Value) when Key > N#node.key ->
N#node{right = put(N#node.right, Key, Value)}.
%% K-th node in the tree
select(null, _) -> null;
select(N, K) ->
T = treesize(N),
if
T > K -> select(N#node.left, K);
T < K -> select(N#node.right, K - T - 1);
true -> N
end.
rank(null, _) -> 0;
rank(N, Key) when Key < N#node.key ->
rank(N#node.left, Key);
rank(N, Key) when Key > N#node.key ->
1 + treesize(N#node.left) + rank(N#node.right, Key);
rank(N, _) ->
treesize(N#node.left).
%%%%%%%%%%%%%%%
%% %%
%% DELETING %%
%% %%
%%%%%%%%%%%%%%%
delete(null, _) -> null;
delete(N = #node{left = #node{color = black}, right = #node{color = black}}, Key) ->
delete__(N#node{color = red}, Key);
delete(Node, Key) -> delete__(Node, Key).
delete__(Node, Key) ->
N = #node{},
case balance(delete_(Node, Key)) of
null -> null;
N -> N#node{color = black}
end.
delete_(null, _) -> null;
delete_(N, Key) when Key < N#node.key ->
Node = ifIsBlack(N, fun moveRedLeft/1),
Node#node{left = delete(Node#node.left, Key)};
delete_(Node, Key) ->
N1 = rotateRight_ifIsRed(Node),
N2 = sameAndNoRight(N1, Key),
N3 = ifIsBlack2(N2, fun moveRedRight/1),
NK = #node{key = Key},
N = #node{},
case N3 of
null ->
null;
NK ->
Min = min(NK#node.right),
NK#node{key = Min#node.key, value = Min#node.value, right = deleteMin(NK#node.right)};
N ->
N#node{right = delete(N#node.right, Key)}
end.
deleteMin(null) -> null;
deleteMin(N = #node{left = null}) -> N#node.right;
deleteMin(Node) ->
N = ifIsBlack(Node, fun moveRedLeft/1),
balance(N#node{left = deleteMin(N#node.left)}).
%%%%%%%%%%%%%%%
%% %%
%% BALANCING %%
%% %%
%%%%%%%%%%%%%%%
balance(Node) ->
flipColors_if(rotateRight_if(rotateLeft_if(Node))).
rotateLeft(N = #node{right = (R = #node{})}) ->
R#node{color = N#node.color, left = N#node{color = red, right = R#node.left}}.
rotateRight(N = #node{left = (L = #node{})}) ->
L#node{color = N#node.color, right = N#node{color = red, left = L#node.right}}.
flipColors(N = #node{left = (L = #node{}), right = (R = #node{})}) ->
N#node{color = flipColor(N#node.color),
left = L#node{color = flipColor(L#node.color)},
right = R#node{color = flipColor(R#node.color)}}.
flipColor(red) -> black;
flipColor(black) -> red.
moveRedLeft(Node) ->
moveRedLeft_(flipColors(Node)).
moveRedRight(Node) ->
moveRedRight_(flipColors(Node)).
moveRedLeft_(N = #node{right = (R = #node{right = #node{color = red}})}) ->
flipColors(rotateLeft(N#node{right = rotateRight(R)}));
moveRedLeft_(Node) -> Node.
moveRedRight_(N = #node{left = #node{left = #node{color = red}}}) ->
flipColors(rotateRight(N));
moveRedRight_(Node) -> Node.
%% Utility
min(null) -> null;
min(N = #node{left = null}) -> N;
min(N) -> min(N#node.left).
max(null) -> null;
max(N = #node{right = null}) -> N;
max(N) -> max(N#node.right).
%% helpers
ifIsBlack(N = #node{left = #node{color = black, left = #node{color = black}}}, Fun) -> Fun(N);
ifIsBlack(Node, _) -> Node.
ifIsBlack2(N = #node{right = #node{color = black, left = #node{color = black}}}, Fun) -> Fun(N);
ifIsBlack2(Node, _) -> Node.
rotateRight_ifIsRed(N = #node{left = #node{color = red}}) -> rotateRight(N);
rotateRight_ifIsRed(Node) -> Node.
sameAndNoRight(#node{key = Key, right = null}, Key) -> null;
sameAndNoRight(Node, _) -> Node.
rotateLeft_if(N = #node{left = #node{color = black}, right = #node{color = red}}) -> rotateLeft(N);
rotateLeft_if(Node) -> Node.
rotateRight_if(N = #node{left = #node{color = red, left = #node{color = red}}}) -> rotateRight(N);
rotateRight_if(Node) -> Node.
flipColors_if(N = #node{left = #node{color = red}, right = #node{color = red}}) -> flipColors(N);
flipColors_if(Node) -> Node.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment