Created
July 3, 2019 12:59
-
-
Save tshort/5e5c9e20e44649a45835538deb2fb63a to your computer and use it in GitHub Desktop.
Code to replace ccalls
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LLVM | |
using LLVM.Interop | |
using TypedCodeUtils | |
import TypedCodeUtils: reflect, filter, lookthrough, canreflect, | |
DefaultConsumer, Reflection, Callsite, | |
identify_invoke, identify_call, identify_foreigncall, | |
process_invoke, process_call | |
using MacroTools | |
""" | |
Replace the pointer in any calls including an `inttoptr` with the symbol in Dict `d` that maps | |
a pointer to a symbol. | |
""" | |
function fix_ccall!(mod::LLVM.Module, d) | |
changed = false | |
for fun in functions(mod), blk in blocks(fun), instr in instructions(blk) | |
if instr isa LLVM.CallInst | |
dest = called_value(instr) | |
if dest isa ConstantExpr && occursin("inttoptr", string(dest)) | |
@show ptr = Ptr{Cvoid}(convert(Int, first(operands(dest)))) | |
if haskey(d, ptr) | |
sym = d[ptr] | |
newdest = LLVM.Function(mod, string(sym), LLVM.FunctionType(llvmtype(dest))) | |
replace_uses!(dest, newdest) | |
changed = true | |
end | |
end | |
end | |
end | |
return changed | |
end | |
getsym(x::QuoteNode) = x | |
getsym(x::Expr) = eval.((x.args[2], x.args[3])) | |
""" | |
Returns a `Dict` mapping function address to symbol name for all `ccall`s and | |
`cglobal`s called from the function. This descends into other invocations | |
within the function. | |
""" | |
find_ccalls(f, tt) = find_ccalls(reflect(f, tt)) | |
function find_ccalls(ref::Reflection) | |
result = Dict{Ptr{Nothing}, Symbol}() | |
foreigncalls = filter((c) -> lookthrough(identify_foreigncall, c), ref.CI.code) | |
for fc in foreigncalls | |
sym = getsym(fc[2].args[1]) | |
address = eval(:(cglobal($(sym)))) | |
result[address] = sym isa Tuple ? sym[1] : sym.value | |
end | |
cglobals = filter((c) -> lookthrough(c -> c.head === :call && c.args[1] isa GlobalRef && | |
c.args[1].name == :cglobal, c), ref.CI.code) | |
for fc in cglobals | |
sym = getsym(fc[2].args[2]) | |
address = eval(:(cglobal($(sym)))) | |
result[address] = sym isa Tuple ? sym[1] : sym.value | |
end | |
invokes = filter((c) -> lookthrough(identify_invoke, c), ref.CI.code) | |
invokes = map((arg) -> process_invoke(DefaultConsumer(), ref, arg...), invokes) | |
for fi in invokes | |
canreflect(fi) || continue | |
merge!(result, find_ccalls(reflect(fi))) | |
end | |
return result | |
end | |
## Examples: | |
# find_ccalls(Threads.nthreads, Tuple{}) | |
# find_ccalls(time, Tuple{}) | |
# find_ccalls(muladd, Tuple{Array{Float64,2},Array{Float64,2},Array{Float64,2}}) | |
llvmmod(native_code) = | |
LLVM.Module(ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef, | |
(Ptr{Cvoid},), native_code.p)) | |
struct LLVMNativeCode # thin wrapper | |
p::Ptr{Cvoid} | |
end | |
""" | |
Returns an LLVM module for the function call `f` with TupleTypes `tt`. | |
""" | |
function irgen(@nospecialize(f), @nospecialize(tt)) | |
# get the method instance | |
world = typemax(UInt) | |
meth = which(f, tt) | |
sig_tt = Tuple{typeof(f), tt.parameters...} | |
(ti, env) = ccall(:jl_type_intersection_with_env, Any, | |
(Any, Any), sig_tt, meth.sig)::Core.SimpleVector | |
meth = Base.func_for_method_checked(meth, ti) | |
linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, | |
(Any, Any, Any, UInt), meth, ti, env, world) | |
# set-up the compiler interface | |
params = Base.CodegenParams(track_allocations=false, | |
code_coverage=false, | |
static_alloc=false, | |
prefer_specsig=true) | |
# generate IR | |
local llvm_mod_ref | |
native_code = ccall(:jl_create_native, Ptr{Cvoid}, | |
(Vector{Core.MethodInstance}, Base.CodegenParams), [linfo], params) | |
@assert native_code != C_NULL | |
m = llvmmod(LLVMNativeCode(native_code)) | |
d = find_ccalls(f, tt) | |
fix_ccall!(m, d) | |
return m | |
end | |
m = irgen(time, Tuple{}) | |
fix_ccall!(m) |
Here's an update with code added to replace references to global variables. Right now, it just supports strings:
import Libdl
using LLVM
using LLVM.Interop
using TypedCodeUtils
import TypedCodeUtils: reflect, filter, lookthrough, canreflect,
DefaultConsumer, Reflection, Callsite,
identify_invoke, identify_call, identify_foreigncall,
process_invoke, process_call
using MacroTools
function julia_to_llvm(@nospecialize x)
isboxed = Ref{UInt8}()
LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef, (Any, Ref{UInt8}), x, isboxed))
end
const jl_value_t_ptr = julia_to_llvm(Any)
const jl_value_t = eltype(jl_value_t_ptr)
const jl_value_t_ptr_ptr = LLVM.PointerType(jl_value_t_ptr)
# cheat on these for now:
const jl_datatype_t_ptr = jl_value_t_ptr
const jl_unionall_t_ptr = jl_value_t_ptr
const jl_typename_t_ptr = jl_value_t_ptr
const jl_sym_t_ptr = jl_value_t_ptr
const jl_svec_t_ptr = jl_value_t_ptr
const jl_module_t_ptr = jl_value_t_ptr
const jl_array_t_ptr = jl_value_t_ptr
const bool_t = julia_to_llvm(Bool)
const int8_t = julia_to_llvm(Int8)
const int16_t = julia_to_llvm(Int16)
const int32_t = julia_to_llvm(Int32)
const int64_t = julia_to_llvm(Int64)
const uint8_t = julia_to_llvm(UInt8)
const uint16_t = julia_to_llvm(UInt16)
const uint32_t = julia_to_llvm(UInt32)
const uint64_t = julia_to_llvm(UInt64)
const float_t = julia_to_llvm(Float32)
const double_t = julia_to_llvm(Float64)
const float32_t = julia_to_llvm(Float32)
const float64_t = julia_to_llvm(Float64)
const void_t = julia_to_llvm(Nothing)
const size_t = julia_to_llvm(Int)
const int8_t_ptr = LLVM.PointerType(int8_t)
const void_t_ptr = LLVM.PointerType(void_t)
llvmmod(native_code) =
LLVM.Module(ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
(Ptr{Cvoid},), native_code.p))
struct LLVMNativeCode # thin wrapper
p::Ptr{Cvoid}
end
"""
Returns an LLVMNativeCode object for the function call `f` with TupleTypes `tt`.
"""
function irgen(@nospecialize(f), @nospecialize(tt))
# get the method instance
world = typemax(UInt)
meth = which(f, tt)
sig_tt = Tuple{typeof(f), tt.parameters...}
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
(Any, Any), sig_tt, meth.sig)::Core.SimpleVector
meth = Base.func_for_method_checked(meth, ti)
linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any, UInt), meth, ti, env, world)
# set-up the compiler interface
params = Base.CodegenParams(track_allocations=false,
code_coverage=false,
static_alloc=false,
prefer_specsig=true)
# generate IR
local llvm_mod_ref
native_code = ccall(:jl_create_native, Ptr{Cvoid},
(Vector{Core.MethodInstance}, Base.CodegenParams), [linfo], params)
@assert native_code != C_NULL
m = llvmmod(LLVMNativeCode(native_code))
d = find_ccalls(f, tt)
fix_ccall!(m, d)
return m
end
"""
optimize!(mod::LLVM.Module)
Optimize the LLVM module `mod`. Crude for now.
Returns nothing.
"""
function optimize!(mod::LLVM.Module)
for llvmf in functions(mod)
startswith(LLVM.name(llvmf), "jfptr_") && unsafe_delete!(mod, llvmf)
startswith(LLVM.name(llvmf), "julia_") && LLVM.linkage!(llvmf, LLVM.API.LLVMExternalLinkage)
end
# triple = "wasm32-unknown-unknown-wasm"
# triple!(mod, triple)
# datalayout!(mod, "e-m:e-p:32:32-i64:64-n32:64-S128")
# LLVM.API.@apicall(:LLVMInitializeWebAssemblyTarget, Cvoid, ())
# LLVM.API.@apicall(:LLVMInitializeWebAssemblyTargetMC, Cvoid, ())
# LLVM.API.@apicall(:LLVMInitializeWebAssemblyTargetInfo, Cvoid, ())
triple = "i686-pc-linux-gnu"
tm = TargetMachine(Target(triple), triple)
ModulePassManager() do pm
# add_library_info!(pm, triple(mod))
add_transform_info!(pm, tm)
ccall(:jl_add_optimization_passes, Cvoid,
(LLVM.API.LLVMPassManagerRef, Cint, Cint),
LLVM.ref(pm), Base.JLOptions().opt_level, 1)
dead_arg_elimination!(pm)
global_optimizer!(pm)
global_dce!(pm)
strip_dead_prototypes!(pm)
run!(pm, mod)
end
mod
end
"""
Creates an object file from `x`.
"""
dump_native(x::LLVMNativeCode, filename) =
ccall(:jl_dump_native_lib, Nothing, (Ptr{Cvoid}, Cstring), x.p, filename)
dump_native(x::LLVM.Module, filename) =
ccall(:jl_dump_native_lib, Nothing, (Ptr{Cvoid}, Cstring), x.ref, filename)
"""
Returns a `Dict` mapping function address to symbol name for all `ccall`s and
`cglobal`s called from the function. This descends into other invocations
within the function.
"""
find_ccalls(@nospecialize(f), @nospecialize(tt)) = find_ccalls(reflect(f, tt))
function find_ccalls(ref::Reflection)
result = Dict{Ptr{Nothing}, Symbol}()
foreigncalls = filter((c) -> lookthrough(identify_foreigncall, c), ref.CI.code)
for fc in foreigncalls
sym = getsym(fc[2].args[1])
address = eval(:(cglobal($(sym))))
result[address] = sym isa Tuple ? sym[1] : sym.value
end
cglobals = filter((c) -> lookthrough(c -> c.head === :call && c.args[1] isa GlobalRef &&
c.args[1].name == :cglobal, c), ref.CI.code)
for fc in cglobals
sym = getsym(fc[2].args[2])
address = eval(:(cglobal($(sym))))
result[address] = sym isa Tuple ? sym[1] : sym.value
end
invokes = filter((c) -> lookthrough(identify_invoke, c), ref.CI.code)
invokes = map((arg) -> process_invoke(DefaultConsumer(), ref, arg...), invokes)
for fi in invokes
canreflect(fi) || continue
merge!(result, find_ccalls(reflect(fi)))
end
return result
end
getsym(x::QuoteNode) = x
getsym(x::Expr) = eval.((x.args[2], x.args[3]))
# d = find_ccalls(Threads.nthreads, Tuple{})
# d = find_ccalls(time, Tuple{})
# d = find_ccalls(muladd, Tuple{Array{Float64,2},Array{Float64,2},Array{Float64,2}})
function fix_ccall!(mod::LLVM.Module, d)
changed = false
for fun in functions(mod), blk in blocks(fun), instr in instructions(blk)
if instr isa LLVM.CallInst
dest = called_value(instr)
if dest isa ConstantExpr && occursin("inttoptr", string(dest))
@show ptr = Ptr{Cvoid}(convert(Int, first(operands(dest))))
if haskey(d, ptr)
sym = d[ptr]
newdest = LLVM.Function(mod, string(sym), LLVM.FunctionType(llvmtype(dest)))
replace_uses!(dest, newdest)
changed = true
end
end
end
end
return changed
end
"""
Returns a `Dict` mapping function address to symbol name for all `GlobalRef`s
referenced from the function. This descends into other invocations
within the function.
Note: this will also return MethodInstances.
"""
find_globals(@nospecialize(f), @nospecialize(tt)) = find_globals(reflect(f, tt))
function find_globals(ref::Reflection)
result = Dict{Ptr{Nothing}, Any}()
globals = filter((c) -> lookthrough(c -> any(x -> !isimmutable(x), c.args), c), ref.CI.code)
for gl in globals
for c in gl[2].args
if !isimmutable(c)
result[pointer_from_objref(c)] = c
end
end
end
invokes = filter((c) -> lookthrough(identify_invoke, c), ref.CI.code)
invokes = map((arg) -> process_invoke(DefaultConsumer(), ref, arg...), invokes)
for fi in invokes
canreflect(fi) || continue
merge!(result, find_globals(reflect(fi)))
end
return result
end
# f() = "asdf"
# d = find_globals(f, Tuple{})
# d = find_globals(sin, Tuple{Float64})
walk(f, x) = nothing
walk(f, x::Instruction) = foreach(c->walk(f,c), operands(x))
walk(f, x::ConstantExpr) = f(x) || foreach(c->walk(f,c), operands(x))
function fix_globals!(mod::LLVM.Module, d)
# Create a `jl_init_globals` function.
jl_init_globals_func = LLVM.Function(mod, "jl_init_globals",
LLVM.FunctionType(julia_to_llvm(Cvoid), LLVMType[]))
jl_init_global_entry = BasicBlock(jl_init_globals_func, "entry", context(mod))
# Definitions for utility functions
func_type = LLVM.FunctionType(
jl_value_t_ptr,
LLVMType[#=str=# int8_t_ptr,
#=len=# julia_to_llvm(Csize_t)])
jl_pchar_to_string_func = LLVM.Function(mod, "jl_pchar_to_string", func_type)
LLVM.linkage!(jl_pchar_to_string_func, LLVM.API.LLVMExternalLinkage)
Builder(context(mod)) do builder
for fun in functions(mod), blk in blocks(fun), instr in instructions(blk)
lastop = Ref{Any}(instr)
walk(instr) do op
if occursin("inttoptr", string(op)) && !occursin("addrspacecast", string(op))
ptr = Ptr{Cvoid}(convert(Int, first(operands(op))))
if haskey(d, ptr)
obj = d[ptr]
if obj isa String
position!(builder, instr)
strdata = globalstring_ptr!(builder, obj, "jl.string.data")
# Create a pointer to the String.
strptr = GlobalVariable(mod, julia_to_llvm(String), "jl.string")
linkage!(strptr, LLVM.API.LLVMInternalLinkage)
# initializer!(strptr, null(julia_to_llvm(String)))
LLVM.API.LLVMSetInitializer(LLVM.ref(strptr), LLVM.ref(null(jl_value_t_ptr)))
strptr2 = load!(builder, strptr)
strptr3 = addrspacecast!(builder, strptr2, LLVM.PointerType(jl_value_t, 10))
replace_uses!(lastop[], strptr3)
# Create the String from `strdata` and include that in `init_fun`.
position!(builder, jl_init_global_entry)
# Call `jl_pchar_to_string(*str, len)`
llvmargs = LLVM.Value[strdata, LLVM.ConstantInt(julia_to_llvm(Int), length(obj))]
newstr = LLVM.call!(builder, jl_pchar_to_string_func, llvmargs)
LLVM.store!(builder, newstr, strptr)
end
end
return true
end
lastop[] = op
return false
end
end
ret!(builder)
end
end
function Base.write(mod::LLVM.Module, path::String)
open(io -> write(io, mod), path, "w")
end
f() = "abcdg"
d = find_globals(f, Tuple{})
m = irgen(f, Tuple{})
fix_globals!(m, d)
optimize!(m)
verify(m)
write(m, "test.bc")
bindir = string(Sys.BINDIR, "/../tools")
libpath = "./test.bc"
dylibpath = abspath("test.so")
# run(`$bindir/clang -shared -fPIC $libpath -o $dylibpath -L$bindir/../lib -ljulia`, wait = true)
run(`$bindir/llc -filetype=obj -o=test.o -relocation-model=pic test.bc`, wait = true)
run(`gcc -shared -fPIC -o test.so -L$bindir/../lib -ljulia test.o`, wait = true)
dylib = Libdl.dlopen(dylibpath)
GC.enable(false)
ccall(Libdl.dlsym(dylib, "jl_init_globals"), Cvoid, ())
funname = first(filter(s->startswith(s, "julia"), LLVM.name.(functions(m))))[2]
# strptr = ccall(Libdl.dlsym(dylib, funname), Ptr{UInt8}, ())
# a = unsafe_wrap(Array{UInt64,1}, convert(Ptr{UInt64}, strptr-8), (3,))
@show str = ccall(Libdl.dlsym(dylib, funname), String, ())
Libdl.dlclose(dylib)
GC.enable(true)
nothing
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Requires https://github.com/JuliaLang/julia/tree/jn/codegen-norecursion.