Created
November 8, 2019 17:12
-
-
Save jrevels/26c760131c80662771d36b5965239b85 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using CuArrays | |
using CuArrays: CuArray, unsafe_free!, adapt, Mem, CuPtr | |
using BenchmarkTools | |
############################################################################### | |
mutable struct SimpleCuIterator{B} | |
batches::B | |
previous::Any | |
SimpleCuIterator(batches) = new{typeof(batches)}(batches) | |
end | |
function Base.iterate(c::SimpleCuIterator, state...) | |
item = iterate(c.batches, state...) | |
isdefined(c, :previous) && foreach(unsafe_free!, c.previous) | |
item === nothing && return nothing | |
batch, next_state = item | |
cubatch = map(x -> adapt(CuArray, x), batch) | |
c.previous = cubatch | |
return cubatch, next_state | |
end | |
############################################################################### | |
mutable struct CuIterator{B} | |
batches::B | |
initial_pool_size::Int | |
pool::Mem.DeviceBuffer | |
function CuIterator(batches, initial_pool_size=0) | |
return new{typeof(batches)}(batches, initial_pool_size) | |
end | |
end | |
function Base.iterate(c::CuIterator, state...) | |
item = iterate(c.batches, state...) | |
if item === nothing | |
isdefined(c, :pool) && Mem.free(c.pool) | |
return nothing | |
end | |
batch, next_state = item | |
required_pool_size = sum(sizeof, batch) | |
if isempty(state) | |
c.initial_pool_size = max(required_pool_size, c.initial_pool_size) | |
c.pool = Mem.alloc(Mem.DeviceBuffer, c.initial_pool_size) | |
elseif required_pool_size > sizeof(c.pool) | |
Mem.free(c.pool) | |
c.initial_pool_size = required_pool_size | |
c.pool = Mem.alloc(Mem.DeviceBuffer, required_pool_size) | |
end | |
pool = c.pool | |
offset = 0 | |
cubatch = map(batch) do array | |
@assert array isa AbstractArray | |
ptr = Base.unsafe_convert(CuPtr{eltype(array)}, pool.ptr + offset) | |
cuarray = unsafe_wrap(CuArray, ptr, size(array); own=false); | |
copyto!(cuarray, array) | |
offset += sizeof(array) | |
return cuarray | |
end | |
return cubatch, next_state | |
end | |
############################################################################### | |
const n = 1500 | |
const k1, k2, k3 = 64, 128, 256 | |
const ws1 = [cu(rand(Float32, 4096, k1)) for _ in 1:n]; | |
const ws2 = [cu(rand(Float32, 4096, k2)) for _ in 1:n]; | |
const ws3 = [cu(rand(Float32, 4096, k3)) for _ in 1:n]; | |
kernel((x1, x2, x3),) = sum(w->sum(w*x1), ws1) + sum(w->sum(w*x2), ws2) + sum(w->sum(w*x3), ws3) | |
function make_batches(n) | |
batches = ((rand(Float32, k1), rand(Float32, k2), rand(Float32, k3)) for _ in 1:n) | |
basic = (map(x -> adapt(CuArray, x), batch) for batch in batches) | |
return basic, CuIterator(batches), SimpleCuIterator(batches) | |
end | |
for i in 7:12 | |
n = 2^i | |
println("############################### benchmarking for size ", n) | |
println("###### basic:") | |
basic, cuiter, simple_cuiter = make_batches(n) | |
show(stdout, MIME("text/plain"), @benchmark(sum(kernel, $basic), evals=1)) | |
println(); println(); | |
println("###### SimpleCuIterator:") | |
show(stdout, MIME("text/plain"), @benchmark(sum(kernel, $simple_cuiter), evals=1)) | |
println(); println(); | |
println("###### CuIterator:") | |
show(stdout, MIME("text/plain"), @benchmark(sum(kernel, $cuiter), evals=1)) | |
println(); println(); | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment