Skip to content

Instantly share code, notes, and snippets.

@saolof
Last active September 25, 2019 23:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saolof/5743c28b22b5308045dafe926b3e126d to your computer and use it in GitHub Desktop.
Save saolof/5743c28b22b5308045dafe926b3e126d to your computer and use it in GitHub Desktop.
Quick Julia port of Jon Harrop's benchmark using MLStyle to provide a Match macro (ref: https://gist.github.com/jdh30/f3d90a65a7abc7c9faf5c0299b002db3 )
using MLStyle
import MacroTools
import Printf
@data internal MExpr begin
MInt(Int)
Var(Symbol)
Add(MExpr,MExpr)
Mul(MExpr,MExpr)
Pow(MExpr,MExpr)
Ln(MExpr)
end
add(a,b) = @match (a,b) begin
(MInt(a),MInt(b)) => MInt(a+b)
(MInt(0),f) || (f,MInt(0)) => f
(f,MInt(n)) => add(MInt(n),f)
(f,Add(MInt(n),g)) => add(MInt(n),add(f,g))
(Add(f,g),h) => add(f,add(g,h))
(f,g) => Add(f,g)
end
mul(a,b) = @match (a,b) begin
(MInt(a),MInt(b)) => MInt(a*b)
(MInt(0),f) || (f,MInt(0)) => MInt(0)
(MInt(1),f) || (f,MInt(1)) => f
(f,MInt(n)) => mul(MInt(n),f)
(f,Mul(MInt(n),g)) => Mul(MInt(n),Mul(f,g))
(Mul(f, g), h) => mul(f, mul(g, h))
(f,g) => Mul(f,g)
end
pow(a,b) = @match (a,b) begin
(MInt(a),MInt(b)) => MInt(a^b)
(f,MInt(0)) => MInt(1)
(f,MInt(1)) => f
(MInt(0),f) => MInt(0)
(f, g) => Pow(f, g)
end
ln(x) = x==MInt(1) ? MInt(0) : Ln(x)
d(x,expr) = @match expr begin
Var(x) => MInt(1)
MInt(_) || Var(_) => MInt(0)
Add(f, g) => add(d(x,f), d(x, g))
Mul(f, g) => add(mul(f,d(x,g)), mul(g, d(x, f)) )
Pow(f, g) => mul(pow(f, g),add(mul(mul(g, d(x, f)),pow(f, MInt(-1))), mul(ln(f),d(x, g))))
Ln(f) => mul(d(x, f), pow(f, MInt(-1)))
end
count_leaves(x) = @match x begin
MInt(_) || Var(_) => 1
Add(f,g) || Mul(f,g) || Pow(f,g) => count_leaves(f) + count_leaves(g)
Ln(f) => count_leaves(f)
end
to_expr(expr::MExpr) = @match expr begin
MInt(x) || Var(x) => x
Add(f,g) => :($(to_expr(f)) + $(to_expr(g)))
Mul(f,g) => :($(to_expr(f)) * $(to_expr(g)))
Pow(f,g) => :($(to_expr(f)) ^ $(to_expr(g)))
Ln(f) => :(ln($f))
end
Base.print(io::IO,f::MExpr) = count_leaves(f) > 100 ? print(io,$"<<(count_leaves(f))>>") : print(io,to_expr(f))
nest(n,f,x) = foldl((s,_) -> f(s),1:n,init=x)
function deriv(f)
local d = d(:x,f)
Printf.@printf("D(%s) = %s\n%!", string(f), string(d))
d
end
macro mexpr(expr) # Convenience macro and operator overloads, to make it more convenient to type in expression in the REPL.
MacroTools.postwalk(x-> x isa Integer ? MInt(x) : x isa QuoteNode ? Var(x.value) : x, expr)
end
Base.:+(a::MExpr,b::MExpr) = add(a,b)
Base.:*(a::MExpr,b::MExpr) = mul(a,b)
Base.:^(a::MExpr,b::MExpr) = pow(a,b)
function main(n)
f = @mexpr :x^:x
@time nest(n,deriv,f)
end
if !isinteractive()
main(parse(Int,ARGS[1]))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment