Skip to content

Instantly share code, notes, and snippets.

@merckxiaan

merckxiaan/model.jl

Last active Oct 26, 2018
Embed
What would you like to do?
This (attempt at) a seq2seq model in Flux causes Julia to crash. (#29818)
using Flux, CuArrays, Statistics
mutable struct Lang; name; word2index; word2count; index2word; n_words end
Lang(name) = Lang(
name,
Dict{String, Int}(),
Dict{String, Int}(),
Dict{Int, String}(1=>"SOS", 2=>"EOS", 3=>"UNK", 4=>"PAD"),
4)
function (l::Lang)(sentence::String)
for word in split(sentence, " ")
if word keys(l.word2index)
l.word2index[word] = l.n_words + 1
l.word2count[word] = 1
l.index2word[l.n_words + 1] = word
l.n_words += 1
else
l.word2count[word] += 1
end
end
end
function normalizeString(s)
s = strip(lowercase(s))
s= replace(s, r"([.!?,])"=>s" \1")
#s= replace(s, r"[^a-zA-Z.!?]+"=>" ")
return s
end
function readLangs(lang1, lang2; rev=false)
println("Reading lines...")
lines = readlines(FILE)
pairs = [normalizeString.(pair) for pair in split.(lines, "\t")]
if rev
pairs = reverse.(pairs)
input_lang = Lang(lang2)
output_lang = Lang(lang1)
else
input_lang = Lang(lang1)
output_lang = Lang(lang2)
end
return(input_lang, output_lang, pairs)
end
eng_prefixes = [
"i am ", "i m ",
"he is ", "he s ",
"she is ", "she s ",
"you are ", "you re ",
"we are ", "we re ",
"they are ", "they re "]
function filterPair(p)
return(false (length.(split.(p, " ")) .<= MAX_LENGTH) && true (startswith.(p[1], eng_prefixes)))
end
function prepareData(lang1, lang2; rev=false)
input_lang, output_lang, pairs = readLangs(lang1, lang2; rev=rev)
println("Read $(length(pairs)) sentence pairs.")
pairs = [pair for pair in pairs if filterPair(pair)]
println("Trimmed to $(length(pairs)) sentence pairs.\n")
xs = []
ys = []
for pair in pairs
push!(xs, pair[1])
push!(ys, pair[2])
end
println("Counting words...")
for pair in pairs
input_lang(pair[2])
output_lang(pair[1])
end
println("Counted words:")
println("", input_lang.name, ": ", input_lang.n_words)
println("", output_lang.name, ": ", output_lang.n_words)
return(input_lang, output_lang, xs, ys)
end
FILE = "D:/Downloads/fra-eng/fra.txt"
MAX_LENGTH = 10
fr, eng, xs, ys = prepareData("fr", "eng");
initWeight(dims...) = param(rand(dims...) .- 0.5*sqrt(24.0/(sum(dims))))
struct Embed; w; end
Embed(vocab::Int, embed::Int) = Embed(initWeight(embed, vocab))
(e::Embed)(x::Number) = e.w[:, Int(x)]
(e::Embed)(x::AbstractArray) = hcat(e.(x)...)
Flux.@treelike Embed
struct EncoderRNN; hidden_size; embedding; rnn end
EncoderRNN(input_size, hidden_size) = EncoderRNN(
hidden_size,
Embed(input_size, hidden_size),
GRU(hidden_size, hidden_size))
function (e::EncoderRNN)(x)
x = e.embedding(x)
x = e.rnn(x)
return(x)
end
Flux.@treelike EncoderRNN
struct DecoderRNN; hidden_size; embedding; rnn; linear end
DecoderRNN(hidden_size, output_size) = DecoderRNN(
hidden_size,
Embed(output_size, hidden_size),
Flux.GRUCell(hidden_size, hidden_size),
Dense(hidden_size, output_size))
function (d::DecoderRNN)(hidden, x)
x = d.embedding(x)
x = relu.(x)
x, hidden = d.rnn(hidden, x)
x = softmax(d.linear(x))
return(hidden, x)
end
Flux.@treelike DecoderRNN
struct AttnDecoderRNN; hidden_size; output_size; dropout_p; max_length; embedding; attn; attn_combine; rnn; out end
AttnDecoderRNN(hidden_size, output_size, dropout_p=0.1; max_length=MAX_LENGTH) = AttnDecoderRNN(
hidden_size,
output_size,
dropout_p,
max_length,
Embed(output_size, hidden_size),
Dense(hidden_size*2, max_length),
Dense(hidden_size*2, hidden_size, relu),
Flux.GRUCell(hidden_size, hidden_size),
Dense(hidden_size, output_size))
function (d::AttnDecoderRNN)(x, hidden, encoder_outputs)
encoder_outputs = cu(encoder_outputs)
embedded = d.embedding(x)
embedded = Dropout(d.dropout_p)(embedded)
attn_weights = softmax(d.attn([embedded; hidden]))
attn_applied = permutedims(sum(encoder_outputs.*attn_weights, dims=1), (3, 2, 1))[:, :]
output = [attn_applied; embedded]
output = d.attn_combine(output)
hidden, output = d.rnn(hidden, output)
output = softmax(d.out(output))
return(output, hidden, attn_weights)
end
Flux.@treelike AttnDecoderRNN
indexesFromSentence(lang, sentence) = get.(Ref(lang.word2index), split(lowercase(sentence), " "), 3)
function tensorsFromPair(input_lang, output_lang, pair)
input = append!(indexesFromSentence(input_lang, pair[1]), 2)
target = append!(indexesFromSentence(output_lang, pair[2]), 2)
return(input, target)
end
function batch(data, bs)
data = Iterators.partition(data, bs)
batch_output = []
for batch in data
max_size = maximum(length.(batch))
placeholder = fill(4, max_size, bs)
[placeholder[1:length(batch[i]), i] = batch[i] for i in eachindex(batch)]
push!(batch_output, placeholder)
end
return(batch_output)
end
xs, ys = batch.([indexesFromSentence.([eng], xs), indexesFromSentence.([fr], ys)], 20);
trn_xs, trn_ys, test_xs, test_ys = [xs[1:end-10], ys[1:end-10], xs[end-10:end], ys[end-10:end]]
teacher_forcing_ratio = 0.5
function train(input, target, encoder, decoder, criterion, max_length=MAX_LENGTH)
Flux.reset!.([encoder, decoder])
input = cu(input)
target = cu(target)
batch_size = size(input, 2)
input_length = size(input, 1)
target_length = size(target, 1)
encoder_outputs = zeros(max_length, batch_size, encoder.hidden_size)
loss = 0
loss_times = 0
for i in eachindex(input_length)
encoder_output = encoder(input[i, :])
encoder_outputs[i, :, :] = permutedims(encoder_output.data, (2, 1))
decoder_input = ones(Int, 1, batch_size)
decoder_hidden = encoder.rnn.state
use_teacher_forcing = rand() < teacher_forcing_ratio
if use_teacher_forcing
for i in eachindex(target_length)
decoder_output, decoder_hidden, decoder_attention =
decoder(decoder_input, decoder_hidden, encoder_outputs)
loss += criterion(decoder_output, target[i])
decoder_input = target[i]
loss_times += 1
end
else
for i in eachindex(target_length)
decoder_output, decoder_hidden, decoder_attention =
decoder(decoder_input, decoder_hidden, encoder_outputs)
decoder_input = Flux.onecold(decoder_output.data)
loss += criterion(decoder_output, target[i])
loss_times += 1
if decoder_input == 2 break end
end
end
end
#println(loss.data/loss_times)
return(loss/loss_times)
end
function trainIters(encoder, decoder, α=0.0015)
losses = []
criterion = function(output, target)
target = Int(target)
mean([ifelse(target[i]==4, nothing, -log(output[Int(target[i]), i])) for i in 1:size(target, 1)])
end
opt = ADAM(params(encoder, decoder), α)
loss(x, y) = train(x, y, encoder, decoder, criterion)
Flux.train!(loss, zip(trn_xs, trn_ys), opt, cb = () -> println(mean(train.(test_xs, test_ys, [encoder1], [attn_decoder1], criterion))))
end
hidden_size=256
encoder1 = EncoderRNN(eng.n_words, hidden_size)|>gpu
attn_decoder1 = AttnDecoderRNN(hidden_size, fr.n_words, 0.1)|>gpu
trainIters(encoder1, attn_decoder1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment