Skip to content

Instantly share code, notes, and snippets.

@yatsuta
Created April 9, 2010 11:16
Show Gist options
  • Save yatsuta/361063 to your computer and use it in GitHub Desktop.
Save yatsuta/361063 to your computer and use it in GitHub Desktop.
PRML Max-sum Implementation for Erlang
-module(prml).
-compile(export_all).
%% fn, calc
make_fn(VarArg, Body) -> {VarArg, Body}.
map_var_arg(VarBinding, VarArg) ->
[lookup_dom_val(Var, VarBinding) || Var <- VarArg].
lookup_dom_val(Var, VarBinding) ->
{value, DomVal} = lists:keysearch(Var, 1, VarBinding),
DomVal.
calc({VarArg, Body}, VarBinding) ->
Body(map_var_arg(VarBinding, VarArg)).
%% konst
%% konst(N) ->
%% make_fn([], fun(_) -> N end).
%% add, mult, prod (utilities)
add(X, Y) -> X + Y.
mult(X, Y) -> X * Y.
prod(L) -> lists:foldl(fun mult/2, 1, L).
%% map_reduce (utilities)
map_reduce(function, Map, Reduce, L) ->
Reduce([Map(X) || X <- L]);
map_reduce(process, Map, Reduce, L) ->
N = I = length(L),
map_spawn(Map, Reduce, L, N, I).
map_spawn(_, Reduce, [], N, 0) ->
map_wait([], Reduce, N);
map_spawn(Map, Reduce, [H|T], N, I) ->
spawn(?MODULE, map_spawned, [self(), I, Map, H]),
map_spawn(Map, Reduce, T, N, I - 1).
map_spawned(Pid, I, Map, X) -> Pid ! {I, Map(X)}.
map_wait(Results, Reduce, 0) ->
{_, ResultValues} =
lists:unzip(
lists:sort(fun({I1, _}, {I2, _})
-> I1 > I2
end,
Results)),
Reduce(ResultValues);
map_wait(Results, Reduce, N) ->
receive Result ->
map_wait([Result|Results], Reduce, N - 1)
end.
%% sum_fn, sum_for, add_fun
%% prod_fun, prod_for, mult_fn
%% maxi_fn, maxi_for, max_fn
type() -> process.
var_arg_union(VarArgs) ->
lists:usort(lists:append(VarArgs)).
pos(E, L) -> pos(E, L, 0).
pos(_, [], _) -> false;
pos(E, [H|_], N) when E=:=H -> N;
pos(E, [_|T], N) -> pos(E, T, N + 1).
val_by_pos(L, N) -> lists:nth(N + 1, L).
get_val(Arg, Pos) -> val_by_pos(Arg, Pos).
get_pos(Var, VarArg) -> pos(Var, VarArg).
extract_arg(VarArg, SourceVarArg, SourceArg) ->
[get_val(SourceArg, VarPos) ||
VarPos <- [get_pos(Var, SourceVarArg) ||
Var <- VarArg]].
split_into_var_args_and_bodies(Fns) -> lists:unzip(Fns).
reduce_fns(Reduce, Fns) ->
{VarArgs, _} = split_into_var_args_and_bodies(Fns),
VarArgFolded = var_arg_union(VarArgs),
BodyFolded =
fun(ArgFolded) ->
Map = fun({VarArg, Body}) ->
Arg = extract_arg(VarArg,
VarArgFolded,
ArgFolded),
Body(Arg)
end,
map_reduce(type(), Map, Reduce, Fns)
end,
make_fn(VarArgFolded, BodyFolded).
sum_fn(Fns) -> reduce_fns(fun lists:sum/1, Fns).
sum_for(L, F) -> map_reduce(type(), F, fun ?MODULE:sum_fn/1, L).
add_fn(Fn1, Fn2) -> sum_fn([Fn1, Fn2]).
prod_fn(Fns) -> reduce_fns(fun ?MODULE:prod/1, Fns).
prod_for(L, F) -> map_reduce(type(), F, fun ?MODULE:prod_fn/1, L).
mult_fn(Fn1, Fn2) -> prod_fn([Fn1, Fn2]).
maxi_fn(Fns) -> reduce_fns(fun lists:max/1, Fns).
maxi_for(L, F) -> map_reduce(type(), F, fun ?MODULE:maxi_fn/1, L).
max_fn(Fn1, Fn2) -> maxi_fn([Fn1, Fn2]).
%% log_of, exp_of
compose_fn(F, {VarArg, Body}) ->
make_fn(VarArg, fun(Arg) -> F(Body(Arg)) end).
log_of(Fn) -> compose_fn(fun math:log/1, Fn).
exp_of(Fn) -> compose_fn(fun math:exp/1, Fn).
%% sum_vars, maxi_vars
remove_var(VarPos, VarArg) ->
{Before, [_|After]} = lists:split(VarPos, VarArg),
Before ++ After.
insert_dom_val(Pos, Arg, DomVal) ->
{Before, After} = lists:split(Pos, Arg),
Before ++ [DomVal|After].
partial_fn(Var, DomVal, {VarArg, Body}) ->
VarPos = pos(Var, VarArg),
VarArgPartial = remove_var(VarPos, VarArg),
BodyPartial =
fun(ArgPartial) ->
Body(insert_dom_val(
VarPos, ArgPartial, DomVal))
end,
make_fn(VarArgPartial, BodyPartial).
sum_var(Var, Fn) ->
sum_for(domain_of(Var),
fun(DomVal) ->
partial_fn(Var, DomVal, Fn)
end).
sum_vars(Vars, Fn) ->
lists:foldr(fun ?MODULE:sum_var/2, Fn, Vars).
maxi_var(Var, Fn) ->
maxi_for(domain_of(Var),
fun(DomVal) ->
partial_fn(Var, DomVal, Fn)
end).
maxi_vars(Vars, Fn) ->
lists:foldr(fun ?MODULE:maxi_var/2, Fn, Vars).
%% max-sum
except(L, E) -> [X || X <- L, X =/= E].
ms_mu_f_x(F, X) ->
maxi_vars(
except(ne(F), X),
add_fn(
log_of(fn_of(F)),
sum_for(
except(ne(F), X),
fun(X2) -> ms_mu_x_f(X2, F) end))).
ms_mu_x_f(X, F) ->
sum_for(
except(ne(X), F),
fun(F2) -> ms_mu_f_x(F2, X) end).
pmax(X) ->
exp_of(
maxi_var(
X,
sum_for(
ne(X),
fun(F2) -> ms_mu_f_x(F2, X) end))).
ne(x1) -> [f1];
ne(x2) -> [f1, f2];
ne(x3) -> [f2];
ne(f1) -> [x1, x2];
ne(f2) -> [x2, x3].
domain_of(x1) -> [0, 1];
domain_of(x2) -> [0, 1];
domain_of(x3) -> [0, 1].
p_x1() ->
make_fn([x1],
fun([0]) -> 0.3;
([1]) -> 1.0 - 0.3
end).
p_x2_given_x1() ->
make_fn([x2, x1],
fun([0, 0]) -> 0.7;
([1, 0]) -> 1.0 - 0.7;
([0, 1]) -> 0.4;
([1, 1]) -> 1.0 - 0.4
end).
p_x3_given_x2() ->
make_fn([x3, x2],
fun([0, 0]) -> 0.0000001;
([1, 0]) -> 1.0 - 0.0000001;
([0, 1]) -> 0.6;
([1, 1]) -> 1.0 - 0.6
end).
fn_of(f1) -> mult_fn(p_x2_given_x1(), p_x1());
fn_of(f2) -> p_x3_given_x2().
%% main
main() ->
io:format("pmax(x1): ~w~n" ++
"pmax(x2): ~w~n" ++
"pmax(x3): ~w~n",
[calc(pmax(X), []) || X <- [x1, x2, x3]]).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment