Skip to content

Instantly share code, notes, and snippets.

@msjgriffiths
Created October 18, 2023 15:41
Show Gist options
  • Save msjgriffiths/90614c8aaa31ec5261d77cc52db0f81a to your computer and use it in GitHub Desktop.
Save msjgriffiths/90614c8aaa31ec5261d77cc52db0f81a to your computer and use it in GitHub Desktop.
SRU in Julia
using Flux
using Random
mutable struct SRUCell{M, V}
"""
W, Wⱼ, Wᵣ are the parameter matrices, and
vⱼ, vᵣ, bⱼ, bᵣ are the parameter vectors to be
learnt during training.
Note we change f -> j in paper for notational convenience
(there is no unicode subscript for f).
"""
W::M
Wⱼ::M
Wᵣ::M
vⱼ::V
vᵣ::V
bⱼ::V
bᵣ::V
cₜ::V # Hidden state
α::AbstractFloat # Scaling factor for highway bias; set to √3
end
sru_init(d...) = (2rand(Random.GLOBAL_RNG, Float32, d...) .- 1f0) .* √(3.0f0/sum(d))
function SRUCell(in::Integer, out::Integer, init = sru_init)
cell = SRUCell(init(out, in), init(out, in), init(out, in),
init(out), init(out), init(out), init(out), zeros(Float32, out), √3)
end
# Change f -> j for notation in code
# fₜ = σ(W_fcₜ + vⱼ ⊙ cₜ₋₁ + bⱼ)
# cₜ = fₜ ⊙ cₜ₋₁ + (1 - fₜ) ⊙ Wxₜ
# rₜ = σ(Wᵣxₜ + vᵣ ⊙ cₜ₋₁ + bᵣ)
# hₜ = rₜ ⊙ cₜ + (1 - rₜ) ⊙ xₜ
⊙(x, y) = x .* y
function (m::SRUCell)(cₜ₋₁, xₜ)
jₜ = σ.(m.Wⱼ * xₜ + m.vⱼ ⊙ cₜ₋₁ + m.bⱼ)
cₜ = jₜ ⊙ cₜ₋₁ + (1 .- jₜ) ⊙ (m.W * xₜ)
rₜ = σ.(m.Wᵣ * xₜ + m.vᵣ ⊙ cₜ₋₁ + m.bᵣ)
hₜ = rₜ ⊙ cₜ + (1 .- rₜ) ⊙ (xₜ * m.α)
return(cₜ, hₜ)
end
Flux.hidden(m::SRUCell) = m.cₜ
Flux.@functor SRUCell (W, Wⱼ, Wᵣ, vⱼ, vᵣ, bⱼ, bᵣ)
SRU(a...; ka...) = Flux.Recur(SRUCell(a...; ka...))
# Test this
using PyCall
sru = pyimport("sru")
# Add helper for indexing
function Base.getindex(W::PyObject, n::AbstractRange, m::AbstractRange)
M = zeros(length(n), length(m))
for (i, row) in enumerate(n)
for (j, col) in enumerate(m)
M[i, j] = W[row, col]
end
end
M
end
function Base.getindex(W::PyObject, n::AbstractRange)
M = zeros(length(n))
for (i, idx) in enumerate(n)
M[i] = W[idx]
end
M
end
py_sru = sru.SRU(10, 10, num_layers=1, rescale=true, highway_bias=0)
py_weights = py_sru.state_dict()
Ws = py_weights["rnn_lst.0.weight"]
Vs = py_weights["rnn_lst.0.weight_c"]
Bs = py_weights["rnn_lst.0.bias"]
weights = Dict(
# SRU stores W, Wⱼ, and Wᵣ as three different matrices
# When it computes U' * x, it then reshapes the result to get the
# three matrices Wxₜ, Wⱼxₜ, and Wᵣxₜ. The reshaping is a little odd,
# to reshape(Ws[1:10, 1:30]' * xₜ, (3, 10))'
# So we can take every third column, and then transpose...
:Wⱼ => Ws[1:10, 2:3:30]',
:W => Ws[1:10, 1:3:30]',
:Wᵣ => Ws[1:10, 3:3:30]',
:vⱼ => Vs[1:10],
:vᵣ => Vs[11:20],
:bⱼ => Bs[1:10],
:bᵣ => Bs[11:20],
)
# Create Julia SRU
m = SRU(10, 10)
# Port weights over
for (key, value) in weights
setproperty!(m.cell, key, value)
end
Flux.reset!(m)
torch = pyimport("torch")
x = torch.rand(1, 1, 10)
xⱼᵤₗᵢₐ = x[1][1][1:10]
Flux.reset!(m)
rⱼᵤₗᵢₐ = m(xⱼᵤₗᵢₐ)
rₛᵣᵤ = py_sru(x)[1][1][1][1:10]
isapprox(rⱼᵤₗᵢₐ, rₛᵣᵤ; atol=1e-6)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment