Created
August 25, 2017 13:26
-
-
Save afniedermayer/fb43e60cafdc5f51ae26c6fae3bec508 to your computer and use it in GitHub Desktop.
`@isinferred` without evaluating function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Foo | |
_args_and_call(args...; kwargs...) = (args[1:end-1], kwargs, args[end](args[1:end-1]...; kwargs...)) | |
_args(args...; kwargs...) = (args[1:end-1], kwargs) | |
_call(args...; kwargs...) = args[end](args[1:end-1]...; kwargs...) | |
macro inferred(ex) | |
inference = inferred_impl(ex, true, __module__) | |
quote | |
let (inftypes, result) = $inference | |
rettype = isa(result, Type) ? Type{result} : typeof(result) | |
rettype == inftypes[1] || error("return type $rettype does not match inferred return type $(inftypes[1])") | |
result | |
end | |
end | |
end | |
macro isinferred(ex) | |
inference = inferred_impl(ex, false, __module__) | |
quote | |
let (inftypes, result) = $inference | |
isleaftype(inftypes[1]) | |
end | |
end | |
end | |
function inferred_impl(ex, evaluate, __module__) | |
if Meta.isexpr(ex, :ref) | |
ex = Expr(:call, :getindex, ex.args...) | |
end | |
Meta.isexpr(ex, :call)|| error("@inferred requires a call expression") | |
Base.remove_linenums!(quote | |
let | |
$(if any(a->(Meta.isexpr(a, :kw) || Meta.isexpr(a, :parameters)), ex.args) | |
# Has keywords | |
args = gensym() | |
kwargs = gensym() | |
quote | |
$(esc(args)), $(esc(kwargs)) = $(esc(Expr(:call, _args, ex.args[2:end]..., ex.args[1]))) | |
result = $evaluate ? $(esc(Expr(:call, _call, ex.args[2:end]..., ex.args[1]))) : nothing | |
inftypes = $(Base.gen_call_with_extracted_types(__module__, Base.return_types, :($(ex.args[1])($(args)...; $(kwargs)...)))) | |
end | |
else | |
# No keywords | |
quote | |
args = ($([esc(ex.args[i]) for i = 2:length(ex.args)]...),) | |
result = $evaluate ? $(esc(ex.args[1]))(args...) : nothing | |
inftypes = Base.return_types($(esc(ex.args[1])), Base.typesof(args...)) | |
end | |
end) | |
@assert length(inftypes) == 1 | |
inftypes, result | |
end | |
end) | |
end | |
end | |
import Foo: @inferred, @isinferred | |
import Base.Test: @test, @test_throws, @testset | |
# test @inferred and @isinferred | |
function uninferrable_function(i) | |
q = [1, "1"] | |
return q[i] | |
end | |
@test_throws ErrorException @inferred(uninferrable_function(1)) | |
@test @inferred(identity(1)) == 1 | |
@test !@isinferred(uninferrable_function(1)) | |
@test @isinferred(identity(1)) | |
println("flag 1") | |
# Ensure @inferred and @isinferred only evaluate the arguments once | |
inferred_test_global = 0 | |
function inferred_test_function() | |
global inferred_test_global | |
inferred_test_global += 1 | |
true | |
end | |
@test @inferred inferred_test_function() | |
@test inferred_test_global == 1 | |
inferred_test_global = 0 | |
@test @isinferred inferred_test_function() | |
@test inferred_test_global == 0 | |
println("flag 2") | |
# Test that @inferred and @isinferred work with A[i] expressions | |
@test @inferred((1:3)[2]) == 2 | |
@test @isinferred((1:3)[2]) | |
struct SillyArray <: AbstractArray{Float64,1} end | |
Base.getindex(a::SillyArray, i) = rand() > 0.5 ? 0 : false | |
test_result = @test_throws ErrorException @inferred(SillyArray()[2]) | |
@test contains(test_result.value.msg, "Bool") | |
@test !@isinferred(SillyArray()[2]) | |
println("flag 3") | |
# Issue #14928 | |
# Make sure abstract error type works. | |
@test_throws Exception error("") | |
println("flag 4") | |
# Issue #17105 | |
# @inferred and @isinferred with kwargs | |
function inferrable_kwtest(x; y=1) | |
2x | |
end | |
function uninferrable_kwtest(x; y=1) | |
2x+y | |
end | |
@test @inferred(inferrable_kwtest(1)) == 2 | |
println("flag 4a1") | |
@test @inferred(inferrable_kwtest(1; y=1)) == 2 | |
println("flag 4a2") | |
@test @inferred(uninferrable_kwtest(1)) == 3 | |
@test_throws ErrorException @inferred(uninferrable_kwtest(1; y=2)) == 2 | |
println("flag 4b") | |
@test @isinferred(inferrable_kwtest(1)) | |
@test @isinferred(inferrable_kwtest(1; y=1)) | |
@test @isinferred(uninferrable_kwtest(1)) | |
@test !@isinferred(uninferrable_kwtest(1; y=2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment