Last active
January 25, 2020 12:42
-
-
Save sdewaele/9a9c705eb4e8da9d798056d18d04c383 to your computer and use it in GitHub Desktop.
setindex with copy to enable Zygote autodiff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ZygoteRules:@adjoint | |
using Zygote | |
using Random | |
function ngradient(f, xs::AbstractArray...) | |
grads = zero.(xs) | |
for (x, Δ) in zip(xs, grads), i in 1:length(x) | |
δ = sqrt(eps()) | |
tmp = x[i] | |
x[i] = tmp - δ/2 | |
y1 = f(xs...) | |
x[i] = tmp + δ/2 | |
y2 = f(xs...) | |
x[i] = tmp | |
Δ[i] = (y2-y1)/δ | |
end | |
return grads | |
end | |
""" | |
setindex(A,X,inds...) | |
`setindex` with copy. An alternative for `setindex!` | |
that allows automatic differentation with Zygote. | |
Example: | |
``` | |
A = setindex(A,X,1:2,1:2) | |
``` | |
Replaces `A[1:2,1:2] = X` | |
""" | |
setindex(A,X,inds...) = setindex!(copy(A),X,inds...) | |
## Adjoint | |
@adjoint setindex(A,X,inds...) = begin | |
B = setindex(A,X,inds...) | |
adj = function(Δ) | |
bA = copy(Δ) | |
bA[inds...] .= zero(eltype(A)) | |
bX = similar(X) | |
bX[:] = Δ[inds...] | |
binds = fill(nothing,length(inds)) | |
return bA,bX,binds... | |
end | |
B,adj | |
end | |
## Test setindex | |
rng = MersenneTwister(234238) | |
n = 3; m = 4; | |
A = rand(rng,n,m); | |
B = copy(A) | |
X = [1.0 2 | |
10 11] | |
A[1:2,1:2] = X | |
A2 = setindex(B,X,1:2,1:2) | |
@show A==A2 # = true | |
## Test Zygote adjoint with respect to X | |
B,back = Zygote.forward(X->setindex(A,X,1:2,1:2),X) | |
bB = rand(rng,size(B)...) | |
bX = back(bB)[1] | |
## Test Zygote adjoint with respect to A | |
B,backA = Zygote.forward(A->setindex(A,X,1:2,1:2),A) | |
bX = backA(bB)[1] | |
## Test with numerical gradient | |
# - adjoint with respect to X | |
ftestX = X->sum(sin.(setindex(A,X,1:2,1:2))) | |
gradX = gradient(ftestX,X)[1] | |
ngradX = ngradient(ftestX,X)[1] | |
@show isapprox(gradX,ngradX,atol=1e-5) # = true | |
# - Adjoint with respect to A | |
ftestA = A->sum(sin.(setindex(A,X,1:2,1:2))) | |
gradA = gradient(ftestA,A)[1] | |
ngradA = ngradient(ftestA,A)[1] | |
@show isapprox(gradA,ngradA,atol=1e-5) # = true |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment