Skip to content

Instantly share code, notes, and snippets.

@leesharma
Last active October 26, 2020 05:36
Show Gist options
  • Save leesharma/8d6165061a9bd149670b951e0a507bb1 to your computer and use it in GitHub Desktop.
Save leesharma/8d6165061a9bd149670b951e0a507bb1 to your computer and use it in GitHub Desktop.
Mockup of how to use Flux to take gradients in Julia
import Pkg; Pkg.add("Flux")
using Flux: gradient, Chain, Dense, σ, softmax
# Partial derivatives of a basic function
#
# docs: https://fluxml.ai/Zygote.jl/dev/#Taking-Gradients-1
#
f(x,y) = 3x^2 + x*y + 5y
f_x(x,y) = gradient(f, x, y)[1] # f_x = 6x + y
f_y(x,y) = gradient(f, x, y)[2] # f_y = x + 5
f_xx(x,y) = gradient(f_x, x, y)[1] # f_xx = 6
f_xy(x,y) = gradient(f_x, x, y)[2] # f_xy = 1
f_yx(x,y) = gradient(f_y, x, y)[1] # f_yx = 1
f_yy(x,y) = gradient(f_y, x, y)[2] # f_yy = 0
# verifying this works as expected
@assert f(3,2) == 43
@assert f_x(3,2) == 20
@assert f_y(3,2) == 8
@assert f_xx(3,2) == 6
@assert f_xy(3,2) == 1
@assert f_yx(3,2) == 1
@assert f_yy(3,2) == nothing
# Partial of a vector function
#
# (doesn't seem to be built-in?)
v1(x,y) = x^2 - y^2
v2(x,y) = x*y
v3(x,y) = x*y^2 - y*x^2
v(x,y) = [v1(x,y) # dv1/dx = 2x; dv1/dy = -2y
v2(x,y) # dv2/dx = y; dv2/dy = x
v3(x,y)] # dv3/dx = y^2 - 2xy dv3/dy = 2xy - x^2
# this seems tedious... extract to function if there's much more of this...
v_x(x,y) = [gradient(v1,x,y)[1]
gradient(v2,x,y)[1]
gradient(v3,x,y)[1]]
v_y(x,y) = [gradient(v1,x,y)[2]
gradient(v2,x,y)[2]
gradient(v3,x,y)[2]]
# verify simple case
@assert isequal( v(1,1), [0; 1; 0] )
@assert isequal( v_x(1,1), [2; 1; -1] )
@assert isequal( v_y(1,1), [-2; 1; 1] )
# Partial of a of a neural network with vector-valued output
#
#
dummy_model = Chain(
Dense(4, 5, σ),
Dense(5, 3),
softmax)
h(u,t) = dummy_model([reshape(u,(3,1)); t])
h1(u,t) = h(u,t)[1]
h2(u,t) = h(u,t)[2]
h3(u,t) = h(u,t)[3]
h_t(u,t) = [
gradient(h1, u, t)[2] # `gradient` returns two values, (d\vec{u}/dt, dh1/dt)
gradient(h2, u, t)[2]
gradient(h3, u, t)[2]
]
# Mocking up function F
# (we probably want a different interface in real life)
function F(u,t; h=h,h_t=h_t)
h1,h2,h3 = h(u,t)
ht = h_t(u,t)
N(x,y,z; σ=10.0,ρ=28.0,β=8/3) = [ σ*(y-x)
x*(ρ-z)-y
x*y-β*z ]
F = ht - N(h1,h2,h3)
end
###
u = [1.0; 0.0; 0.0]
t = 1
display(F(u,t))
println()
println("The script ran successfully.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment