Last active
August 5, 2018 12:53
-
-
Save swiesend/8be8c7927ddaa6edca4894b7c0a98712 to your computer and use it in GitHub Desktop.
Parallel Recur layers for Flux.jl - with map/reduce
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
layers = vcat(LSTM(10,5),LSTM(10,5)) | |
xs = rand(10) | |
p1 = Parallel(layers) | |
p2 = Parallel(layers, map = Dict{Int64,Function}(2 => reverse)) | |
p3 = Parallel(layers, map = Dict{Int64,Function}(2 => reverse), reduce = average) | |
r1 = p1(xs) | |
r2 = p2(xs) | |
r3 = p3(xs) | |
loss = x -> sum(x) | |
l1 = loss(r1) | |
l2 = loss(r2) | |
l3 = loss(r3) | |
Flux.back!(l1) | |
Flux.back!(l2) | |
Flux.back!(l3) | |
l1.tracker.grad | |
l2.tracker.grad | |
l3.tracker.grad |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# uncomment to run on gpu, if available - and if you know how to tell the reflection to infer correct sizes, I don't... | |
#using CuArrays | |
using Flux | |
using Flux: onehot, argmax, chunk, batchseq, throttle, crossentropy | |
using StatsBase: wsample | |
using Base.Iterators: partition | |
include("Parallel.jl") | |
cd(@__DIR__) | |
isfile("input.txt") || | |
download("http://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", | |
"input.txt") | |
text = collect(readstring("input.txt")) | |
alphabet = [unique(text)..., '_'] | |
text = map(ch -> onehot(ch, alphabet), text) | |
stop = onehot('_', alphabet) | |
N = length(alphabet) | |
seqlen = 50 | |
nbatch = 50 | |
hidden = 128 | |
epochs = 1 | |
Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen)) | |
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen)) | |
# m = Chain( | |
# LSTM(N, hidden), | |
# LSTM(hidden, hidden), | |
# Dense(hidden, N), | |
# softmax) | |
m = Chain( | |
BiLSTM(N, hidden), | |
Dropout(0.66), | |
BiLSTM(hidden, hidden), | |
Dense(hidden, N), | |
softmax) | |
m = gpu(m) | |
function loss(xs, ys) | |
l = sum(crossentropy.(m.(gpu.(xs)), gpu.(ys))) | |
truncate!(m) | |
return l | |
end | |
opt = ADAM(params(m), 0.01) | |
function sample(m, alphabet, len; temp = 1) | |
reset!(m) | |
buf = IOBuffer() | |
c = rand(alphabet) | |
for i = 1:len | |
write(buf, c) | |
c = wsample(alphabet, m(onehot(c, alphabet)).data) | |
end | |
return String(take!(buf)) | |
end | |
tx, ty = (gpu.(Xs[5]), gpu.(Ys[5])) | |
# evalcb = () -> @show loss(tx, ty) | |
evalcb = function () | |
info("loss: ", sprint(showcompact, loss(Xs[5], Ys[5]).tracker.data)) | |
println(sample(deepcopy(m), alphabet, 100)) | |
end | |
for e in 1:epochs | |
info("Epoch: $e") | |
Flux.train!(loss, zip(Xs, Ys), opt, | |
cb = throttle(evalcb, 30)) | |
end | |
# Sampling | |
m = cpu(m) | |
sample(m, alphabet, 1000) |> println |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# see: https://github.com/FluxML/model-zoo/blob/master/text/char-rnn/char-rnn.jl | |
hidden = 128 | |
m = Chain( | |
BiLSTM(N, hidden), | |
Dropout(0.66), | |
BiLSTM(hidden ,hidden), | |
Dense(hidden, N), | |
softmax) | |
# NOTE: Parallel needs to use an updated version for truncate!(m) and reset!(m) use: | |
truncate!(m) # instead of Flux.truncate!(m) | |
reset!(m) # instead of Flux.reset!(m) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
using Flux: treelike, Recur, _truncate, prefor | |
function average(mapped) | |
D = mapped[1] | |
for m in 2:length(mapped) | |
D = D .+ mapped[m] | |
end | |
D = D ./ length(mapped) | |
D | |
end | |
function concat(values) | |
vcat(values...) | |
end | |
mutable struct Parallel{L<:Recur} | |
layers::Vector{L} | |
map::Vector{Function} | |
reduce::Function | |
end | |
Parallel(layers::Vector{Recur}) = Parallel(layers, fill(identity, length(layers)), concat) | |
function Parallel(layers::Vector{L}; | |
map::Dict{Int64,Function} = Dict{Int64,Function}(), | |
reduce::Function = concat) where L<:Recur | |
mappings::Vector{Function} = fill(identity, length(layers)) | |
for (k,v) in map | |
mappings[k] = v | |
end | |
return Parallel(layers, mappings, reduce) | |
end | |
function (p::Parallel)(xs) | |
layers, map, reduce = p.layers, p.map, p.reduce | |
# is ok for nprocs() == 1 | |
# Base.pmap | |
mapped = Base.pmap(l-> layers[l](map[l](xs)), eachindex(layers)) | |
# mapped = Vector{Any}(length(layers)) | |
# Threads.@threads for l in eachindex(layers) | |
# mapped[l] = layers[l](map[l](xs)) | |
# end | |
reduce(mapped) | |
end | |
treelike(Parallel) | |
function _prefor_truncate(x) | |
if x isa Recur | |
x.state = _truncate(x.state) | |
elseif x isa Parallel | |
for recur in x.layers | |
_prefor_truncate(recur) | |
end | |
end | |
end | |
function truncate!(m) | |
prefor(_prefor_truncate, m) | |
end | |
function _prefor_reset(x) | |
if x isa Recur | |
x.state = x.init | |
elseif x isa Parallel | |
for recur in x.layers | |
_prefor_reset(recur) | |
end | |
end | |
end | |
function reset!(m) | |
prefor(_prefor_reset, m) | |
end | |
function Base.reverse(M::Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}) | |
Flux.OneHotMatrix(M.height, reverse(M.data)) | |
end | |
function Base.reverse(v::Flux.OneHotVector) | |
v | |
end | |
function Base.reverse(ta::TrackedArray) | |
if length(size(ta.data)) == 2 | |
flipdim(ta.data,2) | |
else | |
ta | |
end | |
end | |
function Base.reverse(b::Bool) | |
b | |
end | |
# see: | |
# "SPEECH RECOGNITION WITH DEEP RECURRENT NEURAL NETWORKS" https://arxiv.org/pdf/1303.5778.pdf | |
# "Bidirectional LSTM-CRF Models for Sequence Tagging" https://arxiv.org/pdf/1508.01991.pdf | |
function Bi(recur::Recur, reduce::Function = concat) | |
map = Dict{Int64,Function}(2 => reverse) | |
Parallel([recur, deepcopy(recur)], map=map, reduce=reduce) | |
end | |
function BiLSTM(in::Int, out::Int, reduce::Function = concat) | |
if reduce == average | |
Bi(LSTM(in,out), reduce) | |
elseif reduce == concat | |
if out % 2 == 0 | |
Bi(LSTM(in,convert(Int64,out/2)), reduce) | |
else | |
throw(DimensionMismatch("`out` must be a multiple of two for `concat` as reduce function.")) | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment