Last active
September 5, 2018 13:12
-
-
Save ornithos/4363d88d33c8930ce75cb74da2e37037 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# call Chain(Dense(2, 400, NNlib.relu), Dense(400, 784, NNlib.σ)) | |
# 100x with random input, where the matmul/add is called "preactiv" | |
# for layers 1 and 2, and the nonlinearities are timed separately. | |
# As always if using sigmoidal transformations, these account for | |
# much of the forward and backward time pass despite being element- | |
# wise operations. | |
#= | |
──────────────────────────────────────────────────────────────────── | |
Time Allocations | |
────────────────────── ─────────────────────── | |
Tot / % measured: 8.50s / 15.8% 1.33GiB / 100% | |
Section ncalls time %tot avg alloc %tot avg | |
──────────────────────────────────────────────────────────────────── | |
l2_sigmoid 100 765ms 57.1% 7.65ms 299MiB 22.1% 2.99MiB | |
l2_preactiv 100 428ms 31.9% 4.28ms 598MiB 44.1% 5.98MiB | |
l1_preactiv 100 125ms 9.35% 1.25ms 306MiB 22.6% 3.06MiB | |
l1_relu 100 22.3ms 1.66% 223μs 153MiB 11.3% 1.53MiB | |
──────────────────────────────────────────────────────────────────── | |
=# | |
using Random | |
using Flux | |
using Flux.Tracker | |
### ====== STANDARD HAND-CODED GRADIENT =========== | |
Random.seed!(42) | |
tmp = 0.0 | |
tmpgrad1 = zeros(400,1000) | |
@time for i in 1:100 | |
tmpcx = randn(400,1000) | |
global tmp += sum(σ.(tmpcx)) | |
global tmpgrad1 += σ.(tmpcx) .* (1 .- σ.(tmpcx)) | |
end | |
# 2.031375 seconds (1000 allocations: 1.192 GiB, 23.59% gc time) | |
### ====== FLUX AUTOMATIC DIFFERENTIATION =========== | |
Random.seed!(42) | |
tmp = 0.0 | |
tmpgrad2 = zeros(400,1000) | |
@time for i in 1:100 | |
tmpcx = param(randn(400,1000)) | |
tmp = sum(σ.(tmpcx)) | |
Tracker.back!(sumsigma) | |
global tmpgrad2 += tmpcx.tracker.grad | |
tmpcx.tracker.grad .= 0 | |
end | |
# 2.329787 seconds (5.00 k allocations: 2.384 GiB, 23.97% gc time) | |
### ====== MORE EFFICIENT HAND-CODED FWD/BWD =========== | |
### ====== (uses more memory, but grad allocates ======= | |
### ====== this same footprint anyway ) =============== | |
Random.seed!(42) | |
tmp = 0.0 | |
tmpgrad3 = zeros(400,1000) | |
@time for i in 1:100 | |
tmpcx = randn(400,1000) | |
tmpsig = σ.(tmpcx) | |
global tmp += sum(tmpsig) | |
global tmpgrad3 += tmpsig .* (1 .- tmpsig) | |
end | |
# 1.062505 seconds (1000 allocations: 1.192 GiB, 24.93% gc time) | |
# i.e. if we have a specialist σ gradient developed, which could | |
# cache the results of the forward pass (ideally in its grad | |
# container, if it could be marked as 'dirty'), we could posisbly | |
# dramatically reduce the time for the backward pass. This would | |
# also work for tanh. The overhead for calculating the gradient | |
# in this 3rd "efficient" fwd/bwd is about 25% of the forward pass. | |
# The forward pass alone is ≈ 840ms, and incl. the bwd ≈ 1060ms. | |
isapprox(tmpgrad1, tmpgrad2, atol=1e-5, rtol=1e-5) # true | |
isapprox(tmpgrad1, tmpgrad3, atol=1e-5, rtol=1e-5) # true |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment