Skip to content

Instantly share code, notes, and snippets.

@jeremiedb
Created October 28, 2020 02:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeremiedb/999852ea469db3035b4be3ca7d7b6b67 to your computer and use it in GitHub Desktop.
Save jeremiedb/999852ea469db3035b4be3ca7d7b6b67 to your computer and use it in GitHub Desktop.
Diverging RNN behavior between CPU and GPU (CUDA)
using Revise
using Flux
using Zygote
using CUDA
using Statistics: mean
################################################
# Define structs
################################################
mutable struct MyRecur{T}
cell::T
init
state
end
MyRecur(m, h = hidden(m)) = MyRecur(m, h, h)
function (m::MyRecur)(xs...)
m.state, y = m.cell(m.state, xs...)
return y
end
Flux.@functor MyRecur
Flux.trainable(a::MyRecur) = (a.cell,)
reset!(m::MyRecur) = (m.state = m.init)
reset!(m) = foreach(reset!, functor(m)[1])
# Vanilla RNN
struct MyRNNCell{F,A,V}
σ::F
Wi::A
Wh::A
b::V
end
MyRNNCell(in::Integer, out::Integer, σ = tanh; init = Flux.glorot_uniform) =
MyRNNCell(σ, init(out, in), init(out, out), init(out))
function (m::MyRNNCell)(h, x)
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
h = σ.(Wi*x .+ Wh*h .+ b)
return h, h
end
hidden(m::MyRNNCell) = m.h
Flux.@functor MyRNNCell
MyRecur(m::MyRNNCell) = MyRecur(m, zeros(length(m.b)), zeros(length(m.b)))
MyRNN(a...; ka...) = MyRecur(MyRNNCell(a...; ka...))
########################################
### end of struct definitions
########################################
# illustrate diverging behavior of GPU execution
feat = 32
h_size = 64
seq_len = 20
batch_size = 100
rnn = Chain(MyRNN(feat, h_size),
Dense(h_size, 1, σ),
x -> reshape(x,:))
X = [rand(Float32, feat, batch_size) for i in 1:seq_len]
Y = rand(Float32, batch_size, seq_len) ./ 10
#### transfer to gpu ####
rnn_gpu = rnn |> gpu
X_gpu = gpu(X)
Y_gpu = gpu(Y)
θ = Flux.params(rnn)
θ_gpu = Flux.params(rnn_gpu)
function loss(x,y)
l = mean((Flux.stack(map(rnn, x),2) .- y) .^ 2f0)
# Flux.reset!(rnn)
return l
end
function loss_gpu(x,y)
l = mean((Flux.stack(map(rnn_gpu, x),2) .- y) .^ 2f0)
# Flux.reset!(rnn_gpu)
return l
end
opt = Descent(1e-2)
opt_gpu = Descent(1e-2)
for i in 1:50
println("iter: ", i)
Flux.train!(loss, θ, [(X,Y)], opt)
Flux.train!(loss_gpu, θ_gpu, [(X_gpu,Y_gpu)], opt_gpu)
println("loss_cpu: ", loss(X, Y))
println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu))
println("θ[3][1:2]: ", θ[3][1:2])
println("θ[4][1:2]: ", θ[4][1:2])
println("θ_gpu[3][1:2]: ", θ_gpu[3][1:2])
println("θ_gpu[4][1:2]: ", θ_gpu[4][1:2])
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment