Skip to content

Instantly share code, notes, and snippets.

@jrevels
Created November 8, 2019 17:12
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jrevels/26c760131c80662771d36b5965239b85 to your computer and use it in GitHub Desktop.
Save jrevels/26c760131c80662771d36b5965239b85 to your computer and use it in GitHub Desktop.
using CuArrays
using CuArrays: CuArray, unsafe_free!, adapt, Mem, CuPtr
using BenchmarkTools
###############################################################################
mutable struct SimpleCuIterator{B}
batches::B
previous::Any
SimpleCuIterator(batches) = new{typeof(batches)}(batches)
end
function Base.iterate(c::SimpleCuIterator, state...)
item = iterate(c.batches, state...)
isdefined(c, :previous) && foreach(unsafe_free!, c.previous)
item === nothing && return nothing
batch, next_state = item
cubatch = map(x -> adapt(CuArray, x), batch)
c.previous = cubatch
return cubatch, next_state
end
###############################################################################
mutable struct CuIterator{B}
batches::B
initial_pool_size::Int
pool::Mem.DeviceBuffer
function CuIterator(batches, initial_pool_size=0)
return new{typeof(batches)}(batches, initial_pool_size)
end
end
function Base.iterate(c::CuIterator, state...)
item = iterate(c.batches, state...)
if item === nothing
isdefined(c, :pool) && Mem.free(c.pool)
return nothing
end
batch, next_state = item
required_pool_size = sum(sizeof, batch)
if isempty(state)
c.initial_pool_size = max(required_pool_size, c.initial_pool_size)
c.pool = Mem.alloc(Mem.DeviceBuffer, c.initial_pool_size)
elseif required_pool_size > sizeof(c.pool)
Mem.free(c.pool)
c.initial_pool_size = required_pool_size
c.pool = Mem.alloc(Mem.DeviceBuffer, required_pool_size)
end
pool = c.pool
offset = 0
cubatch = map(batch) do array
@assert array isa AbstractArray
ptr = Base.unsafe_convert(CuPtr{eltype(array)}, pool.ptr + offset)
cuarray = unsafe_wrap(CuArray, ptr, size(array); own=false);
copyto!(cuarray, array)
offset += sizeof(array)
return cuarray
end
return cubatch, next_state
end
###############################################################################
const n = 1500
const k1, k2, k3 = 64, 128, 256
const ws1 = [cu(rand(Float32, 4096, k1)) for _ in 1:n];
const ws2 = [cu(rand(Float32, 4096, k2)) for _ in 1:n];
const ws3 = [cu(rand(Float32, 4096, k3)) for _ in 1:n];
kernel((x1, x2, x3),) = sum(w->sum(w*x1), ws1) + sum(w->sum(w*x2), ws2) + sum(w->sum(w*x3), ws3)
function make_batches(n)
batches = ((rand(Float32, k1), rand(Float32, k2), rand(Float32, k3)) for _ in 1:n)
basic = (map(x -> adapt(CuArray, x), batch) for batch in batches)
return basic, CuIterator(batches), SimpleCuIterator(batches)
end
for i in 7:12
n = 2^i
println("############################### benchmarking for size ", n)
println("###### basic:")
basic, cuiter, simple_cuiter = make_batches(n)
show(stdout, MIME("text/plain"), @benchmark(sum(kernel, $basic), evals=1))
println(); println();
println("###### SimpleCuIterator:")
show(stdout, MIME("text/plain"), @benchmark(sum(kernel, $simple_cuiter), evals=1))
println(); println();
println("###### CuIterator:")
show(stdout, MIME("text/plain"), @benchmark(sum(kernel, $cuiter), evals=1))
println(); println();
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment