Last active
November 29, 2020 07:43
-
-
Save drozzy/ca9819203b410d8b5778d9b3d6843c08 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
using Flux: reset! | |
# https://discourse.julialang.org/t/how-to-do-batching-in-fluxs-recurrent-sequence-model-to-take-advantage-of-gpu-during-training/28678 | |
# example: | |
N = 7 | |
D = 2 | |
m = RNN(D,3) | |
m.(rand(D, T)) | |
m.state | |
Flux.reset!(m) | |
T = 5 | |
rand(D,N) | |
seq_batch = [rand(D,N) for _ in 1:T] | |
y_batch = m.(seq_batch) | |
## Single input seq | |
H = 3 | |
T = 5 | |
m = RNN(H,4) | |
x = [rand(H) for _ in 1:T] | |
y,h = m.(x) | |
m.state | |
size(y), size(h) | |
## sequences of batch | |
B = 10 | |
x = [rand(H, B) for _ in 1:T] | |
Flux.reset!(m) | |
r = m.(x) | |
r | |
size(r) | |
size(r[1]) | |
m.state | |
Flux.reset!(m) | |
m.state | |
## Matrix of batch? WRONG! | |
# # sequences of batch | |
# B = 10 | |
# x = rand(H, B, T) #[rand(H, B) for _ in 1:T] | |
# Flux.reset!(m) | |
# r = m.(x) | |
# size(r) | |
# size(r[1]) | |
# m.state | |
# size(y), size(h) | |
# Starting from basic principles | |
# Assuming only that a Dense layer knows how to deal with batches. | |
f1 = Flux.Dense(10, 5) | |
## Single input is as follows: | |
x = rand(10) | |
y = f1(x) # 5-element array | |
## Batch input is as follows | |
xbatch = rand(10, 2) # 10x2 | |
# Dense knows how to handle batces: HxB | |
y = f1(xbatch) # 5x2 | |
## Sequence input is as follows: | |
xseq = [rand(10) for _ in 1:13] # 13-elem array, each 10 in size | |
y = f1.(xseq) # 13-elem array, each 5 in size | |
y[1] # 5-element array | |
## Batch of sequence is simply a sequence, where each item is a batch | |
seq_batch = [rand(10, 2) for _ in 1:13] # 13-elem array - input sequence | |
seq_batch[1] # 10x2 batch | |
y = f1.(seq_batch) # 13-elem array, output sequence | |
y[1] # 5x2 output batch at this position in the sequence | |
## Recurrent - X better be the same DIM as the First DIM in RNN! | |
f2 = Flux.RNN(10, 5) | |
## Single input - RNN basically acts as a Dense layer | |
x = rand(10) | |
### Call it once: | |
reset!(f2) | |
f2.state | |
y = f2(x) # 5-elem array | |
f2.state # 5-elem array state | |
### Call it on each item of x (WRONG!!!) | |
reset!(f2) | |
f2.state | |
y2 = f2.(x) | |
f2.state | |
y2[1] # 5x10 - each item is being broadcast to 5, and we get 10 outputs, same as the dim | |
# of internal h-matrix (not the effect we want) | |
# Same as if we fed a single number: | |
reset!(f2) | |
f2.state | |
y2 = [f2(x[1]) for _ in 1:10] | |
f2.state | |
y2[1] | |
## Batched input: | |
x = rand(10, 2) | |
### Call it once - OK! | |
reset!(f2) | |
y = f2(x) # 5x2 - it understands batches | |
f2.state # 5x2 # size of state var is dependent on the Batch!!! | |
### Call it on each item of x (WRONG!!!) - x dim not the same as hidden dim in RNN | |
reset!(f2) | |
y2 = f2.(x) # 10x2 elem arr | |
y2[1] | |
f2.state # 5x10 | |
## Sequence of inputs | |
reset!(f2) | |
x = [rand(10) for _ ∈ 1:13] | |
y = f2.(x) # 13-elm array, each of 5-dim | |
y[1] # 5-elem array | |
f2.state # 5 | |
## Sequence of batched inputs (KEY IDEA!): | |
reset!(f2) | |
x = [rand(10, 2) for _ in 1:13] # 13-elem array | |
x[1] | |
y = f2.(x) # 13-elem array, each of 5x2 | |
y[1] # 5x2 | |
## Aside: Simple RNN - https://fluxml.ai/Flux.jl/v0.4/models/recurrence.html | |
Wxh = randn(5, 10) | |
Whh = randn(5, 5) | |
b = randn(5) | |
x = rand(10) # dummy data 10x1 | |
h = rand(5) # initial hidden state 5x1 | |
Wxh * x # 5x1 | |
Whh * h # 5x1 | |
h = tanh.(Wxh * x .+ Whh * h .+ b) # 5x1 | |
function rnn(h, x) | |
h = tanh.(Wxh * x .+ Whh * h .+ b) | |
return h, h | |
end | |
h, y = rnn(h, x) | |
# What if x dim is 1? | |
x = 1.0 | |
Wxh * x # 5x10 | |
Whh * h # 5x1 | |
h = tanh.(Wxh * x .+ Whh * h .+ b) # 5x10 | |
# -> So the output just reduces to the size of the Wxh matrix - the size of which is xh | |
# The reason is the single x-input is simply broadcast: | |
a = Wxh * x # 5x10 | |
b = Whh * h # 5x1 | |
a .+ b | |
## stateful: | |
x = rand(10) | |
h = rand(5) | |
m = Flux.Recur(rnn, h) | |
y = m(x) | |
## - what if x is single? | |
x = 1 | |
h = rand(5) | |
m = Flux.Recur(rnn, h) | |
y = m(x) | |
### Example of recurrent network that sums the Inputs | |
accum(h, x) = (h + x, x*2) | |
rnn1 = Flux.Recur(accum, 0) | |
rnn1(1) # 1 | |
rnn1(2) # 2 | |
rnn1.state # 3 | |
rnn1.(1:10) # apply to a sequence (inputs +1, e.g. 2, 4, 6) | |
rnn1.state # 58 | |
# Mess with hidden state: | |
accum2(h, x) = (h * x, x*2) | |
rnn2 = Flux.Recur(accum2, ones(1)) | |
rnn2.state | |
rnn2(rand(1, 2)) | |
rnn2.state | |
rnn2(2) | |
rnn2.state | |
rnn2.(1:10) | |
rnn2.state |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment