Skip to content

Instantly share code, notes, and snippets.

@rejuvyesh
Created November 19, 2019 18:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rejuvyesh/05b0235f7969862416b4e73d9fc27c3d to your computer and use it in GitHub Desktop.
Save rejuvyesh/05b0235f7969862416b4e73d9fc27c3d to your computer and use it in GitHub Desktop.
using Flux
qdim = 2
nn = Chain(Dense(qdim, 32, tanh), Dense(32, 2));
q = rand(2, 5);
function jac(x)
o = nn(x)
return reduce(hcat, [o[:, i] for i in 1:size(x)[end]])
end
Cv = Flux.Tracker.gradient(x->sum(jac(x)), q; nest=true)[1]
g = Flux.Tracker.gradient(() -> sum(Cv), params(nn))
# ERROR: Nested AD not defined for getindex
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment