Skip to content

Instantly share code, notes, and snippets.

@seanie12
Last active February 19, 2019 06:51
Show Gist options
  • Save seanie12/ef31411438cfbfc3c93fdccceb0cda95 to your computer and use it in GitHub Desktop.
Save seanie12/ef31411438cfbfc3c93fdccceb0cda95 to your computer and use it in GitHub Desktop.
char-level seq2seq with attention
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import random
import numpy as np
import re, unicodedata
random.seed(1024)
# gpu configuration
USE_CUDA = torch.cuda.is_available()
gpus = [0]
torch.cuda.set_device(gpus[0])
FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor
flatten = lambda lst: [item for sublist in lst for item in sublist]
def get_batch(batch_size, train_data):
random.shuffle(train_data)
start_idx = 0
end_idx = batch_size
while end_idx < len(train_data):
batch = train_data[start_idx: end_idx]
temp = end_idx
end_idx = end_idx + batch_size
start_idx = temp
yield batch
if end_idx >= len(train_data):
yield train_data[start_idx:]
def pad_to_batch(batch, source2idx, target2idx):
# sort data point as descending order of seq_length
sorted_batch = sorted(batch, key=lambda x: x[0].size(1), reverse=True)
x, y = list(zip(*sorted_batch))
max_source_length = max([seq.size(1) for seq in x])
max_target_length = max([seq.size(1) for seq in y])
padded_x, padded_y = [], []
for i in range(len(batch)):
if x[i].size(1) < max_source_length:
paddings = LongTensor([source2idx["<PAD>"]] * (max_source_length - x[i].size(1)))
paddings = paddings.view(1, -1)
padded = torch.cat([x[i], paddings], dim=1)
padded_x.append(padded)
else:
padded_x.append(x[i])
if y[i].size(1) < max_target_length:
paddings = LongTensor([target2idx["<PAD>"]] * (max_target_length - y[i].size(1)))
paddings = paddings.view(1, -1)
padded = torch.cat([y[i], paddings], dim=1)
padded_y.append(padded)
else:
padded_y.append(y[i])
input_x = torch.cat(padded_x, dim=0)
target_y = torch.cat(padded_y, dim=0)
input_len = [list(map(lambda s: s == 0, t.data)).count(False) for t in input_x]
target_len = [list(map(lambda s: s == 0, t.data)).count(False) for t in target_y]
return input_x, target_y, input_len, target_len
# convert tokens to indices
def preapre_sequence(seq, to_idx):
indices = list(map(lambda w: to_idx[w] if w in to_idx else to_idx["<UNK>"], seq))
return LongTensor(indices)
def unicode_to_ascii(s):
return "".join(
c for c in unicodedata.normalize("NFD", s)
if unicodedata.category(c) != "Mn"
)
def normalize_string(s):
s = unicode_to_ascii(s.lower().strip())
s = re.sub(r"([,.!?])", r" \1 ", s)
s = re.sub(r"[^a-zA-Z,.!?]+", r" ", s)
s = re.sub(r"\s+", r" ", s).strip()
return s
corpus = open("fra-eng/fra.txt", "r", encoding="utf-8").readlines()[:-1]
x_raw, y_raw = [], []
for parallel in corpus:
source, target = parallel[:-1].split("\t")
if source.strip() == "" or target.strip() == "":
continue
normalized_source = normalize_string(source).split()
normalized_target = normalize_string(target).split()
x_raw.append(normalized_source)
y_raw.append(normalized_target)
# construct vocab for source and target language
source_vocab = list(set(flatten(x_raw)))
target_vocab = list(set(flatten(y_raw)))
source2idx = {"<PAD>": 0, "<UNK>": 1, "<s>": 2, "</s>": 3}
for vocab in source_vocab:
if vocab not in source2idx:
source2idx[vocab] = len(source2idx)
idx2source = {idx: vocab for vocab, idx in source2idx.items()}
target2idx = {"<PAD>": 0, "<UNK>": 1, "<s>": 2, "</s>": 3}
for vocab in target_vocab:
if vocab not in target2idx:
target2idx[vocab] = len(target2idx)
idx2target = {idx: vocab for vocab, idx in target2idx.items()}
padded_x, padded_y = [], []
for source, target in zip(x_raw, y_raw):
padded_x.append(preapre_sequence(source + ["</s>"], source2idx).view(1, -1))
padded_y.append(preapre_sequence(target + ["</s>"], target2idx).view(1, -1))
train_data = list(zip(padded_x, padded_y))
class Encoder(nn.Module):
def __init__(self, input_size, embedding_size, hidden_size, n_layers=1, bidirectional=False):
super(Encoder, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.n_layers = n_layers
self.embedding = nn.Embedding(input_size, embedding_size)
if bidirectional:
self.n_direction = 2
self.gru = nn.GRU(embedding_size, hidden_size, n_layers, batch_first=True, bidirectional=True)
else:
self.n_direction = 1
self.gru = nn.GRU(embedding_size, hidden_size, n_layers, batch_first=True, bidirectional=False)
def init_hidden(self, inputs):
hidden = torch.zeros((self.n_layers * self.n_direction, inputs.size(0), self.hidden_size))
return hidden.cuda() if USE_CUDA else hidden
def init_weight(self):
self.embedding.weight = nn.init.xavier_uniform_(self.embedding.weight)
self.gru.weight_hh_l0 = nn.init.xavier_uniform_(self.gru.weight_hh_l0)
self.gru.weight_ih_l0 = nn.init.xavier_uniform_(self.gru.weight_ih_l0)
def forward(self, inputs, input_lengths):
# inputs : [B, T] LongTensor
# input_lengths: actual lengths of input batch list()
hidden = self.init_hidden(inputs)
embedded = self.embedding(inputs)
packed = pack_padded_sequence(embedded, input_lengths, batch_first=True)
# ouptuts :[B, T, 2D], hidden: [n_layers * n_direction, B, D]
outputs, hidden = self.gru(packed, hidden)
outputs, output_lengths = pad_packed_sequence(outputs, batch_first=True)
if self.n_layers > 1:
if self.n_direction == 2:
hidden = hidden[-2:]
else:
hidden = hidden[-1]
hidden = torch.cat([h for h in hidden], dim=1) # [B, D * n_layers * n_direction]
hidden = hidden.unsqueeze(dim=1) # [B, 1, D']
return outputs, hidden
class Decoder(nn.Module):
def __init__(self, input_size, embedding_size,
hidden_size, n_layers=1, dropout_prob=0.1):
super(Decoder, self).__init__()
self.hidden_size = hidden_size
self.n_layers = n_layers
self.embedding = nn.Embedding(input_size, embedding_size)
self.dropout = nn.Dropout(dropout_prob)
# because input is [embedded, context_vector], its size is embedding_size + hidden_size
self.gru = nn.GRU(embedding_size + hidden_size, hidden_size, n_layers, batch_first=True)
self.linear = nn.Linear(hidden_size * 2, input_size) # for final softmax layer
self.attn_linear = nn.Linear(self.hidden_size, self.hidden_size)
# for bahdanau attention
self.encoder_linear = nn.Linear(hidden_size, hidden_size, bias=False)
self.decoder_linear = nn.Linear(hidden_size, hidden_size)
self.attention_output = nn.Linear(hidden_size, 1)
def init_hidden(self, inputs):
hidden = torch.zeros((self.n_layers, inputs.size(0), self.hidden_size))
return hidden.cuda() if USE_CUDA else hidden
def init_weigh(self):
self.embedding.weight = nn.init.xavier_uniform_(self.embedding.weight)
self.gru.weight_hh_l0 = nn.init.xavier_uniform_(self.gru.weight_hh_l0)
self.gru.weight_ih_l0 = nn.init.xavier_uniform_(self.gru.weight_ih_l0)
self.attn_linear.weight = nn.init.xavier_uniform_(self.attn_linear.weight)
def bahdanau_attention(self, decoder_hidden, encoder_outputs, encoder_mask):
# decoder_hidden : [1, b, d]
# encoder_outputs :[b, t, d]
# encoder_maskings : [b, t] ByteTensor
decoder_hidden = decoder_hidden[0].unsqueeze(dim=1) # [b, 1, d]
decoder_feature = self.decoder_linear(decoder_hidden)
encoder_features = self.encoder_linear(encoder_outputs)
attn_features = encoder_features + decoder_feature
attn_features = F.tanh(attn_features)
score = self.attention_output(attn_features).squeeze(dim=2) # [b,t,1] -> [b,t]
if encoder_mask is not None:
score = score.masked_fill(encoder_mask, value=-1e12)
score = F.softmax(score, dim=1)
score = score.unsqueeze(dim=1)
context_vector = torch.matmul(score, encoder_outputs)
return context_vector, score
def luong_attention(self, decoder_hidden, encoder_outputs, encoder_mask):
# decoder_hidden : [1, b, d]
# encoder_outputs :[b, t, d]
# encoder_maskings : [b, t] ByteTensor
decoder_hidden = decoder_hidden[0].unsqueeze(2) # [1, b, d] -> [b,d,1]
energies = self.attn_linear(encoder_outputs) # [b, t, d]
attention_energies = torch.matmul(energies, decoder_hidden).squeeze(2)
# mask
if encoder_mask is not None:
attention_energies = attention_energies.masked_fill(encoder_mask, value=-1e12)
alpha = F.softmax(attention_energies, dim=1)
alpha = alpha.unsqueeze(1) # [b,1,t]
context_vector = torch.matmul(alpha, encoder_outputs) # [b, 1, d]
return context_vector, alpha
def forward(self, inputs, context, max_length, encoder_outputs, encoder_maskings, is_training=False):
"""
:param inputs: [b,1] LongTensor, Start_symbol
:param context: [b, 1, d]
:param max_length: max length to decode
:param encoder_outputs: [b, t, d]
:param encoder_maskings: [b, t] ByteTensor
:param is_training: boolean
:return:
"""
embedded = self.embedding(inputs)
hidden = self.init_hidden(inputs)
if is_training:
embedded = self.dropout(embedded)
decode = []
# unroll gru
for t in range(max_length):
concat_inputs = torch.cat((embedded, context), dim=2)
_, hidden = self.gru(concat_inputs, hidden)
concat_outputs = torch.cat((hidden, context.transpose(0, 1)), dim=2)
logit = self.linear(concat_outputs.squeeze(0))
score = F.log_softmax(logit, 1)
decode.append(logit)
decoded = torch.argmax(score, dim=1)
# input for next time step because it is not teacher-forcing training
embedded = self.embedding(decoded).unsqueeze(1)
if is_training:
embedded = self.dropout(embedded)
context, alpha = self.bahdanau_attention(hidden, encoder_outputs, encoder_maskings)
scores = torch.cat(decode, dim=1)
batch_size = inputs.size(0)
return scores.view(batch_size * max_length, -1)
def decode(self, context, encoder_outputs, max_decode_length=100):
start_decode = LongTensor([[target2idx["<s>"]] * 1])
embedded = self.embedding(start_decode)
hidden = self.init_hidden(start_decode)
decodes = []
attentions = []
decoded = torch.Tensor([target2idx["<s>"]])
while decoded.tolist()[0] != target2idx["</s>"] and len(decodes) < max_decode_length:
concat_input = torch.cat((embedded, context), dim=2)
_, hidden = self.gru(concat_input, hidden)
concat_output = torch.cat((hidden, context.transpose(0, 1)), dim=2)
score = self.linear(concat_output.squeeze(0)) # [1, d]
score = F.log_softmax(score, dim=1)
decodes.append(score)
decoded = torch.argmax(score, dim=1) # [1]
embedded = self.embedding(decoded).unsqueeze(1) # [1,d] -> [1, 1, d]
context, alpha = self.attention(hidden, encoder_outputs, None)
attentions.append(alpha.squeeze(1))
indices = torch.cat(decodes, dim=0).max(1)[1]
attentions = torch.cat(attentions, dim=0)
return indices, attentions
num_epochs = 50
batch_size = 64
embedding_size = 300
hidden_size = 512
lr = 1e-3
decoder_lr_ratio = 5.0
rescheduled = False
encoder = Encoder(len(source2idx), embedding_size, hidden_size, 3, True)
decoder = Decoder(len(target2idx), embedding_size, hidden_size * 2)
encoder.init_weight()
decoder.init_weigh()
if USE_CUDA:
encoder = encoder.cuda()
decoder = decoder.cuda()
loss_function = nn.CrossEntropyLoss(ignore_index=0)
enc_optimizer = optim.Adam(encoder.parameters(), lr=lr)
dec_optimizer = optim.Adam(decoder.parameters(), lr=lr * decoder_lr_ratio)
for epoch in range(num_epochs):
losses = []
for i, batch in enumerate(get_batch(batch_size, train_data)):
# prepare inputs
inputs, targets, input_lengths, target_lengths = pad_to_batch(batch, source2idx, target2idx)
zeros = torch.zeros_like(inputs)
input_masks = ByteTensor(inputs == zeros)
start_decode = LongTensor([[target2idx["<s>"]]] * targets.size(0))
encoder.zero_grad()
decoder.zero_grad()
output, hidden_c = encoder(inputs, input_lengths)
preds = decoder(start_decode, hidden_c, targets.size(1), output, input_masks, True)
loss = loss_function(preds, targets.view(-1))
losses.append(loss.tolist())
loss.backward()
nn.utils.clip_grad_norm_(encoder.parameters(), 5.0)
nn.utils.clip_grad_norm_(decoder.parameters(), 5.0)
enc_optimizer.step()
dec_optimizer.step()
if i % 200 == 0:
print("[%02d/%d] [%03d/%d] mean_loss : %0.2f" % (
epoch, num_epochs, i, len(train_data) // batch_size, np.mean(losses)))
if rescheduled is False and epoch == num_epochs // 2:
lr *= 0.01
enc_optimizer = optim.Adam(encoder.parameters(), lr=lr)
dec_optimizer = optim.Adam(decoder.parameters(), lr=lr * decoder_lr_ratio)
rescheduled = True
test = train_data[0]
source = test[0]
target = test[1]
input_text = [idx2source[idx] for idx in source.tolist()[0]]
target_text = [idx2target[idx] for idx in target.tolist()[0]]
output, hidden = encoder(source, [source.size(1)])
pred, attn = decoder.decode(hidden, output)
pred_text = [idx2target[idx] for idx in pred.tolist()]
print("source : ", " ".join(input_text))
print("target :", " ".join(target_text))
print(pred_text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment