Created
December 21, 2015 21:43
-
-
Save davidagold/674a9fe8a34e64859dad to your computer and use it in GitHub Desktop.
Revised lift macro
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Base.isnull(x) = false | |
Base.get(x) = x | |
macro ^(call, T...) | |
arg_cache = Dict{Union{Symbol, Expr}, Expr}() | |
# if !(isa(call, Expr)) || call.head != :call | |
# throw(ArgumentError("@^: argument must be a function call")) | |
# end | |
if length(T) == 1 | |
e_type = T[1] | |
else | |
throw(ArgumentError("@^: wrong number of arguments")) | |
end | |
e_call = gen_calls(call, arg_cache, true) | |
args = collect(keys(arg_cache)) | |
e_nullcheck = :(isnull($(args[1]))) | |
for i = 2:length(args) | |
e_nullcheck = Expr(:||, e_nullcheck, :(isnull($(args[i])))) | |
end | |
res = quote | |
if $e_nullcheck | |
Nullable{$e_type}() | |
else | |
$e_call | |
end | |
end | |
return esc(res) | |
end | |
# base case for literals | |
gen_calls(e, arg_cache, top) = e | |
# base case for symbols | |
function gen_calls(e::Symbol, arg_cache, top) | |
new_arg = get!(arg_cache, e, :(get($e))) | |
return new_arg | |
end | |
# recursively modify expression tree | |
function gen_calls(e::Expr, arg_cache, top) | |
if e.head == :call | |
if top == true | |
res = Expr(:call, e.args[1], gen_calls(e.args[2:end], arg_cache, false)...) | |
return :(Nullable($res)) | |
else | |
return Expr(:call, e.args[1], gen_calls(e.args[2:end], arg_cache, false)...) | |
end | |
elseif e.head == :ref | |
new_arg = get!(arg_cache, e, :(get($e))) | |
return new_arg | |
# `if` support | |
elseif e.head == :if | |
return Expr(:if, [ gen_calls(expr, arg_cache, top) for expr in e.args ]...) | |
elseif e.head == :block | |
return Expr(:block, [ gen_calls(expr, arg_cache, top) for expr in e.args ]...) | |
elseif e.head == :comparison | |
return Expr(:comparison, gen_calls(e.args[1], arg_cache, top), e.args[2], gen_calls(e.args[3], arg_cache, top)) | |
# /if support | |
else | |
return e | |
end | |
end | |
# recursive case for `args` field arrays | |
function gen_calls(args::Array, arg_cache, top) | |
return [ gen_calls(arg, arg_cache, top) for arg in args ] | |
end | |
### Perf tests | |
using NullableArrays | |
srand(1) | |
A = rand(5_000_000) | |
B = rand(5_000_000) | |
M = rand(Bool, 5_000_000) | |
X = NullableArray(A) | |
Y = NullableArray(A, M) | |
Z = similar(X) | |
@inline function _g(b::Float64, y::Nullable{Float64}) | |
if y.isnull | |
return Nullable{Float64}() | |
else | |
return Nullable(b * y.value) | |
end | |
end | |
function f(Z, B, Y) | |
for i in eachindex(Z) | |
Z[i] = @^ B[i] * Y[i] Float64 | |
end | |
end | |
function g(Z, B, Y) | |
for i in eachindex(Z) | |
Z[i] = _g(B[i], Y[i]) | |
end | |
end | |
f(Z, B, Y); | |
f(Z, X, Y); | |
g(Z, B, Y); | |
@time f(Z, B, Y) | |
@time f(Z, X, Y) | |
@time g(Z, B, Y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment