Skip to content

Instantly share code, notes, and snippets.

@mcabbott
Last active April 19, 2019 18:52
Show Gist options
  • Save mcabbott/ecb9a7756c0530e8fae0ef444761ffcd to your computer and use it in GitHub Desktop.
Save mcabbott/ecb9a7756c0530e8fae0ef444761ffcd to your computer and use it in GitHub Desktop.
Another look at prod(x) gradients
##### Taking another look at prod(x) gradients
# see https://github.com/FluxML/Flux.jl/pull/524
using Flux, Zygote, ForwardDiff, BenchmarkTools # Flux v0.7.3, Zygote 10 March 2019
M = rand(5,10);
MM = rand(10,1000);
@btime prod($M, dims=1) # 79.748 ns (1 allocation: 160 bytes)
@btime prod($MM, dims=1) # 4.157 μs (1 allocation: 7.94 KiB)
ForwardDiff.gradient(x->sum(prod(x, dims=1)), M)
Flux.gradient(x->sum(prod(x, dims=1)), M)[1] # works correctly
Zygote.gradient(x->sum(prod(x, dims=1)), M)[1] # wrong!
@btime ForwardDiff.gradient(x->sum(prod(x, dims=1)), $M)[1] # 3.115 μs (9 allocations: 10.66 KiB)
@btime ForwardDiff.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 148.075 ms (1674 allocations: 83.85 MiB)
# Flux is using TrackedReal, slower than ForwardDiff
@btime Flux.gradient(x->sum(prod(x, dims=1)), $M)[1] # 16.693 μs (763 allocations: 45.83 KiB)
@btime Flux.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 276.165 ms (155015 allocations: 767.66 MiB)
##### Simplest possible gradient definition
import Base: *
using Flux.Tracker: TrackedArray, track, @grad, data, nobacksies
using Zygote: @adjoint
Base.prod(xs::TrackedArray; dims=:) = track(prod, xs; dims=dims)
@grad prod(xs; dims=:) = _prod(xs.data, prod(xs.data, dims=dims), dims)
_prod(xd, p, dims) = p, Δ -> (p ./ xd .* Δ,)
@adjoint prod(xs; dims=:) = _prod(xs, prod(xs, dims=dims), dims)
@btime Flux.gradient(x->sum(prod(x, dims=1)), $M)[1] # 905.100 ns (25 allocations: 2.45 KiB)
@btime Flux.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 36.058 μs (28 allocations: 258.95 KiB)
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $M)[1] # 1.479 μs (18 allocations: 1008 bytes)
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 15.641 μs (19 allocations: 86.48 KiB)
##### But that is not correct if the matrix contains zeros.
Z = copy(M);
Z[1,1] = 0; Z[3,3] = 0; Z[4,4] = 0; Z[3,5] = 0; Z[5,5] = 0;
ZZ = copy(MM);
for i=1:10
ZZ[i,i:20:end] .= 0 # half the columns have a zero
end
ForwardDiff.gradient(x->sum(prod(x, dims=1)), Z)
Flux.gradient(x->sum(prod(x, dims=1)), Z)[1] # lots of NaN
Zygote.gradient(x->sum(prod(x, dims=1)), Z)[1] # lots of NaN
##### My proposal is to mapslices(∇prod, x, dims) a function which understands zero:
_prod(xd, p, ::Colon) = p, Δ -> (nobacksies(:prod, ∇prod(xd, p, data(Δ)) ),)
_prod(xd, p, dims) = (p, Δ -> (nobacksies(:prod, mapslices(∇prod, xd; dims=dims) .* data(Δ)),))
function ∇prod(x, p=prod(x), Δ=1)
numzero = count(iszero, x)
if numzero == 0
∇ = p ./ x .* Δ
elseif numzero > 1
∇ = zero(x)
else
∇ = ∇prod_one(x, Δ)
end
end
function ∇prod_one(x, Δ)
zloc = findfirst(iszero, x)
∇ = copy(x)
∇[zloc] = 1
nonzero = prod(∇) * Δ
∇ .= 0
∇[zloc] = nonzero
end
# Fast on large arrays, not great on small ones:
@btime Flux.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 13.891 μs (135 allocations: 8.06 KiB)
@btime Flux.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 636.920 μs (7546 allocations: 659.45 KiB)
# But also now slow on arrays without zeros:
@btime Flux.gradient(x->sum(prod(x, dims=1)), $M)[1] # 13.770 μs (135 allocations: 8.06 KiB)
@btime Flux.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 648.196 μs (7546 allocations: 659.45 KiB)
# Same for Zygote:
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 11.912 μs (111 allocations: 6.28 KiB)
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 614.808 μs (7520 allocations: 486.67 KiB)
##### To make that faster, #524 first tests for the presence of zeros:
_prod(xd, p, dims) = count(iszero, p) == 0 ?
(p, Δ -> (nobacksies(:prod, p ./ xd .* data(Δ) ),)) :
(p, Δ -> (nobacksies(:prod, mapslices(∇prod, xd; dims=dims) .* data(Δ)),))
# similar speed on Z, but on M much better:
@btime Flux.gradient(x->sum(prod(x, dims=1)), $M)[1] # 1.071 μs (31 allocations: 2.63 KiB)
@btime Flux.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 36.367 μs (34 allocations: 259.13 KiB)
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $M)[1] # 2.691 μs (25 allocations: 1.17 KiB)
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $MM)[1] # 17.456 μs (26 allocations: 86.67 KiB)
##### The problem is that ∇prod_one isn't going to work on CuArrays, if they contain zeros.
# You can use this circshift thing, but it is extremely slow:
∇prod_one(x, Δ) = reshape(.*(circshift.((x,), 1:length(x)-1)...), size(x)) .* Δ
@btime Flux.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 37.132 μs (203 allocations: 11.78 KiB)
@btime Flux.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 6.899 ms (29048 allocations: 2.04 MiB)
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 39.671 μs (195 allocations: 10.33 KiB)
@btime Zygote.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 6.827 ms (29038 allocations: 1.87 MiB)
# It also seems to crash Julia if you take the product of something too big:
# ∇prod_one(rand(1000), 1) # not recommended
##### Here's another idea, broadcasting to a matrix and then reshaping that slightly:
function ∇prod_one(x, Δ)
n = length(x) - 1
# m = reshape(vec(x) .* trues(n)' .* Δ, (n,:)) # doesn't work on CuArrays
o = x[1:end-1] # copies! # nor does ones(x)[1:n]
o .= 1
m = reshape(vec(x) .* o' .* Δ, (n,:))
v = reverse(vec(prod(m, dims=1)))
reshape(v, size(x))
end
# compare first to earlier candidates, as above, just with new names:
function ∇prod_one_fast(x, Δ)
zloc = findfirst(iszero, x)
∇ = copy(x)
∇[zloc] = 1
nonzero = prod(∇) * Δ
∇ .= 0
∇[zloc] = nonzero
end
∇prod_one_circ(x, Δ) = reshape(.*(circshift.((x,), 1:length(x)-1)...), size(x)) .* Δ
V = rand(10); V[3]=0
# time in isolation - only 10x not 200x worse...
@btime ∇prod_one($V, 100) # 343.461 ns (8 allocations: 1.44 KiB)
@btime ∇prod_one_fast($V, 100) # 42.710 ns (1 allocation: 160 bytes)
@btime ∇prod_one_circ($V, 100) # 10.819 μs (43 allocations: 2.92 KiB)
# times when called by the above mapslices stuff -- not bad!
@btime Flux.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 15.121 μs (158 allocations: 10.05 KiB)
@btime Flux.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 800.394 μs (11048 allocations: 1.27 MiB)
# you could even go one step further, and use this for ∇prod not ∇prod_one:
∇prod(x, p=nothing, Δ=1) = ∇prod_one(x, Δ)
@btime Flux.gradient(x->sum(prod(x, dims=1)), $Z)[1] # 16.980 μs (207 allocations: 14.53 KiB)
@btime Flux.gradient(x->sum(prod(x, dims=1)), $ZZ)[1] # 950.762 μs (14548 allocations: 1.90 MiB)
# if I'm not mistaken, that would may also allow all this to work for 2nd derivatives.
# I haven't yet made it work on GPU, all of these fail right now!
using CuArrays
CuArrays.allowscalar(false)
cV = cu(V)
∇prod_one(cV, 100) # reverse(::CuArray{Float32,1}, does not seem to work, https://github.com/JuliaGPU/CuArrays.jl/issues/299
∇prod_one_fast(cV, 100) # obviously fails, scalar indexing
∇prod_one_circ(cV, 100) # runs into https://github.com/JuliaGPU/CuArrays.jl/issues/161 , fixed on master it seems
##### Here's that version in isolation:
import Base: *
using Flux.Tracker: TrackedArray, track, @grad, data
using Zygote: @adjoint
Base.prod(xs::TrackedArray; dims=:) = track(prod, xs; dims=dims)
@grad prod(xs; dims=:) = _prod(xs.data, prod(xs.data, dims=dims), dims) # is .data correct here?
@adjoint prod(xs; dims=:) = _prod(xs, prod(xs, dims=dims), dims)
function _prod(xd, p, dims)
if count(iszero, p) == 0
p, Δ -> (p ./ xd .* Δ ,)
elseif dims == Colon()
p, Δ -> (∇prod(xd, Δ) ,)
else
p, Δ -> (mapslices(∇prod, xd; dims=dims) .* Δ,)
end
end
function ∇prod(x, Δ=1)
onerow = transpose(x[1:end-1]) .= 1 # this works on GPU... try similar(x, length(x)-1) perhaps
mat = reshape(vec(x) .* onerow .* Δ, (length(x)-1,:))
grad = reverse(vec(prod(mat, dims=1))) # reverse doesn't work on GPU right now
reshape(grad, size(x))
end
# Test using the same M,MM,Z,ZZ as defined above:
# at worst 30% worse than the best complicated version above, on these samples.
# You could also use onerow = Zygote.FillArray(one(eltype(x)), (1,length(x)-1,))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment