Skip to content

Instantly share code, notes, and snippets.

@drozzy
Last active November 29, 2020 07:43
Show Gist options
  • Save drozzy/ca9819203b410d8b5778d9b3d6843c08 to your computer and use it in GitHub Desktop.
Save drozzy/ca9819203b410d8b5778d9b3d6843c08 to your computer and use it in GitHub Desktop.
using Flux
using Flux: reset!
# https://discourse.julialang.org/t/how-to-do-batching-in-fluxs-recurrent-sequence-model-to-take-advantage-of-gpu-during-training/28678
# example:
N = 7
D = 2
m = RNN(D,3)
m.(rand(D, T))
m.state
Flux.reset!(m)
T = 5
rand(D,N)
seq_batch = [rand(D,N) for _ in 1:T]
y_batch = m.(seq_batch)
## Single input seq
H = 3
T = 5
m = RNN(H,4)
x = [rand(H) for _ in 1:T]
y,h = m.(x)
m.state
size(y), size(h)
## sequences of batch
B = 10
x = [rand(H, B) for _ in 1:T]
Flux.reset!(m)
r = m.(x)
r
size(r)
size(r[1])
m.state
Flux.reset!(m)
m.state
## Matrix of batch? WRONG!
# # sequences of batch
# B = 10
# x = rand(H, B, T) #[rand(H, B) for _ in 1:T]
# Flux.reset!(m)
# r = m.(x)
# size(r)
# size(r[1])
# m.state
# size(y), size(h)
# Starting from basic principles
# Assuming only that a Dense layer knows how to deal with batches.
f1 = Flux.Dense(10, 5)
## Single input is as follows:
x = rand(10)
y = f1(x) # 5-element array
## Batch input is as follows
xbatch = rand(10, 2) # 10x2
# Dense knows how to handle batces: HxB
y = f1(xbatch) # 5x2
## Sequence input is as follows:
xseq = [rand(10) for _ in 1:13] # 13-elem array, each 10 in size
y = f1.(xseq) # 13-elem array, each 5 in size
y[1] # 5-element array
## Batch of sequence is simply a sequence, where each item is a batch
seq_batch = [rand(10, 2) for _ in 1:13] # 13-elem array - input sequence
seq_batch[1] # 10x2 batch
y = f1.(seq_batch) # 13-elem array, output sequence
y[1] # 5x2 output batch at this position in the sequence
## Recurrent - X better be the same DIM as the First DIM in RNN!
f2 = Flux.RNN(10, 5)
## Single input - RNN basically acts as a Dense layer
x = rand(10)
### Call it once:
reset!(f2)
f2.state
y = f2(x) # 5-elem array
f2.state # 5-elem array state
### Call it on each item of x (WRONG!!!)
reset!(f2)
f2.state
y2 = f2.(x)
f2.state
y2[1] # 5x10 - each item is being broadcast to 5, and we get 10 outputs, same as the dim
# of internal h-matrix (not the effect we want)
# Same as if we fed a single number:
reset!(f2)
f2.state
y2 = [f2(x[1]) for _ in 1:10]
f2.state
y2[1]
## Batched input:
x = rand(10, 2)
### Call it once - OK!
reset!(f2)
y = f2(x) # 5x2 - it understands batches
f2.state # 5x2 # size of state var is dependent on the Batch!!!
### Call it on each item of x (WRONG!!!) - x dim not the same as hidden dim in RNN
reset!(f2)
y2 = f2.(x) # 10x2 elem arr
y2[1]
f2.state # 5x10
## Sequence of inputs
reset!(f2)
x = [rand(10) for _ ∈ 1:13]
y = f2.(x) # 13-elm array, each of 5-dim
y[1] # 5-elem array
f2.state # 5
## Sequence of batched inputs (KEY IDEA!):
reset!(f2)
x = [rand(10, 2) for _ in 1:13] # 13-elem array
x[1]
y = f2.(x) # 13-elem array, each of 5x2
y[1] # 5x2
## Aside: Simple RNN - https://fluxml.ai/Flux.jl/v0.4/models/recurrence.html
Wxh = randn(5, 10)
Whh = randn(5, 5)
b = randn(5)
x = rand(10) # dummy data 10x1
h = rand(5) # initial hidden state 5x1
Wxh * x # 5x1
Whh * h # 5x1
h = tanh.(Wxh * x .+ Whh * h .+ b) # 5x1
function rnn(h, x)
h = tanh.(Wxh * x .+ Whh * h .+ b)
return h, h
end
h, y = rnn(h, x)
# What if x dim is 1?
x = 1.0
Wxh * x # 5x10
Whh * h # 5x1
h = tanh.(Wxh * x .+ Whh * h .+ b) # 5x10
# -> So the output just reduces to the size of the Wxh matrix - the size of which is xh
# The reason is the single x-input is simply broadcast:
a = Wxh * x # 5x10
b = Whh * h # 5x1
a .+ b
## stateful:
x = rand(10)
h = rand(5)
m = Flux.Recur(rnn, h)
y = m(x)
## - what if x is single?
x = 1
h = rand(5)
m = Flux.Recur(rnn, h)
y = m(x)
### Example of recurrent network that sums the Inputs
accum(h, x) = (h + x, x*2)
rnn1 = Flux.Recur(accum, 0)
rnn1(1) # 1
rnn1(2) # 2
rnn1.state # 3
rnn1.(1:10) # apply to a sequence (inputs +1, e.g. 2, 4, 6)
rnn1.state # 58
# Mess with hidden state:
accum2(h, x) = (h * x, x*2)
rnn2 = Flux.Recur(accum2, ones(1))
rnn2.state
rnn2(rand(1, 2))
rnn2.state
rnn2(2)
rnn2.state
rnn2.(1:10)
rnn2.state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment