Skip to content

Instantly share code, notes, and snippets.

@tshort
Created July 3, 2019 12:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tshort/5e5c9e20e44649a45835538deb2fb63a to your computer and use it in GitHub Desktop.
Save tshort/5e5c9e20e44649a45835538deb2fb63a to your computer and use it in GitHub Desktop.
Code to replace ccalls
using LLVM
using LLVM.Interop
using TypedCodeUtils
import TypedCodeUtils: reflect, filter, lookthrough, canreflect,
DefaultConsumer, Reflection, Callsite,
identify_invoke, identify_call, identify_foreigncall,
process_invoke, process_call
using MacroTools
"""
Replace the pointer in any calls including an `inttoptr` with the symbol in Dict `d` that maps
a pointer to a symbol.
"""
function fix_ccall!(mod::LLVM.Module, d)
changed = false
for fun in functions(mod), blk in blocks(fun), instr in instructions(blk)
if instr isa LLVM.CallInst
dest = called_value(instr)
if dest isa ConstantExpr && occursin("inttoptr", string(dest))
@show ptr = Ptr{Cvoid}(convert(Int, first(operands(dest))))
if haskey(d, ptr)
sym = d[ptr]
newdest = LLVM.Function(mod, string(sym), LLVM.FunctionType(llvmtype(dest)))
replace_uses!(dest, newdest)
changed = true
end
end
end
end
return changed
end
getsym(x::QuoteNode) = x
getsym(x::Expr) = eval.((x.args[2], x.args[3]))
"""
Returns a `Dict` mapping function address to symbol name for all `ccall`s and
`cglobal`s called from the function. This descends into other invocations
within the function.
"""
find_ccalls(f, tt) = find_ccalls(reflect(f, tt))
function find_ccalls(ref::Reflection)
result = Dict{Ptr{Nothing}, Symbol}()
foreigncalls = filter((c) -> lookthrough(identify_foreigncall, c), ref.CI.code)
for fc in foreigncalls
sym = getsym(fc[2].args[1])
address = eval(:(cglobal($(sym))))
result[address] = sym isa Tuple ? sym[1] : sym.value
end
cglobals = filter((c) -> lookthrough(c -> c.head === :call && c.args[1] isa GlobalRef &&
c.args[1].name == :cglobal, c), ref.CI.code)
for fc in cglobals
sym = getsym(fc[2].args[2])
address = eval(:(cglobal($(sym))))
result[address] = sym isa Tuple ? sym[1] : sym.value
end
invokes = filter((c) -> lookthrough(identify_invoke, c), ref.CI.code)
invokes = map((arg) -> process_invoke(DefaultConsumer(), ref, arg...), invokes)
for fi in invokes
canreflect(fi) || continue
merge!(result, find_ccalls(reflect(fi)))
end
return result
end
## Examples:
# find_ccalls(Threads.nthreads, Tuple{})
# find_ccalls(time, Tuple{})
# find_ccalls(muladd, Tuple{Array{Float64,2},Array{Float64,2},Array{Float64,2}})
llvmmod(native_code) =
LLVM.Module(ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
(Ptr{Cvoid},), native_code.p))
struct LLVMNativeCode # thin wrapper
p::Ptr{Cvoid}
end
"""
Returns an LLVM module for the function call `f` with TupleTypes `tt`.
"""
function irgen(@nospecialize(f), @nospecialize(tt))
# get the method instance
world = typemax(UInt)
meth = which(f, tt)
sig_tt = Tuple{typeof(f), tt.parameters...}
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
(Any, Any), sig_tt, meth.sig)::Core.SimpleVector
meth = Base.func_for_method_checked(meth, ti)
linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any, UInt), meth, ti, env, world)
# set-up the compiler interface
params = Base.CodegenParams(track_allocations=false,
code_coverage=false,
static_alloc=false,
prefer_specsig=true)
# generate IR
local llvm_mod_ref
native_code = ccall(:jl_create_native, Ptr{Cvoid},
(Vector{Core.MethodInstance}, Base.CodegenParams), [linfo], params)
@assert native_code != C_NULL
m = llvmmod(LLVMNativeCode(native_code))
d = find_ccalls(f, tt)
fix_ccall!(m, d)
return m
end
m = irgen(time, Tuple{})
fix_ccall!(m)
@tshort
Copy link
Author

tshort commented Jul 3, 2019

@tshort
Copy link
Author

tshort commented Jul 25, 2019

Here's an update with code added to replace references to global variables. Right now, it just supports strings:

import Libdl

using LLVM
using LLVM.Interop
using TypedCodeUtils
import TypedCodeUtils: reflect, filter, lookthrough, canreflect,
                       DefaultConsumer, Reflection, Callsite,
                       identify_invoke, identify_call, identify_foreigncall,
                       process_invoke, process_call
using MacroTools

function julia_to_llvm(@nospecialize x)
    isboxed = Ref{UInt8}()
    LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef, (Any, Ref{UInt8}), x, isboxed))
end

const jl_value_t_ptr = julia_to_llvm(Any)
const jl_value_t = eltype(jl_value_t_ptr)
const jl_value_t_ptr_ptr = LLVM.PointerType(jl_value_t_ptr)
# cheat on these for now:
const jl_datatype_t_ptr = jl_value_t_ptr
const jl_unionall_t_ptr = jl_value_t_ptr
const jl_typename_t_ptr = jl_value_t_ptr
const jl_sym_t_ptr = jl_value_t_ptr
const jl_svec_t_ptr = jl_value_t_ptr
const jl_module_t_ptr = jl_value_t_ptr
const jl_array_t_ptr = jl_value_t_ptr

const bool_t  = julia_to_llvm(Bool)
const int8_t  = julia_to_llvm(Int8)
const int16_t = julia_to_llvm(Int16)
const int32_t = julia_to_llvm(Int32)
const int64_t = julia_to_llvm(Int64)
const uint8_t  = julia_to_llvm(UInt8)
const uint16_t = julia_to_llvm(UInt16)
const uint32_t = julia_to_llvm(UInt32)
const uint64_t = julia_to_llvm(UInt64)
const float_t  = julia_to_llvm(Float32)
const double_t = julia_to_llvm(Float64)
const float32_t = julia_to_llvm(Float32)
const float64_t = julia_to_llvm(Float64)
const void_t    = julia_to_llvm(Nothing)
const size_t    = julia_to_llvm(Int)

const int8_t_ptr  = LLVM.PointerType(int8_t)
const void_t_ptr  = LLVM.PointerType(void_t)

llvmmod(native_code) =
    LLVM.Module(ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
                      (Ptr{Cvoid},), native_code.p))

struct LLVMNativeCode    # thin wrapper
    p::Ptr{Cvoid}
end

"""
Returns an LLVMNativeCode object for the function call `f` with TupleTypes `tt`.
"""
function irgen(@nospecialize(f), @nospecialize(tt))
    # get the method instance
    world = typemax(UInt)
    meth = which(f, tt)
    sig_tt = Tuple{typeof(f), tt.parameters...}
    (ti, env) = ccall(:jl_type_intersection_with_env, Any,
                      (Any, Any), sig_tt, meth.sig)::Core.SimpleVector
    meth = Base.func_for_method_checked(meth, ti)
    linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
                  (Any, Any, Any, UInt), meth, ti, env, world)

    # set-up the compiler interface
    params = Base.CodegenParams(track_allocations=false,
                                code_coverage=false,
                                static_alloc=false,
                                prefer_specsig=true)

    # generate IR
    local llvm_mod_ref
    native_code = ccall(:jl_create_native, Ptr{Cvoid},
                        (Vector{Core.MethodInstance}, Base.CodegenParams), [linfo], params)
    @assert native_code != C_NULL
    m = llvmmod(LLVMNativeCode(native_code))
    d = find_ccalls(f, tt)
    fix_ccall!(m, d)
    return m
end


"""
    optimize!(mod::LLVM.Module)

Optimize the LLVM module `mod`. Crude for now.
Returns nothing.
"""
function optimize!(mod::LLVM.Module)
    for llvmf in functions(mod)
        startswith(LLVM.name(llvmf), "jfptr_") && unsafe_delete!(mod, llvmf)
        startswith(LLVM.name(llvmf), "julia_") && LLVM.linkage!(llvmf, LLVM.API.LLVMExternalLinkage)
    end
    # triple = "wasm32-unknown-unknown-wasm"
    # triple!(mod, triple)
    # datalayout!(mod, "e-m:e-p:32:32-i64:64-n32:64-S128")
    # LLVM.API.@apicall(:LLVMInitializeWebAssemblyTarget, Cvoid, ())
    # LLVM.API.@apicall(:LLVMInitializeWebAssemblyTargetMC, Cvoid, ())
    # LLVM.API.@apicall(:LLVMInitializeWebAssemblyTargetInfo, Cvoid, ())
    triple = "i686-pc-linux-gnu"
    tm = TargetMachine(Target(triple), triple)

    ModulePassManager() do pm
        # add_library_info!(pm, triple(mod))
        add_transform_info!(pm, tm)
        ccall(:jl_add_optimization_passes, Cvoid,
              (LLVM.API.LLVMPassManagerRef, Cint, Cint),
              LLVM.ref(pm), Base.JLOptions().opt_level, 1)

        dead_arg_elimination!(pm)
        global_optimizer!(pm)
        global_dce!(pm)
        strip_dead_prototypes!(pm)

        run!(pm, mod)
    end
    mod
end


"""
Creates an object file from `x`.
"""
dump_native(x::LLVMNativeCode, filename) =
    ccall(:jl_dump_native_lib, Nothing, (Ptr{Cvoid}, Cstring), x.p, filename)
dump_native(x::LLVM.Module, filename) =
    ccall(:jl_dump_native_lib, Nothing, (Ptr{Cvoid}, Cstring), x.ref, filename)


"""
Returns a `Dict` mapping function address to symbol name for all `ccall`s and
`cglobal`s called from the function. This descends into other invocations
within the function.
"""
find_ccalls(@nospecialize(f), @nospecialize(tt)) = find_ccalls(reflect(f, tt))

function find_ccalls(ref::Reflection)
    result = Dict{Ptr{Nothing}, Symbol}()
    foreigncalls = filter((c) -> lookthrough(identify_foreigncall, c), ref.CI.code)
    for fc in foreigncalls
        sym = getsym(fc[2].args[1])
        address = eval(:(cglobal($(sym))))
        result[address] = sym isa Tuple ? sym[1] : sym.value
    end
    cglobals = filter((c) -> lookthrough(c -> c.head === :call && c.args[1] isa GlobalRef &&
                                              c.args[1].name == :cglobal, c), ref.CI.code)
    for fc in cglobals
        sym = getsym(fc[2].args[2])
        address = eval(:(cglobal($(sym))))
        result[address] = sym isa Tuple ? sym[1] : sym.value
    end
    invokes = filter((c) -> lookthrough(identify_invoke, c), ref.CI.code)
    invokes = map((arg) -> process_invoke(DefaultConsumer(), ref, arg...), invokes)
    for fi in invokes
        canreflect(fi) || continue
        merge!(result, find_ccalls(reflect(fi)))
    end
    return result
end

getsym(x::QuoteNode) = x
getsym(x::Expr) = eval.((x.args[2], x.args[3]))

# d = find_ccalls(Threads.nthreads, Tuple{})
# d = find_ccalls(time, Tuple{})
# d = find_ccalls(muladd, Tuple{Array{Float64,2},Array{Float64,2},Array{Float64,2}})


function fix_ccall!(mod::LLVM.Module, d)
    changed = false
    for fun in functions(mod), blk in blocks(fun), instr in instructions(blk)
        if instr isa LLVM.CallInst
            dest = called_value(instr)
            if dest isa ConstantExpr && occursin("inttoptr", string(dest))
                @show ptr = Ptr{Cvoid}(convert(Int, first(operands(dest))))
                if haskey(d, ptr)
                    sym = d[ptr]
                    newdest = LLVM.Function(mod, string(sym), LLVM.FunctionType(llvmtype(dest)))
                    replace_uses!(dest, newdest)
                    changed = true
                end
            end
        end
    end
    return changed
end

"""
Returns a `Dict` mapping function address to symbol name for all `GlobalRef`s
referenced from the function. This descends into other invocations
within the function.

Note: this will also return MethodInstances.
"""
find_globals(@nospecialize(f), @nospecialize(tt)) = find_globals(reflect(f, tt))

function find_globals(ref::Reflection)
    result = Dict{Ptr{Nothing}, Any}()
    globals = filter((c) -> lookthrough(c -> any(x -> !isimmutable(x), c.args), c), ref.CI.code)
    for gl in globals
        for c in gl[2].args
            if !isimmutable(c)
                result[pointer_from_objref(c)] = c
            end
        end
    end
    invokes = filter((c) -> lookthrough(identify_invoke, c), ref.CI.code)
    invokes = map((arg) -> process_invoke(DefaultConsumer(), ref, arg...), invokes)
    for fi in invokes
        canreflect(fi) || continue
        merge!(result, find_globals(reflect(fi)))
    end
    return result
end

# f() = "asdf"
# d = find_globals(f, Tuple{})
# d = find_globals(sin, Tuple{Float64})

walk(f, x) = nothing
walk(f, x::Instruction) = foreach(c->walk(f,c), operands(x))
walk(f, x::ConstantExpr) = f(x) || foreach(c->walk(f,c), operands(x))

function fix_globals!(mod::LLVM.Module, d)
    # Create a `jl_init_globals` function.
    jl_init_globals_func = LLVM.Function(mod, "jl_init_globals",
                                         LLVM.FunctionType(julia_to_llvm(Cvoid), LLVMType[]))
    jl_init_global_entry = BasicBlock(jl_init_globals_func, "entry", context(mod))

    # Definitions for utility functions
    func_type = LLVM.FunctionType(
            jl_value_t_ptr,
            LLVMType[#=str=# int8_t_ptr,
                     #=len=# julia_to_llvm(Csize_t)])
    jl_pchar_to_string_func = LLVM.Function(mod, "jl_pchar_to_string", func_type)
    LLVM.linkage!(jl_pchar_to_string_func, LLVM.API.LLVMExternalLinkage)

    Builder(context(mod)) do builder
        for fun in functions(mod), blk in blocks(fun), instr in instructions(blk)
            lastop = Ref{Any}(instr)
            walk(instr) do op
                if occursin("inttoptr", string(op)) && !occursin("addrspacecast", string(op))
                    ptr = Ptr{Cvoid}(convert(Int, first(operands(op))))
                    if haskey(d, ptr)
                        obj = d[ptr]
                        if obj isa String
                            position!(builder, instr)
                            strdata = globalstring_ptr!(builder, obj, "jl.string.data")
                            # Create a pointer to the String.
                            strptr = GlobalVariable(mod, julia_to_llvm(String), "jl.string")
                            linkage!(strptr, LLVM.API.LLVMInternalLinkage)
                            # initializer!(strptr, null(julia_to_llvm(String)))
                            LLVM.API.LLVMSetInitializer(LLVM.ref(strptr), LLVM.ref(null(jl_value_t_ptr)))
                            strptr2 = load!(builder, strptr) 
                            strptr3 = addrspacecast!(builder, strptr2, LLVM.PointerType(jl_value_t, 10))
                            replace_uses!(lastop[], strptr3)
                            # Create the String from `strdata` and include that in `init_fun`.
                            position!(builder, jl_init_global_entry)
                            # Call `jl_pchar_to_string(*str, len)`
                            llvmargs = LLVM.Value[strdata, LLVM.ConstantInt(julia_to_llvm(Int), length(obj))]
                            newstr = LLVM.call!(builder, jl_pchar_to_string_func, llvmargs)
                            LLVM.store!(builder, newstr, strptr)
                        end
                    end
                    return true
                end
                lastop[] = op
                return false
            end
        end
        ret!(builder)
    end
end

function Base.write(mod::LLVM.Module, path::String)
    open(io -> write(io, mod), path, "w")
end

f() = "abcdg"
d = find_globals(f, Tuple{})
m = irgen(f, Tuple{})
fix_globals!(m, d)
optimize!(m)
verify(m)
write(m, "test.bc")
bindir = string(Sys.BINDIR, "/../tools")
libpath = "./test.bc"
dylibpath = abspath("test.so")

# run(`$bindir/clang -shared -fPIC $libpath -o $dylibpath -L$bindir/../lib -ljulia`, wait = true)
run(`$bindir/llc -filetype=obj -o=test.o -relocation-model=pic test.bc`, wait = true)
run(`gcc -shared -fPIC -o test.so -L$bindir/../lib -ljulia test.o`, wait = true)
dylib = Libdl.dlopen(dylibpath)
GC.enable(false)
ccall(Libdl.dlsym(dylib, "jl_init_globals"), Cvoid, ())
funname = first(filter(s->startswith(s, "julia"), LLVM.name.(functions(m))))[2]
# strptr = ccall(Libdl.dlsym(dylib, funname), Ptr{UInt8}, ())
# a = unsafe_wrap(Array{UInt64,1}, convert(Ptr{UInt64}, strptr-8), (3,))

@show str = ccall(Libdl.dlsym(dylib, funname), String, ())

Libdl.dlclose(dylib)
GC.enable(true)

nothing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment