Skip to content

Instantly share code, notes, and snippets.

@mfalt
Created March 3, 2022 11:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mfalt/16a98f9ddc61bf55460e3e76208bbaa5 to your computer and use it in GitHub Desktop.
Save mfalt/16a98f9ddc61bf55460e3e76208bbaa5 to your computer and use it in GitHub Desktop.
using Flux
const nn = gpu(Chain(
Conv((5, 5), 3=>6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6=>16, relu),
MaxPool((2, 2)),
flatten,
Dense(256, 120, relu),
Dense(120, 84, relu),
Dense(84, 4)
))
function inference(imgs)
out = nn(imgs)
maximum(cpu(out))
end
imgs = gpu(randn(28,28,3,1000))
N = 100
# Example 1, Threaded execution, crashes after
# about 10 runs with 11GB GPU Memory
res = Any[nothing for i in 1:N]
@sync for i in 1:N
res[i] = Threads.@spawn inference(imgs)
end
s = sum(fetch.(res))
# Example 2, Serially threaded, crashes after
# about 10 runs with 11GB GPU Memory
for j in 1:10
res = Any[nothing for i in 1:N]
for i in 1:N
res[i] = fetch(Threads.@spawn inference(imgs))
end
s = sum(res)
println(s)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment