Skip to content

Instantly share code, notes, and snippets.

@aviatesk
Last active November 26, 2020 06:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aviatesk/9314f9a5cd28ad6f3dca519e87ce7558 to your computer and use it in GitHub Desktop.
Save aviatesk/9314f9a5cd28ad6f3dca519e87ce7558 to your computer and use it in GitHub Desktop.
# %%
import Base:
uniontypes,
get_world_counter,
_methods_by_ftype
import Core:
Const,
MethodMatch,
MethodInstance
import Core.Compiler:
NativeInterpreter,
specialize_method,
InferenceResult,
InferenceState,
retrieve_code_info,
typeinf,
widenconst
function infer(@nospecialize(tt),
world = get_world_counter(),
interp = NativeInterpreter(world),
)
mms = _methods_by_ftype(tt, -1, world)
length(mms) === 1 || return nothing
mm = first(mms)::MethodMatch
linfo = specialize_method(mm.method, mm.spec_types, mm.sparams)
result = InferenceResult(linfo)
frame = InferenceState(result, #=cached=# true, interp)
typeinf(interp, frame)
return result
end
function typed_filter(f, a::Array{T,N}) where {T,N}
if @generated
ft = f
et = T
isa(et, Union) || return :(filter(f, a))
ets = uniontypes(et)
filtered = []
for et in ets
tt = Tuple{ft,et}
result = infer(tt)
if isa(result, InferenceResult)
result = result.result
isa(result, Const) && result.val === false && continue
end
push!(filtered, et)
end
et′ = Union{filtered...}
return quote
j = 1
b = Vector{$(et′)}(undef, length(a))
isempty(b) && return b
fallback = first(b)
for ai in a
c = f(ai)
@inbounds b[j] = (c ? ai : fallback)::$(et′)
j = ifelse(c, j+1, j)
end
resize!(b, j-1)
sizehint!(b, length(b))
b
end
else
filter(f, a)
end
end
function typed_filter(pred, s::AbstractSet)
if @generated
ft = pred
et = eltype(s)
isa(et, Union) || return :(filter(pred, s))
ets = uniontypes(et)
filtered = []
for et in ets
tt = Tuple{ft,et}
result = infer(tt)
if isa(result, InferenceResult)
result = result.result
isa(result, Const) && result.val === false && continue
end
push!(filtered, et)
end
et′ = Union{filtered...}
return :(Base.mapfilter(pred, push!, s, Set{$(et′)}()))
else
return filter(pred, s)
end
end
function typed_replace(new::Base.Callable, A; count = nothing)
if @generated
count === Nothing || return :(replace(new, A; count))
fallback = :(replace(new, A; count = typemax(Int)))
ft = new
et = eltype(A)
isa(et, Union) || return fallback
ets = uniontypes(et)
replaced = []
for et in ets
tt = Tuple{ft,et}
result = infer(tt)
isa(result, InferenceResult) || return fallback # unsuccessful inference
push!(replaced, widenconst(result.result))
end
et′ = Union{replaced...}
return :(Base._replace!(new, Base._similar_or_copy(A, $(et′)), A, Base.check_count(typemax(Int))))
else
if isnothing(count)
count = typemax(Int)
end
return replace(new, A; count)
end
end
# %%
ary = map(x -> rand(Bool) ? missing : x, 1:1000000)
summer(filter, ary) = sum(filter(ary))
filter(!ismissing, ary)
typed_filter(!ismissing, ary)
using BenchmarkTools
@btime filter(!ismissing, $ary)
@btime typed_filter(!ismissing, $ary)
@btime summer(Base.Fix1(filter, !ismissing), $(ary))
@btime summer(Base.Fix1(typed_filter, !ismissing), $(ary))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment