Skip to content

Instantly share code, notes, and snippets.

@HamletWantToCode
Created December 19, 2019 05:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save HamletWantToCode/fb7c8b85b4543a8a1b19df4fa2bae35b to your computer and use it in GitHub Desktop.
Save HamletWantToCode/fb7c8b85b4543a8a1b19df4fa2bae35b to your computer and use it in GitHub Desktop.
check nested AD #Flux #Tracker
using Flux
# using DifferentialEquations
using Random: seed!
seed!(32);
x0 = rand(2)
model = Chain(Dense(2, 10, σ), Dense(10, 1)) |> f64
f(x) = model(x)
function df(x)
g(x) = Tracker.gradient(z->sum(f(z)), x, nest=true)[1]
grads_collect = [g(x[:, i]) for i in 1:size(x, 2)]
hcat(grads_collect...)
end
loss(x) = sum(df(x))
p = params(model)
ddfdp = Tracker.gradient(()->loss(x0), p)
# 1. ERROR: Use `gradient(...; nest = true)` for nested derivatives :(
# 2. ADD: nest=true, ERROR: Nested AD not defined for getindex :(
# 3. CHANGE: f(z)[1]->sum(f(z)) (this is ok since f(z) outputs a scalar), works fine :)
# gradient check
using Calculus
W1 = model.layers[1].W.data; b1 = model.layers[1].b.data
W2 = model.layers[2].W.data; b2 = model.layers[2].b.data
plane_f(x) = W2 * σ.(W1*x .+ b1) .+ b2
plane_f(x0) ≈ f(x0) # true
df(x0) ≈ Calculus.gradient(z->plane_f(z)[1], x0) # true
# use ForwardDiff to check gradient on coefficients (e.g. b1)
## first manually compute the dfdx and write a forward pass
function manual_df(b)
z1 = W1*x0 .+ b
t1 = σ.(z1)
dt1 = @. exp(-z1)*t1^2
db = (W2' .* dt1)' * W1
return db
end
dropdims(manual_df(b1), dims=1) ≈ df(x1) # true
b1_dual = Tracker.seed(b1, Val(length(b1)))
ll = sum(manual_df(b1_dual))
_, ddfdb_forward = Tracker.extract(ll)
ddfdb_back = ddfdp[model.layers[1].b]
ddfdb_forward ≈ ddfdb_back # true !
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment