Skip to content

Instantly share code, notes, and snippets.

@Keno
Created June 23, 2021 21:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Keno/907c7ce6a7393f8d4224a4ac24c68b12 to your computer and use it in GitHub Desktop.
Save Keno/907c7ce6a7393f8d4224a4ac24c68b12 to your computer and use it in GitHub Desktop.
#=
using Revise
Revise.track(Core.Compiler)
=#
using Symbolics
bar(a, b) = a * b
function foo(a, b, c)
bar(a, b) + c
end
function matmul(a, b)
a * b
end
function concat(a, b)
vcat(a, b)
end
using Cthulhu
using Symbolics
struct SymbolicInt
sym::Symbolics.Symbolic{Int}
end
Base.show(io::IO, a::SymbolicInt) = show(io, a.sym)
struct SymbolicDimArray
dims::Any
end
using Cthulhu: CthulhuInterpreter, get_specialization,
InferenceResult, InferenceState
using Core.Compiler: CustomLattice, LatticeCallbacks, typeinf, Builtin, Const,
getfield_tfunc, ⊑, PartialStruct, abstract_call_gf_by_type, AbstractInterpreter,
tuple_tfunc, CallMeta, nfields_tfunc, LatticeUnion, tmerge, instanceof_tfunc,
widenconst
using Base.Experimental: @opaque
function Base.show(io::IO, cl::CustomLattice)
print(io, string("CustomLattice(", cl.payload, ")"))
end
function Base.show(io::IO, a::SymbolicDimArray)
if !isa(a.dims, PartialStruct)
prin(io, "D(", a.dims, ")")
end
print(io, "D(")
join(io, map(a.dims.fields) do x
isa(x, CustomLattice) ? x.payload : x
end, ',')
print(io, ")")
end
@syms a::Int b::Int c::Int
function symbolic_tmerge(a, b)
if isa(a, Const) && isa(b, CustomLattice)
if isa(b.payload, SymbolicInt)
return LatticeUnion(a, b)
end
elseif isa(a, CustomLattice) && isa(b, Const)
if isa(a.payload, SymbolicInt)
return LatticeUnion(a, b)
end
end
if isa(a, CustomLattice) && isa(b, CustomLattice) &&
isa(a.payload, SymbolicInt) &&
isa(b.payload, SymbolicInt)
ap = a.payload
bp = b.payload
if isequal(ap.sym, bp.sym)
return a
else
return LatticeUnion(a, b)
end
end
if isa(a, CustomLattice) && isa(b, CustomLattice) &&
isa(a.payload, SymbolicDimArray) &&
isa(b.payload, SymbolicDimArray)
ap = a.payload
bp = b.payload
@assert isa(ap.dims, PartialStruct)
@assert isa(bp.dims, PartialStruct)
apd = ap.dims
bpd = bp.dims
@assert length(apd.fields) == length(bpd.fields)
return CustomLattice(a.typ, SymbolicDimArray(
PartialStruct(Tuple{ntuple(_->Int, length(apd.fields))...},
Any[tmerge(a, b) for (a, b) in zip(apd.fields, bpd.fields)])),
callbacks)
end
@show (a, b)
error()
end
callbacks = LatticeCallbacks(
(@opaque (a::Any, b)->Base.invokelatest(⊑ₛ, a, b)),
(@opaque (a::Any, b)->Base.invokelatest(symbolic_tmeetbound, a, b)),
(@opaque (a::Any, b)->Base.invokelatest(symbolic_tmerge, a, b)),
(@opaque (a::Builtin, b::Vector{Any})->Base.invokelatest(symbolic_tfunc, a, b)),
)
function Core.Compiler.abstract_call_gf_by_type(interp::CthulhuInterpreter, @nospecialize(f),
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, @nospecialize(atype),
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
if isa(f, Type) && f <: Array
if argtypes[2] == Const(undef) && all(x->x ⊑ Int, argtypes[3:end])
if any(x->isa(x, CustomLattice), argtypes[3:end])
return CallMeta(CustomLattice(f, SymbolicDimArray(tuple_tfunc(argtypes[3:end])), callbacks), false)
end
end
end
invoke(Core.Compiler.abstract_call_gf_by_type,
Tuple{AbstractInterpreter, Any, Union{Nothing,Vector{Any}}, Vector{Any},
Any, InferenceState, Int},
interp, f, fargs, argtypes, atype, sv, max_methods)
end
function ⊑ₛ(a, b)
@assert isa(a, CustomLattice)
@Core.Main.Base.show (a, b)
if isa(a.payload, SymbolicInt)
isa(b, Int64) && return true
else isa(a.payload, SymbolicDimArray)
@assert isa(b.payload, SymbolicDimArray)
return ⊑(a.payload.dims, b.payload.dims)
end
end
function symbolic_tfunc(f, args)
@show f === Core.Intrinsics.mul_int
if f === Core.Intrinsics.mul_int || f === Core.Intrinsics.add_int
a2 = args[2]
a3 = args[3]
if !(isa(a2, CustomLattice) && isa(a2.payload, SymbolicInt)) &&
!isa(a2, Const)
return Any
end
if !(isa(a3, CustomLattice) && isa(a3.payload, SymbolicInt)) &&
!isa(a3, Const)
return Any
end
a2v = isa(a2, Const) ? a2.val : a2.payload.sym
a3v = isa(a3, Const) ? a3.val : a3.payload.sym
rva = if f === Core.Intrinsics.mul_int
a2v * a3v
elseif f === Core.Intrinsics.add_int
a2v + a3v
elseif f === Core.Intrinsics.sub_int
a2v - a3v
else
error()
end
@show (a2v, a3v, rva)
if isa(rva, Int)
return Const(rva)
else
return CustomLattice(Int, SymbolicInt(rva), callbacks)
end
elseif f === Core.arraysize
a2 = args[2]
a3 = args[3]
isa(a2, CustomLattice) || return Int
isa(a3, Const) || return Int
a2p = a2.payload
if isa(a2p, SymbolicDimArray)
return getfield_tfunc(a2p.dims, a3)
else
error()
end
elseif f === Core.Intrinsics.arraylen
a2 = args[2]
isa(a2, CustomLattice) || return Int
a2p = a2.payload
if isa(a2p, SymbolicDimArray)
@assert isa(a2p.dims, PartialStruct)
nf = nfields_tfunc(a2p.dims)
@assert isa(nf, Const)
flds = map(1:nf.val) do i
getfield_tfunc(a2p.dims, Const(i))
end
length(flds) == 1 && return flds[1]
return foldl(flds) do a, b
symbolic_tfunc(Core.Intrinsics.mul_int, Any[Const(Core.Intrinsics.mul_int), a, b])
end
else
error()
end
elseif f === typeassert
a2 = args[2]
a3 = args[3]
t = instanceof_tfunc(a3)[1]
if a2.typ <: widenconst(t)
return a2
end
error()
else
@show f
@show args
error()
end
end
function symbolic_tmeetbound(a, b)
b === Any && return a
a === Any && return b
if isa(b, CustomLattice)
if a === b.typ
return b
end
if Core.Compiler.tmeet(a, b.typ) === Union{}
return Union{}
end
if isa(a, CustomLattice)
if isa(a.payload, SymbolicInt) && isa(b.payload, SymbolicInt)
if isequal(a.payload.sym, b.payload.sym)
return a
end
end
end
end
@show (a, b)
error()
end
avar = CustomLattice(Int,
SymbolicInt(a),
callbacks)
bvar = CustomLattice(Int,
SymbolicInt(b),
callbacks)
cvar = CustomLattice(Int,
SymbolicInt(c),
callbacks)
function run_1()
mi = get_specialization(foo, Tuple{Int, Int, Int})
interp = CthulhuInterpreter()
result = InferenceResult(mi, [typeof(foo), avar, bvar, cvar])
frame = InferenceState(result, true, interp)
typeinf(interp, frame)
@show frame.src
@show interp.msgs
end
aavar = CustomLattice(Array{Float64, 2},
SymbolicDimArray(
Core.Compiler.PartialStruct(Tuple{Int, Int}, Any[avar, bvar])),
callbacks)
bbvar = CustomLattice(Array{Float64, 2},
SymbolicDimArray(
Core.Compiler.PartialStruct(Tuple{Int, Int}, Any[bvar, cvar])),
callbacks)
function run_2()
mi = get_specialization(matmul, Tuple{Array{Float64, 2}, Array{Float64, 2}})
interp = CthulhuInterpreter()
result = InferenceResult(mi, [typeof(matmul), aavar, bbvar])
frame = InferenceState(result, true, interp)
typeinf(interp, frame)
@show frame.src
@show interp.msgs
end
aavar2 = CustomLattice(Array{Float64, 2},
SymbolicDimArray(
Core.Compiler.PartialStruct(Tuple{Int, Int}, Any[avar, bvar])),
callbacks)
bbvar2 = CustomLattice(Array{Float64, 2},
SymbolicDimArray(
Core.Compiler.PartialStruct(Tuple{Int, Int}, Any[cvar, bvar])),
callbacks)
function run_3()
mi = get_specialization(concat, Tuple{Array{Float64, 2}, Array{Float64, 2}})
interp = CthulhuInterpreter()
result = InferenceResult(mi, [typeof(concat), aavar2, bbvar2])
frame = InferenceState(result, true, interp)
typeinf(interp, frame)
@show frame.src
@show interp.msgs
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment