Created
October 18, 2023 15:41
-
-
Save msjgriffiths/90614c8aaa31ec5261d77cc52db0f81a to your computer and use it in GitHub Desktop.
SRU in Julia
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
using Random | |
mutable struct SRUCell{M, V} | |
""" | |
W, Wⱼ, Wᵣ are the parameter matrices, and | |
vⱼ, vᵣ, bⱼ, bᵣ are the parameter vectors to be | |
learnt during training. | |
Note we change f -> j in paper for notational convenience | |
(there is no unicode subscript for f). | |
""" | |
W::M | |
Wⱼ::M | |
Wᵣ::M | |
vⱼ::V | |
vᵣ::V | |
bⱼ::V | |
bᵣ::V | |
cₜ::V # Hidden state | |
α::AbstractFloat # Scaling factor for highway bias; set to √3 | |
end | |
sru_init(d...) = (2rand(Random.GLOBAL_RNG, Float32, d...) .- 1f0) .* √(3.0f0/sum(d)) | |
function SRUCell(in::Integer, out::Integer, init = sru_init) | |
cell = SRUCell(init(out, in), init(out, in), init(out, in), | |
init(out), init(out), init(out), init(out), zeros(Float32, out), √3) | |
end | |
# Change f -> j for notation in code | |
# fₜ = σ(W_fcₜ + vⱼ ⊙ cₜ₋₁ + bⱼ) | |
# cₜ = fₜ ⊙ cₜ₋₁ + (1 - fₜ) ⊙ Wxₜ | |
# rₜ = σ(Wᵣxₜ + vᵣ ⊙ cₜ₋₁ + bᵣ) | |
# hₜ = rₜ ⊙ cₜ + (1 - rₜ) ⊙ xₜ | |
⊙(x, y) = x .* y | |
function (m::SRUCell)(cₜ₋₁, xₜ) | |
jₜ = σ.(m.Wⱼ * xₜ + m.vⱼ ⊙ cₜ₋₁ + m.bⱼ) | |
cₜ = jₜ ⊙ cₜ₋₁ + (1 .- jₜ) ⊙ (m.W * xₜ) | |
rₜ = σ.(m.Wᵣ * xₜ + m.vᵣ ⊙ cₜ₋₁ + m.bᵣ) | |
hₜ = rₜ ⊙ cₜ + (1 .- rₜ) ⊙ (xₜ * m.α) | |
return(cₜ, hₜ) | |
end | |
Flux.hidden(m::SRUCell) = m.cₜ | |
Flux.@functor SRUCell (W, Wⱼ, Wᵣ, vⱼ, vᵣ, bⱼ, bᵣ) | |
SRU(a...; ka...) = Flux.Recur(SRUCell(a...; ka...)) | |
# Test this | |
using PyCall | |
sru = pyimport("sru") | |
# Add helper for indexing | |
function Base.getindex(W::PyObject, n::AbstractRange, m::AbstractRange) | |
M = zeros(length(n), length(m)) | |
for (i, row) in enumerate(n) | |
for (j, col) in enumerate(m) | |
M[i, j] = W[row, col] | |
end | |
end | |
M | |
end | |
function Base.getindex(W::PyObject, n::AbstractRange) | |
M = zeros(length(n)) | |
for (i, idx) in enumerate(n) | |
M[i] = W[idx] | |
end | |
M | |
end | |
py_sru = sru.SRU(10, 10, num_layers=1, rescale=true, highway_bias=0) | |
py_weights = py_sru.state_dict() | |
Ws = py_weights["rnn_lst.0.weight"] | |
Vs = py_weights["rnn_lst.0.weight_c"] | |
Bs = py_weights["rnn_lst.0.bias"] | |
weights = Dict( | |
# SRU stores W, Wⱼ, and Wᵣ as three different matrices | |
# When it computes U' * x, it then reshapes the result to get the | |
# three matrices Wxₜ, Wⱼxₜ, and Wᵣxₜ. The reshaping is a little odd, | |
# to reshape(Ws[1:10, 1:30]' * xₜ, (3, 10))' | |
# So we can take every third column, and then transpose... | |
:Wⱼ => Ws[1:10, 2:3:30]', | |
:W => Ws[1:10, 1:3:30]', | |
:Wᵣ => Ws[1:10, 3:3:30]', | |
:vⱼ => Vs[1:10], | |
:vᵣ => Vs[11:20], | |
:bⱼ => Bs[1:10], | |
:bᵣ => Bs[11:20], | |
) | |
# Create Julia SRU | |
m = SRU(10, 10) | |
# Port weights over | |
for (key, value) in weights | |
setproperty!(m.cell, key, value) | |
end | |
Flux.reset!(m) | |
torch = pyimport("torch") | |
x = torch.rand(1, 1, 10) | |
xⱼᵤₗᵢₐ = x[1][1][1:10] | |
Flux.reset!(m) | |
rⱼᵤₗᵢₐ = m(xⱼᵤₗᵢₐ) | |
rₛᵣᵤ = py_sru(x)[1][1][1][1:10] | |
isapprox(rⱼᵤₗᵢₐ, rₛᵣᵤ; atol=1e-6) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment