Skip to content

Instantly share code, notes, and snippets.

@skleinbo
Last active July 26, 2022 12:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save skleinbo/ab822198e0ae0898447932fbba7d4a2e to your computer and use it in GitHub Desktop.
Save skleinbo/ab822198e0ae0898447932fbba7d4a2e to your computer and use it in GitHub Desktop.
Derivative of RNN approximating piecewise linear function
using BSON
using Flux
import Flux: @epochs, reset!
import Flux.Losses: mse
using MLUtils
import Zygote: jacobian
using Plots
_x = range(0f0, 1f0, length=101)
n = length(_x)
_y = vcat( 7*_x[1:n÷2], 7*_x[n÷2].+ 2*(_x[n÷2+1:end].-_x[n÷2]))
plot(_x, _y, label="training data")
x = [ [x] for x in _x ]
y = [ [y] for y in _y ]
train_data = DataLoader((x,y), batchsize=n)
# nn = Chain(
# RNN(1=>10),
# Dense(10=>1, identity)
# )
## load pre-trained net
BSON.@load "rnn_piecewise_linear.bson" nn opt
function loss!(x,y)
reset!(nn)
sum(mse(nn(xi), yi) for (xi, yi) in zip(x, y))
end
loss!(x,y)
# opt = ADAM(0.05)
# ps = Flux.params(nn)
# @epochs 1000 begin
# reset!(nn)
# Flux.train!(loss!, ps, train_data, opt)
# @show loss!(x,y)
# opt.eta = max(0.01, opt.eta*0.99)
# end
# prediction
ypred = begin
reset!(nn)
[ nn(x)[1] for x in x ]
end
plot!(_x, ypred, label="prediction")
# internal states of the recurrent unit
hs = begin
reset!(nn)
[
begin
nn(x)
nn.layers[1].state
end for x in x[1:end-1]
]
end
pushfirst!(hs, nn.layers[1].cell.state0)
function thenn(x,y)
nn.layers[1].state = y
nn(x)
end
# Jacobians of the RNN at (x_k, h_{k-1})
js = [
jacobian(thenn, x[i], hs[i-1])
for i in 2:n
]
δx = step(_x)
# Exact slopes
jnum = [
(thenn(x[idx+1], hs[idx+1]) .- thenn(x[idx], hs[idx]))[1]/δx
for idx in 1:n-1
]
plot(_x[2:end], jnum)
idx = 20
# derivative in x
map(2:length(_x)) do idx
# wiggle x only
(thenn(x[idx].+δx, hs[idx-1]) .- thenn(x[idx], hs[idx-1]))[1]/δx,
js[idx-1][1][1]
end |> x->scatter(x; label="")
rnn, dl = nn.layers
# recurrent unit only
function thernn(x,y)
rnn.state = y
rnn(x)
end
# Jacobians of the RNN
jsrnn = [
jacobian(thernn, x[i], hs[i-1])
for i in 2:n
]
# dy/dt
js_accum = accumulate(jsrnn, init=hs[2].-hs[1]) do x,j
j[1] .+ j[2]*x
end
# finally multiply by Jacobian of dense layer
dydt = map(js_accum) do j; (dl.weight*j)[1] end
plot(_x[2:end], jnum, label="finite diff.")
plot!(_x[2:end], dydt, label="derivative")
using BSON
using Flux
import Flux: @epochs, reset!
import Flux.Losses: mse
using MLUtils
import Zygote: jacobian
using Plots
_x = range(0f0, 1f0, length=101)
n = length(_x)
_y = vcat( 7*_x[1:n÷2], 7*_x[n÷2].+ 2*(_x[n÷2+1:end].-_x[n÷2]))
plot(_x, _y, label="training data")
x = [ [_x[i-1], _x[i]] for i in 2:lastindex(_x) ]
pushfirst!(x, [ 0f0, _x[1] ])
y = [ [y] for y in _y ]
train_data = DataLoader((x,y), batchsize=n)
nn = Chain(
RNN(2=>10),
Dense(10=>1, identity)
)
## load pre-trained net
# BSON.@load "rnn_piecewise_linear.bson" nn opt
function loss!(x,y)
reset!(nn)
sum(mse(nn(xi), yi) for (xi, yi) in zip(x, y))
end
loss!(x,y)
opt = ADAM(0.05)
ps = Flux.params(nn)
@epochs 1000 begin
reset!(nn)
Flux.train!(loss!, ps, train_data, opt)
@show loss!(x,y)
opt.eta = max(0.05, opt.eta*0.995)
end
# prediction
ypred = begin
reset!(nn)
[ nn(x)[1] for x in x ]
end
plot!(_x, ypred, label="prediction")
# internal states of the recurrent unit
hs = begin
reset!(nn)
[
begin
nn(x)
nn.layers[1].state
end for x in x[1:end-1]
]
end
pushfirst!(hs, nn.layers[1].cell.state0)
function thenn(x,y)
nn.layers[1].state = y
nn(x)
end
# Jacobians of the RNN at (x_k, h_{k-1})
js = [
jacobian(thenn, x[i], hs[i-1])
for i in 2:n
]
δx = step(_x)
# Exact slopes
jnum = [
(thenn(x[idx+1], hs[idx+1]) .- thenn(x[idx], hs[idx]))[1]/δx
for idx in 1:n-1
]
plot(_x[2:end], jnum)
idx = 20
# derivative in x
map(2:length(_x)) do idx
# wiggle x only
(thenn(x[idx].+δx, hs[idx-1]) .- thenn(x[idx], hs[idx-1]))[1]/δx,
js[idx-1][1][1]
end |> x->scatter(x; label="")
rnn, dl = nn.layers
# recurrent unit only
function thernn(x,y)
rnn.state = y
rnn(x)
end
# Jacobians of the RNN
jsrnn = [
jacobian(thernn, x[i], hs[i-1])
for i in 2:n
]
# dy/dt
js_accum = accumulate(jsrnn, init=hs[2].-hs[1]) do x,j
j[1][:,1] .+ j[1][:,2] .+ j[2]*x
end
# finally multiply by Jacobian of dense layer
dydt = map(js_accum) do j; (dl.weight*j)[1] end
plot(_x[2:end], jnum, label="finite diff.")
plot!(_x[2:end], dydt, label="derivative")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment