Created
December 19, 2019 05:59
-
-
Save HamletWantToCode/fb7c8b85b4543a8a1b19df4fa2bae35b to your computer and use it in GitHub Desktop.
check nested AD #Flux #Tracker
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
# using DifferentialEquations | |
using Random: seed! | |
seed!(32); | |
x0 = rand(2) | |
model = Chain(Dense(2, 10, σ), Dense(10, 1)) |> f64 | |
f(x) = model(x) | |
function df(x) | |
g(x) = Tracker.gradient(z->sum(f(z)), x, nest=true)[1] | |
grads_collect = [g(x[:, i]) for i in 1:size(x, 2)] | |
hcat(grads_collect...) | |
end | |
loss(x) = sum(df(x)) | |
p = params(model) | |
ddfdp = Tracker.gradient(()->loss(x0), p) | |
# 1. ERROR: Use `gradient(...; nest = true)` for nested derivatives :( | |
# 2. ADD: nest=true, ERROR: Nested AD not defined for getindex :( | |
# 3. CHANGE: f(z)[1]->sum(f(z)) (this is ok since f(z) outputs a scalar), works fine :) | |
# gradient check | |
using Calculus | |
W1 = model.layers[1].W.data; b1 = model.layers[1].b.data | |
W2 = model.layers[2].W.data; b2 = model.layers[2].b.data | |
plane_f(x) = W2 * σ.(W1*x .+ b1) .+ b2 | |
plane_f(x0) ≈ f(x0) # true | |
df(x0) ≈ Calculus.gradient(z->plane_f(z)[1], x0) # true | |
# use ForwardDiff to check gradient on coefficients (e.g. b1) | |
## first manually compute the dfdx and write a forward pass | |
function manual_df(b) | |
z1 = W1*x0 .+ b | |
t1 = σ.(z1) | |
dt1 = @. exp(-z1)*t1^2 | |
db = (W2' .* dt1)' * W1 | |
return db | |
end | |
dropdims(manual_df(b1), dims=1) ≈ df(x1) # true | |
b1_dual = Tracker.seed(b1, Val(length(b1))) | |
ll = sum(manual_df(b1_dual)) | |
_, ddfdb_forward = Tracker.extract(ll) | |
ddfdb_back = ddfdp[model.layers[1].b] | |
ddfdb_forward ≈ ddfdb_back # true ! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment