Skip to content

Instantly share code, notes, and snippets.

@sdewaele
Last active January 25, 2020 12:42
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sdewaele/9a9c705eb4e8da9d798056d18d04c383 to your computer and use it in GitHub Desktop.
Save sdewaele/9a9c705eb4e8da9d798056d18d04c383 to your computer and use it in GitHub Desktop.
setindex with copy to enable Zygote autodiff
using ZygoteRules:@adjoint
using Zygote
using Random
function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
for (x, Δ) in zip(xs, grads), i in 1:length(x)
δ = sqrt(eps())
tmp = x[i]
x[i] = tmp - δ/2
y1 = f(xs...)
x[i] = tmp + δ/2
y2 = f(xs...)
x[i] = tmp
Δ[i] = (y2-y1)/δ
end
return grads
end
"""
setindex(A,X,inds...)
`setindex` with copy. An alternative for `setindex!`
that allows automatic differentation with Zygote.
Example:
```
A = setindex(A,X,1:2,1:2)
```
Replaces `A[1:2,1:2] = X`
"""
setindex(A,X,inds...) = setindex!(copy(A),X,inds...)
## Adjoint
@adjoint setindex(A,X,inds...) = begin
B = setindex(A,X,inds...)
adj = function(Δ)
bA = copy(Δ)
bA[inds...] .= zero(eltype(A))
bX = similar(X)
bX[:] = Δ[inds...]
binds = fill(nothing,length(inds))
return bA,bX,binds...
end
B,adj
end
## Test setindex
rng = MersenneTwister(234238)
n = 3; m = 4;
A = rand(rng,n,m);
B = copy(A)
X = [1.0 2
10 11]
A[1:2,1:2] = X
A2 = setindex(B,X,1:2,1:2)
@show A==A2 # = true
## Test Zygote adjoint with respect to X
B,back = Zygote.forward(X->setindex(A,X,1:2,1:2),X)
bB = rand(rng,size(B)...)
bX = back(bB)[1]
## Test Zygote adjoint with respect to A
B,backA = Zygote.forward(A->setindex(A,X,1:2,1:2),A)
bX = backA(bB)[1]
## Test with numerical gradient
# - adjoint with respect to X
ftestX = X->sum(sin.(setindex(A,X,1:2,1:2)))
gradX = gradient(ftestX,X)[1]
ngradX = ngradient(ftestX,X)[1]
@show isapprox(gradX,ngradX,atol=1e-5) # = true
# - Adjoint with respect to A
ftestA = A->sum(sin.(setindex(A,X,1:2,1:2)))
gradA = gradient(ftestA,A)[1]
ngradA = ngradient(ftestA,A)[1]
@show isapprox(gradA,ngradA,atol=1e-5) # = true
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment