Skip to content

Instantly share code, notes, and snippets.

@maleadt
Last active September 28, 2016 14:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maleadt/ec07509c880903efce0d2fed6bded416 to your computer and use it in GitHub Desktop.
Save maleadt/ec07509c880903efce0d2fed6bded416 to your computer and use it in GitHub Desktop.
PTX JIT (WIP)
using LLVM
using CUDAdrv, CUDAnative
using Base.Test
dev = CuDevice(0)
cap = capability(dev)
# helper methods for querying version DBs
# (tool::VersionNumber => devices::Vector{VersionNumber})
flatten(vecvec) = vcat(vecvec...)
search(db, predicate) = Set(flatten(valvec for (key,valvec) in db if predicate(key)))
# figure out which devices this LLVM library supports
const llvm_db = [
v"3.2" => [v"2.0", v"2.1", v"3.0", v"3.5"],
v"3.5" => [v"5.0"],
v"3.7" => [v"3.2", v"3.7", v"5.2", v"5.3"],
v"3.9" => [v"6.0", v"6.1", v"6.2"]
]
llvm_ver = LLVM.version()
llvm_support = search(llvm_db, ver -> ver <= llvm_ver)
isempty(llvm_support) && error("LLVM library $llvm_ver does not support any compatible device")
# figure out which devices this CUDA toolkit supports
const cuda_db = [
v"4.0" => [v"2.0", v"2.1"],
v"4.2" => [v"3.0"],
v"5.0" => [v"3.5"],
v"6.0" => [v"3.2", v"5.0"],
v"6.5" => [v"3.7"],
v"7.0" => [v"5.2"],
v"7.5" => [v"5.3"],
v"8.0" => [v"6.0", v"6.1"] # NOTE: 6.2 should be supported, but `ptxas` complains
]
cuda_ver = CUDAdrv.version()
cuda_support = search(cuda_db, ver -> ver <= cuda_ver)
isempty(cuda_support) && error("CUDA toolkit $cuda_ver does not support any compatible device")
# select a target
target = maximum(llvm_support cuda_support)
cpu = "sm_$(target.major)$(target.minor)"
# figure out which libdevice versions are compatible with the selected target
const libdevice_db = [v"2.0", v"3.0", v"3.5"]
libdevice_compat = Set(ver for ver in libdevice_db if ver <= target)
isempty(libdevice_compat) && error("No compatible CUDA device library available")
libdevice = maximum(libdevice_compat)
libdevice_fn = "libdevice.compute_$(libdevice.major)$(libdevice.minor).10.bc"
## irgen
# check if Julia's LLVM version matches ours
jl_llvm_ver = VersionNumber(Base.libllvm_version)
if jl_llvm_ver != llvm_ver
error("LLVM library $llvm_ver incompatible with Julia's LLVM $jl_llvm_ver")
end
mod = LLVM.Module("vadd")
if is(Int,Int64)
triple!(mod, "nvptx64-nvidia-cuda")
else
triple!(mod, "nvptx-nvidia-cuda")
end
# TODO: get this IR by calling into Julia
entry = "julia_kernel_vadd_64943"
ir = readstring(joinpath(@__DIR__, "ptxjit.ll"))
let irmod = parse(LLVM.Module, ir)
name!(irmod, "parsed")
link!(mod, irmod)
verify(mod)
end
## linking
# libdevice
if haskey(ENV, "NVVMIR_LIBRARY_DIR")
libdevice_dirs = [ENV["NVVMIR_LIBRARY_DIR"]]
else
libdevice_dirs = ["/usr/lib/nvidia-cuda-toolkit/libdevice",
"/usr/local/cuda/nvvm/libdevice",
"/opt/cuda/nvvm/libdevice"]
end
any(d->isdir(d), libdevice_dirs) ||
error("CUDA device library path not found -- specify using NVVMIR_LIBRARY_DIR")
libdevice_paths = filter(p->isfile(p), map(d->joinpath(d,libdevice_fn), libdevice_dirs))
isempty(libdevice_paths) && error("CUDA device library $libdevice_fn not found")
libdevice_path = first(libdevice_paths)
open(libdevice_path) do libdevice
let libdevice_mod = parse(LLVM.Module, read(libdevice))
name!(libdevice_mod, "libdevice")
# override libdevice's nvptx-unknown-unknown triple to avoid warnings
triple!(libdevice_mod, triple(mod))
# 1. Save list of external functions
exports = map(f->LLVM.name(f), functions(mod))
filter!(fn->!haskey(functions(libdevice_mod), fn), exports)
# 2. Link with libdevice
link!(mod, libdevice_mod)
ModulePassManager() do pm
# 3. Internalize all functions not in list from (1)
internalize!(pm, exports)
# 4. Eliminate all unused internal functions
global_optimizer!(pm)
global_dce!(pm)
strip_dead_prototypes!(pm)
# 5. Run NVVMReflect pass
nvvm_reflect!(pm, Dict("__CUDA_FTZ" => 1))
# 6. Run standard optimization pipeline
always_inliner!(pm)
run!(pm, mod)
end
end
end
# kernel metadata
fn = get(functions(mod), entry)
push!(metadata(mod), "nvvm.annotations",
MDNode([fn, MDString("kernel"), ConstantInt(LLVM.Int32Type(), 1)]))
## optimize & mcgen
InitializeNVPTXTarget()
InitializeNVPTXTargetInfo()
t = Target(triple(mod))
InitializeNVPTXTargetMC()
tm = TargetMachine(t, triple(mod), cpu)
dl = DataLayout(tm)
datalayout!(mod, dl)
ModulePassManager() do pm
# invoke Julia's custom passes
tbaa_gcframe = MDNode(ccall(:jl_get_tbaa_gcframe, LLVM.API.LLVMValueRef, ()))
ccall(:LLVMAddLowerGCFramePass, Void,
(LLVM.API.LLVMPassManagerRef, LLVM.API.LLVMValueRef),
LLVM.ref(pm), LLVM.ref(tbaa_gcframe))
ccall(:LLVMAddLowerPTLSPass, Void,
(LLVM.API.LLVMPassManagerRef, LLVM.API.LLVMValueRef, Cint),
LLVM.ref(pm), LLVM.ref(tbaa_gcframe), 0)
populate!(pm, tm)
PassManagerBuilder() do pmb
populate!(pm, pmb)
end
run!(pm, mod)
end
InitializeNVPTXAsmPrinter()
asm = convert(String, emit(tm, mod, LLVM.API.LLVMAssemblyFile))
## execution
ctx = CuContext(dev)
cuda_mod = CuModule(asm)
vadd = CuFunction(cuda_mod, entry)
dims = (3,4)
a = round.(rand(Float32, dims) * 100)
b = round.(rand(Float32, dims) * 100)
d_a = CuArray(a)
d_b = CuArray(b)
d_c = CuArray(Float32, dims)
len = prod(dims)
@cuda (1,len) vadd(d_a.ptr, d_b.ptr, d_c.ptr)
c = Array(d_c)
@test a+b c
destroy(ctx)
; ModuleID = 'kernel_vadd' (unoptimized IR generated by Julia's irgen)
%jl_value_t = type { %jl_value_t* }
define void @julia_kernel_vadd_64943(float*, float*, float*) #0 !dbg !6 {
top:
%3 = alloca float
%4 = alloca float
%5 = alloca i32
%6 = alloca i32
%7 = alloca i32
%8 = alloca i32
%9 = alloca i32
%10 = alloca i32
%11 = alloca i32
%12 = alloca i32
%13 = alloca i32
%14 = alloca i32
%15 = alloca i32
%16 = alloca i32
%17 = alloca i32
%18 = alloca i32
%19 = alloca i32
%20 = call %jl_value_t*** @jl_get_ptls_states()
%21 = bitcast %jl_value_t*** %20 to %jl_value_t**
%22 = getelementptr %jl_value_t*, %jl_value_t** %21, i64 2
%23 = bitcast %jl_value_t** %22 to i64**
%24 = load i64*, i64** %23, !tbaa !8
%i = alloca i64
; Filename: /home/tbesard/Projects/Julia-CUDA/CUDAnative/src/intrinsics.jl
; Source line: 93
call void @julia.gcroot_flush(), !dbg !10
%25 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !10
store i32 %25, i32* %19, !dbg !10
br i1 true, label %pass, label %fail, !dbg !10
fail: ; preds = %top
call void @llvm.trap() #2, !dbg !10
unreachable, !dbg !10
pass: ; preds = %top
store i32 1, i32* %18, !dbg !10
call void @julia.gcroot_flush(), !dbg !16
%26 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !dbg !16
store i32 %26, i32* %17, !dbg !16
br i1 true, label %pass2, label %fail1, !dbg !16
fail1: ; preds = %pass
call void @llvm.trap() #2, !dbg !16
unreachable, !dbg !16
pass2: ; preds = %pass
store i32 1, i32* %16, !dbg !16
call void @julia.gcroot_flush(), !dbg !18
%27 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z(), !dbg !18
store i32 %27, i32* %15, !dbg !18
br i1 true, label %pass4, label %fail3, !dbg !18
fail3: ; preds = %pass2
call void @llvm.trap() #2, !dbg !18
unreachable, !dbg !18
pass4: ; preds = %pass2
store i32 1, i32* %14, !dbg !18
; Source line: 127
call void @julia.gcroot_flush(), !dbg !20
%28 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x(), !dbg !20
store i32 %28, i32* %13, !dbg !20
call void @julia.gcroot_flush(), !dbg !20
%29 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y(), !dbg !20
store i32 %29, i32* %12, !dbg !20
call void @julia.gcroot_flush(), !dbg !20
%30 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z(), !dbg !20
store i32 %30, i32* %11, !dbg !20
; Source line: 93
call void @julia.gcroot_flush(), !dbg !22
%31 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !22
store i32 %31, i32* %10, !dbg !22
br i1 true, label %pass6, label %fail5, !dbg !22
fail5: ; preds = %pass4
call void @llvm.trap() #2, !dbg !22
unreachable, !dbg !22
pass6: ; preds = %pass4
store i32 1, i32* %9, !dbg !22
call void @julia.gcroot_flush(), !dbg !26
%32 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y(), !dbg !26
store i32 %32, i32* %8, !dbg !26
br i1 true, label %pass8, label %fail7, !dbg !26
fail7: ; preds = %pass6
call void @llvm.trap() #2, !dbg !26
unreachable, !dbg !26
pass8: ; preds = %pass6
store i32 1, i32* %7, !dbg !26
call void @julia.gcroot_flush(), !dbg !28
%33 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z(), !dbg !28
store i32 %33, i32* %6, !dbg !28
br i1 true, label %pass10, label %fail9, !dbg !28
fail9: ; preds = %pass8
call void @llvm.trap() #2, !dbg !28
unreachable, !dbg !28
pass10: ; preds = %pass8
store i32 1, i32* %5, !dbg !28
; Filename: /home/tbesard/Projects/Julia-CUDA/CUDAnative/examples/vadd.jl
; Source line: 5
%34 = load i32, i32* %19, !dbg !15, !tbaa !30
%35 = load i32, i32* %18, !dbg !15, !tbaa !30
%36 = add i32 %34, %35, !dbg !15
%37 = sext i32 %36 to i64, !dbg !15
%38 = sub i64 %37, 1, !dbg !15
%39 = load i32, i32* %13, !dbg !15, !tbaa !30
%40 = sext i32 %39 to i64, !dbg !15
%41 = mul i64 %38, %40, !dbg !15
%42 = load i32, i32* %10, !dbg !15, !tbaa !30
%43 = load i32, i32* %9, !dbg !15, !tbaa !30
%44 = add i32 %42, %43, !dbg !15
%45 = sext i32 %44 to i64, !dbg !15
%46 = add i64 %41, %45, !dbg !15
store i64 %46, i64* %i, !dbg !15
; Source line: 6
%47 = load i64, i64* %i, !dbg !31, !tbaa !30
%48 = sub i64 %47, 1, !dbg !31
%49 = getelementptr float, float* %0, i64 %48, !dbg !31
%50 = load float, float* %49, align 1, !dbg !31, !tbaa !32
store float %50, float* %4, !dbg !31
%51 = load i64, i64* %i, !dbg !31, !tbaa !30
%52 = sub i64 %51, 1, !dbg !31
%53 = getelementptr float, float* %1, i64 %52, !dbg !31
%54 = load float, float* %53, align 1, !dbg !31, !tbaa !32
store float %54, float* %3, !dbg !31
%55 = load i64, i64* %i, !dbg !31, !tbaa !30
%56 = sub i64 %55, 1, !dbg !31
%57 = load float, float* %4, !dbg !31, !tbaa !30
%58 = load float, float* %3, !dbg !31, !tbaa !30
%59 = fadd float %57, %58, !dbg !31
%60 = getelementptr float, float* %2, i64 %56, !dbg !31
store float %59, float* %60, align 1, !dbg !31, !tbaa !32
; Source line: 8
ret void, !dbg !33
}
define %jl_value_t* @jlcall_kernel_vadd_64943(%jl_value_t*, %jl_value_t**, i32) #1 {
top:
call void @llvm.trap() #2
unreachable
}
; Function Attrs: noreturn nounwind
declare void @llvm.trap() #2
declare %jl_value_t*** @jl_get_ptls_states()
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #3
declare void @julia.gcroot_flush()
declare void @julia.gc_root_kill(%jl_value_t**)
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() #3
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() #3
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #3
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.y() #3
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.z() #3
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() #3
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.tid.y() #3
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.tid.z() #3
attributes #0 = { "jl_cgtarget"="ptx" "no-frame-pointer-elim"="true" }
attributes #1 = { "no-frame-pointer-elim"="true" }
attributes #2 = { noreturn nounwind }
attributes #3 = { nounwind readnone }
!llvm.module.flags = !{!0, !1, !2}
!llvm.dbg.cu = !{!3}
!0 = !{i32 2, !"Dwarf Version", i32 4}
!1 = !{i32 1, !"Debug Info Version", i32 3}
!2 = !{i32 1, !"Julia Codegen Target", !"ptx"}
!3 = distinct !DICompileUnit(language: DW_LANG_C89, file: !4, producer: "julia", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug, enums: !5)
!4 = !DIFile(filename: "/home/tbesard/Projects/Julia-CUDA/CUDAnative/examples/vadd.jl", directory: ".")
!5 = !{}
!6 = distinct !DISubprogram(name: "kernel_vadd", linkageName: "julia_kernel_vadd_64943", scope: null, file: !4, type: !7, isLocal: false, isDefinition: true, isOptimized: true, unit: !3, variables: !5)
!7 = !DISubroutineType(types: !5)
!8 = !{!"jtbaa_const", !9, i64 1}
!9 = !{!"jtbaa"}
!10 = !DILocation(line: 93, scope: !11, inlinedAt: !13)
!11 = distinct !DISubprogram(name: "blockIdx_x;", linkageName: "blockIdx_x", scope: !12, file: !12, type: !7, isLocal: false, isDefinition: true, isOptimized: true, unit: !3, variables: !5)
!12 = !DIFile(filename: "/home/tbesard/Projects/Julia-CUDA/CUDAnative/src/intrinsics.jl", directory: ".")
!13 = !DILocation(line: 128, scope: !14, inlinedAt: !15)
!14 = distinct !DISubprogram(name: "blockIdx;", linkageName: "blockIdx", scope: !12, file: !12, type: !7, isLocal: false, isDefinition: true, isOptimized: true, unit: !3, variables: !5)
!15 = !DILocation(line: 5, scope: !6)
!16 = !DILocation(line: 93, scope: !17, inlinedAt: !13)
!17 = distinct !DISubprogram(name: "blockIdx_y;", linkageName: "blockIdx_y", scope: !12, file: !12, type: !7, isLocal: false, isDefinition: true, isOptimized: true, unit: !3, variables: !5)
!18 = !DILocation(line: 93, scope: !19, inlinedAt: !13)
!19 = distinct !DISubprogram(name: "blockIdx_z;", linkageName: "blockIdx_z", scope: !12, file: !12, type: !7, isLocal: false, isDefinition: true, isOptimized: true, unit: !3, variables: !5)
!20 = !DILocation(line: 127, scope: !21, inlinedAt: !15)
!21 = distinct !DISubprogram(name: "blockDim;", linkageName: "blockDim", scope: !12, file: !12, type: !7, isLocal: false, isDefinition: true, isOptimized: true, unit: !3, variables: !5)
!22 = !DILocation(line: 93, scope: !23, inlinedAt: !24)
!23 = distinct !DISubprogram(name: "threadIdx_x;", linkageName: "threadIdx_x", scope: !12, file: !12, type: !7, isLocal: false, isDefinition: true, isOptimized: true, unit: !3, variables: !5)
!24 = !DILocation(line: 126, scope: !25, inlinedAt: !15)
!25 = distinct !DISubprogram(name: "threadIdx;", linkageName: "threadIdx", scope: !12, file: !12, type: !7, isLocal: false, isDefinition: true, isOptimized: true, unit: !3, variables: !5)
!26 = !DILocation(line: 93, scope: !27, inlinedAt: !24)
!27 = distinct !DISubprogram(name: "threadIdx_y;", linkageName: "threadIdx_y", scope: !12, file: !12, type: !7, isLocal: false, isDefinition: true, isOptimized: true, unit: !3, variables: !5)
!28 = !DILocation(line: 93, scope: !29, inlinedAt: !24)
!29 = distinct !DISubprogram(name: "threadIdx_z;", linkageName: "threadIdx_z", scope: !12, file: !12, type: !7, isLocal: false, isDefinition: true, isOptimized: true, unit: !3, variables: !5)
!30 = !{!"jtbaa_stack", !9}
!31 = !DILocation(line: 6, scope: !6)
!32 = !{!"jtbaa_data", !9}
!33 = !DILocation(line: 8, scope: !6)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment