Skip to content

Instantly share code, notes, and snippets.

@jsams
Last active May 4, 2018 05:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jsams/789c4000983290446b8910c88cdfa0df to your computer and use it in GitHub Desktop.
Save jsams/789c4000983290446b8910c88cdfa0df to your computer and use it in GitHub Desktop.
how to get the gradient wrt to some inputs but not others? scoping?
using ReverseDiff
using Base.Test
mutable struct data
X::Array{Float64, 2}
end
const D = data(zeros(Float64, 2, 2))
function f1(params)
X = [1 2;
3 4]
sum(params[1]' * X[:, 1] - (params[1] .* params[2])' * X[:, 2].^2)
end
"query database to get new X, want gradient wrt params"
function f2(params, d)
sum(params[1]' * d.X[:, 1] - (params[1] .* params[2])' * d.X[:, 2].^2)
end
f2(params) = f2(params, D)
function scope_test()
D.X = float.([6 7; 1 3])
f2_tape = ReverseDiff.GradientTape(f2, [1, 2])
D.X = float.([1 2; 3 4])
grad = ReverseDiff.gradient!(f2_tape, [3,4])
return grad
end
function scope_test2()
D.X = float.([1 2; 3 4])
f2_tape = ReverseDiff.GradientTape(f2, [1, 2])
D.X = float.([1 2; 3 4])
grad = ReverseDiff.gradient!(f2_tape, [3,4])
return grad
end
D.X = float.([1 2; 3 4])
@test f1([3,4]) == f2([3, 4], D)
f1_tape = ReverseDiff.GradientTape(f1, [3,4])
@test ReverseDiff.gradient!(f1_tape, [3,4]) == scope_test() # fails! uses original values
@test ReverseDiff.gradient!(f1_tape, [3,4]) == scope_test2()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment