Skip to content

Instantly share code, notes, and snippets.

@swiesend
Last active August 5, 2018 12:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save swiesend/8be8c7927ddaa6edca4894b7c0a98712 to your computer and use it in GitHub Desktop.
Save swiesend/8be8c7927ddaa6edca4894b7c0a98712 to your computer and use it in GitHub Desktop.
Parallel Recur layers for Flux.jl - with map/reduce
layers = vcat(LSTM(10,5),LSTM(10,5))
xs = rand(10)
p1 = Parallel(layers)
p2 = Parallel(layers, map = Dict{Int64,Function}(2 => reverse))
p3 = Parallel(layers, map = Dict{Int64,Function}(2 => reverse), reduce = average)
r1 = p1(xs)
r2 = p2(xs)
r3 = p3(xs)
loss = x -> sum(x)
l1 = loss(r1)
l2 = loss(r2)
l3 = loss(r3)
Flux.back!(l1)
Flux.back!(l2)
Flux.back!(l3)
l1.tracker.grad
l2.tracker.grad
l3.tracker.grad
# uncomment to run on gpu, if available - and if you know how to tell the reflection to infer correct sizes, I don't...
#using CuArrays
using Flux
using Flux: onehot, argmax, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition
include("Parallel.jl")
cd(@__DIR__)
isfile("input.txt") ||
download("http://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt",
"input.txt")
text = collect(readstring("input.txt"))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)
N = length(alphabet)
seqlen = 50
nbatch = 50
hidden = 128
epochs = 1
Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen))
# m = Chain(
# LSTM(N, hidden),
# LSTM(hidden, hidden),
# Dense(hidden, N),
# softmax)
m = Chain(
BiLSTM(N, hidden),
Dropout(0.66),
BiLSTM(hidden, hidden),
Dense(hidden, N),
softmax)
m = gpu(m)
function loss(xs, ys)
l = sum(crossentropy.(m.(gpu.(xs)), gpu.(ys)))
truncate!(m)
return l
end
opt = ADAM(params(m), 0.01)
function sample(m, alphabet, len; temp = 1)
reset!(m)
buf = IOBuffer()
c = rand(alphabet)
for i = 1:len
write(buf, c)
c = wsample(alphabet, m(onehot(c, alphabet)).data)
end
return String(take!(buf))
end
tx, ty = (gpu.(Xs[5]), gpu.(Ys[5]))
# evalcb = () -> @show loss(tx, ty)
evalcb = function ()
info("loss: ", sprint(showcompact, loss(Xs[5], Ys[5]).tracker.data))
println(sample(deepcopy(m), alphabet, 100))
end
for e in 1:epochs
info("Epoch: $e")
Flux.train!(loss, zip(Xs, Ys), opt,
cb = throttle(evalcb, 30))
end
# Sampling
m = cpu(m)
sample(m, alphabet, 1000) |> println
# see: https://github.com/FluxML/model-zoo/blob/master/text/char-rnn/char-rnn.jl
hidden = 128
m = Chain(
BiLSTM(N, hidden),
Dropout(0.66),
BiLSTM(hidden ,hidden),
Dense(hidden, N),
softmax)
# NOTE: Parallel needs to use an updated version for truncate!(m) and reset!(m) use:
truncate!(m) # instead of Flux.truncate!(m)
reset!(m) # instead of Flux.reset!(m)
using Flux
using Flux: treelike, Recur, _truncate, prefor
function average(mapped)
D = mapped[1]
for m in 2:length(mapped)
D = D .+ mapped[m]
end
D = D ./ length(mapped)
D
end
function concat(values)
vcat(values...)
end
mutable struct Parallel{L<:Recur}
layers::Vector{L}
map::Vector{Function}
reduce::Function
end
Parallel(layers::Vector{Recur}) = Parallel(layers, fill(identity, length(layers)), concat)
function Parallel(layers::Vector{L};
map::Dict{Int64,Function} = Dict{Int64,Function}(),
reduce::Function = concat) where L<:Recur
mappings::Vector{Function} = fill(identity, length(layers))
for (k,v) in map
mappings[k] = v
end
return Parallel(layers, mappings, reduce)
end
function (p::Parallel)(xs)
layers, map, reduce = p.layers, p.map, p.reduce
# is ok for nprocs() == 1
# Base.pmap
mapped = Base.pmap(l-> layers[l](map[l](xs)), eachindex(layers))
# mapped = Vector{Any}(length(layers))
# Threads.@threads for l in eachindex(layers)
# mapped[l] = layers[l](map[l](xs))
# end
reduce(mapped)
end
treelike(Parallel)
function _prefor_truncate(x)
if x isa Recur
x.state = _truncate(x.state)
elseif x isa Parallel
for recur in x.layers
_prefor_truncate(recur)
end
end
end
function truncate!(m)
prefor(_prefor_truncate, m)
end
function _prefor_reset(x)
if x isa Recur
x.state = x.init
elseif x isa Parallel
for recur in x.layers
_prefor_reset(recur)
end
end
end
function reset!(m)
prefor(_prefor_reset, m)
end
function Base.reverse(M::Flux.OneHotMatrix{Array{Flux.OneHotVector,1}})
Flux.OneHotMatrix(M.height, reverse(M.data))
end
function Base.reverse(v::Flux.OneHotVector)
v
end
function Base.reverse(ta::TrackedArray)
if length(size(ta.data)) == 2
flipdim(ta.data,2)
else
ta
end
end
function Base.reverse(b::Bool)
b
end
# see:
# "SPEECH RECOGNITION WITH DEEP RECURRENT NEURAL NETWORKS" https://arxiv.org/pdf/1303.5778.pdf
# "Bidirectional LSTM-CRF Models for Sequence Tagging" https://arxiv.org/pdf/1508.01991.pdf
function Bi(recur::Recur, reduce::Function = concat)
map = Dict{Int64,Function}(2 => reverse)
Parallel([recur, deepcopy(recur)], map=map, reduce=reduce)
end
function BiLSTM(in::Int, out::Int, reduce::Function = concat)
if reduce == average
Bi(LSTM(in,out), reduce)
elseif reduce == concat
if out % 2 == 0
Bi(LSTM(in,convert(Int64,out/2)), reduce)
else
throw(DimensionMismatch("`out` must be a multiple of two for `concat` as reduce function."))
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment