-
-
Save mcabbott/ecb9a7756c0530e8fae0ef444761ffcd to your computer and use it in GitHub Desktop.
Another look at prod(x) gradients
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
##### Taking another look at prod(x) gradients | |
# see https://github.com/FluxML/Flux.jl/pull/524 | |
using Flux, Zygote, ForwardDiff, BenchmarkTools # Flux v0.7.3, Zygote 10 March 2019 | |
M = rand(5,10); | |
MM = rand(10,1000); | |
@btime prod($M, dims=1) # 79.748 ns (1 allocation: 160 bytes) | |
@btime prod($MM, dims=1) # 4.157 μs (1 allocation: 7.94 KiB) | |
ForwardDiff.gradient(x->sum(prod(x, dims=1)), M) | |
Flux.gradient(x->sum(prod(x, dims=1)), M)[1] # works correctly | |
Zygote.gradient(x->sum(prod(x, dims=1)), M)[1] # wrong! | |
@btime ForwardDiff.gradient(x->sum(prod(x, dims=1)), $M)[1] # 3.115 μs (9 allocations: 10.66 KiB) | |
@btime ForwardDiff.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 148.075 ms (1674 allocations: 83.85 MiB) | |
# Flux is using TrackedReal, slower than ForwardDiff | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $M)[1] # 16.693 μs (763 allocations: 45.83 KiB) | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 276.165 ms (155015 allocations: 767.66 MiB) | |
##### Simplest possible gradient definition | |
import Base: * | |
using Flux.Tracker: TrackedArray, track, @grad, data, nobacksies | |
using Zygote: @adjoint | |
Base.prod(xs::TrackedArray; dims=:) = track(prod, xs; dims=dims) | |
@grad prod(xs; dims=:) = _prod(xs.data, prod(xs.data, dims=dims), dims) | |
_prod(xd, p, dims) = p, Δ -> (p ./ xd .* Δ,) | |
@adjoint prod(xs; dims=:) = _prod(xs, prod(xs, dims=dims), dims) | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $M)[1] # 905.100 ns (25 allocations: 2.45 KiB) | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 36.058 μs (28 allocations: 258.95 KiB) | |
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $M)[1] # 1.479 μs (18 allocations: 1008 bytes) | |
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 15.641 μs (19 allocations: 86.48 KiB) | |
##### But that is not correct if the matrix contains zeros. | |
Z = copy(M); | |
Z[1,1] = 0; Z[3,3] = 0; Z[4,4] = 0; Z[3,5] = 0; Z[5,5] = 0; | |
ZZ = copy(MM); | |
for i=1:10 | |
ZZ[i,i:20:end] .= 0 # half the columns have a zero | |
end | |
ForwardDiff.gradient(x->sum(prod(x, dims=1)), Z) | |
Flux.gradient(x->sum(prod(x, dims=1)), Z)[1] # lots of NaN | |
Zygote.gradient(x->sum(prod(x, dims=1)), Z)[1] # lots of NaN | |
##### My proposal is to mapslices(∇prod, x, dims) a function which understands zero: | |
_prod(xd, p, ::Colon) = p, Δ -> (nobacksies(:prod, ∇prod(xd, p, data(Δ)) ),) | |
_prod(xd, p, dims) = (p, Δ -> (nobacksies(:prod, mapslices(∇prod, xd; dims=dims) .* data(Δ)),)) | |
function ∇prod(x, p=prod(x), Δ=1) | |
numzero = count(iszero, x) | |
if numzero == 0 | |
∇ = p ./ x .* Δ | |
elseif numzero > 1 | |
∇ = zero(x) | |
else | |
∇ = ∇prod_one(x, Δ) | |
end | |
end | |
function ∇prod_one(x, Δ) | |
zloc = findfirst(iszero, x) | |
∇ = copy(x) | |
∇[zloc] = 1 | |
nonzero = prod(∇) * Δ | |
∇ .= 0 | |
∇[zloc] = nonzero | |
∇ | |
end | |
# Fast on large arrays, not great on small ones: | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 13.891 μs (135 allocations: 8.06 KiB) | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 636.920 μs (7546 allocations: 659.45 KiB) | |
# But also now slow on arrays without zeros: | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $M)[1] # 13.770 μs (135 allocations: 8.06 KiB) | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 648.196 μs (7546 allocations: 659.45 KiB) | |
# Same for Zygote: | |
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 11.912 μs (111 allocations: 6.28 KiB) | |
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 614.808 μs (7520 allocations: 486.67 KiB) | |
##### To make that faster, #524 first tests for the presence of zeros: | |
_prod(xd, p, dims) = count(iszero, p) == 0 ? | |
(p, Δ -> (nobacksies(:prod, p ./ xd .* data(Δ) ),)) : | |
(p, Δ -> (nobacksies(:prod, mapslices(∇prod, xd; dims=dims) .* data(Δ)),)) | |
# similar speed on Z, but on M much better: | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $M)[1] # 1.071 μs (31 allocations: 2.63 KiB) | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 36.367 μs (34 allocations: 259.13 KiB) | |
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $M)[1] # 2.691 μs (25 allocations: 1.17 KiB) | |
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 17.456 μs (26 allocations: 86.67 KiB) | |
##### The problem is that ∇prod_one isn't going to work on CuArrays, if they contain zeros. | |
# You can use this circshift thing, but it is extremely slow: | |
∇prod_one(x, Δ) = reshape(.*(circshift.((x,), 1:length(x)-1)...), size(x)) .* Δ | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 37.132 μs (203 allocations: 11.78 KiB) | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 6.899 ms (29048 allocations: 2.04 MiB) | |
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 39.671 μs (195 allocations: 10.33 KiB) | |
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 6.827 ms (29038 allocations: 1.87 MiB) | |
# It also seems to crash Julia if you take the product of something too big: | |
# ∇prod_one(rand(1000), 1) # not recommended | |
##### Here's another idea, broadcasting to a matrix and then reshaping that slightly: | |
function ∇prod_one(x, Δ) | |
n = length(x) - 1 | |
# m = reshape(vec(x) .* trues(n)' .* Δ, (n,:)) # doesn't work on CuArrays | |
o = x[1:end-1] # copies! # nor does ones(x)[1:n] | |
o .= 1 | |
m = reshape(vec(x) .* o' .* Δ, (n,:)) | |
v = reverse(vec(prod(m, dims=1))) | |
reshape(v, size(x)) | |
end | |
# compare first to earlier candidates, as above, just with new names: | |
function ∇prod_one_fast(x, Δ) | |
zloc = findfirst(iszero, x) | |
∇ = copy(x) | |
∇[zloc] = 1 | |
nonzero = prod(∇) * Δ | |
∇ .= 0 | |
∇[zloc] = nonzero | |
∇ | |
end | |
∇prod_one_circ(x, Δ) = reshape(.*(circshift.((x,), 1:length(x)-1)...), size(x)) .* Δ | |
V = rand(10); V[3]=0 | |
# time in isolation - only 10x not 200x worse... | |
@btime ∇prod_one($V, 100) # 343.461 ns (8 allocations: 1.44 KiB) | |
@btime ∇prod_one_fast($V, 100) # 42.710 ns (1 allocation: 160 bytes) | |
@btime ∇prod_one_circ($V, 100) # 10.819 μs (43 allocations: 2.92 KiB) | |
# times when called by the above mapslices stuff -- not bad! | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 15.121 μs (158 allocations: 10.05 KiB) | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 800.394 μs (11048 allocations: 1.27 MiB) | |
# you could even go one step further, and use this for ∇prod not ∇prod_one: | |
∇prod(x, p=nothing, Δ=1) = ∇prod_one(x, Δ) | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 16.980 μs (207 allocations: 14.53 KiB) | |
@btime Flux.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 950.762 μs (14548 allocations: 1.90 MiB) | |
# if I'm not mistaken, that would may also allow all this to work for 2nd derivatives. | |
# I haven't yet made it work on GPU, all of these fail right now! | |
using CuArrays | |
CuArrays.allowscalar(false) | |
cV = cu(V) | |
∇prod_one(cV, 100) # reverse(::CuArray{Float32,1}, does not seem to work, https://github.com/JuliaGPU/CuArrays.jl/issues/299 | |
∇prod_one_fast(cV, 100) # obviously fails, scalar indexing | |
∇prod_one_circ(cV, 100) # runs into https://github.com/JuliaGPU/CuArrays.jl/issues/161 , fixed on master it seems | |
##### Here's that version in isolation: | |
import Base: * | |
using Flux.Tracker: TrackedArray, track, @grad, data | |
using Zygote: @adjoint | |
Base.prod(xs::TrackedArray; dims=:) = track(prod, xs; dims=dims) | |
@grad prod(xs; dims=:) = _prod(xs.data, prod(xs.data, dims=dims), dims) # is .data correct here? | |
@adjoint prod(xs; dims=:) = _prod(xs, prod(xs, dims=dims), dims) | |
function _prod(xd, p, dims) | |
if count(iszero, p) == 0 | |
p, Δ -> (p ./ xd .* Δ ,) | |
elseif dims == Colon() | |
p, Δ -> (∇prod(xd, Δ) ,) | |
else | |
p, Δ -> (mapslices(∇prod, xd; dims=dims) .* Δ,) | |
end | |
end | |
function ∇prod(x, Δ=1) | |
onerow = transpose(x[1:end-1]) .= 1 # this works on GPU... try similar(x, length(x)-1) perhaps | |
mat = reshape(vec(x) .* onerow .* Δ, (length(x)-1,:)) | |
grad = reverse(vec(prod(mat, dims=1))) # reverse doesn't work on GPU right now | |
reshape(grad, size(x)) | |
end | |
# Test using the same M,MM,Z,ZZ as defined above: | |
# at worst 30% worse than the best complicated version above, on these samples. | |
# You could also use onerow = Zygote.FillArray(one(eltype(x)), (1,length(x)-1,)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment