Skip to content

Instantly share code, notes, and snippets.

@yusugomori
Last active March 21, 2019 02:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yusugomori/29244ef1804891c202a044f93aaf4433 to your computer and use it in GitHub Desktop.
Save yusugomori/29244ef1804891c202a044f93aaf4433 to your computer and use it in GitHub Desktop.
import os
import subprocess
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimizers
from torch.autograd import Variable
class PositionalEncoding(nn.Module):
'''
Positional encoding layer with sinusoid
'''
def __init__(self, output_dim,
max_len=6000,
device='cpu'):
super().__init__()
self.output_dim = output_dim
self.max_len = max_len
pe = self.initializer()
self.register_buffer('pe', pe)
def forward(self, x, mask=None):
'''
# Argument
x: (batch, sequence)
'''
pe = self.pe[:x.size(1), :].unsqueeze(0)
return x + Variable(pe, requires_grad=False)
def initializer(self):
pe = \
np.array([[pos / np.power(10000, 2 * (i // 2) / self.output_dim)
for i in range(self.output_dim)]
for pos in range(self.max_len)])
pe[:, 0::2] = np.sin(pe[:, 0::2])
pe[:, 1::2] = np.cos(pe[:, 1::2])
return torch.from_numpy(pe).float()
class ScaledDotProductAttention(nn.Module):
def __init__(self,
d_k,
device='cpu'):
super().__init__()
self.device = device
self.scaler = np.sqrt(d_k)
def forward(self, q, k, v, mask=None):
'''
# Argument
q, k, v: (batch, sequence, out_features)
mask: (batch, sequence (, sequence))
'''
score = torch.einsum('ijk,ilk->ijl', (q, k)) / self.scaler
score = score - torch.max(score,
dim=-1,
keepdim=True)[0] # softmax max trick
score = torch.exp(score)
if mask is not None:
# suppose `mask` is a mask of source
# in source-target-attention, source is `k` and `v`
if len(mask.size()) == 2:
mask = mask.unsqueeze(1).repeat(1, score.size(1), 1)
# score = score * mask.float().to(self.device)
score = score.data.masked_fill_(mask, 0)
a = score / torch.sum(score, dim=-1, keepdim=True)
c = torch.einsum('ijk,ikl->ijl', (a, v))
return c
class MultiHeadAttention(nn.Module):
def __init__(self,
h,
d_model,
device='cpu'):
super().__init__()
self.h = h
self.d_model = d_model
self.d_k = d_k = d_model // h
self.d_v = d_v = d_model // h
self.device = device
self.W_q = nn.Parameter(torch.Tensor(h,
d_model,
d_k))
self.W_k = nn.Parameter(torch.Tensor(h,
d_model,
d_k))
self.W_v = nn.Parameter(torch.Tensor(h,
d_model,
d_v))
nn.init.xavier_normal_(self.W_q)
nn.init.xavier_normal_(self.W_k)
nn.init.xavier_normal_(self.W_v)
self.attn = ScaledDotProductAttention(d_k)
self.linear = nn.Linear((h * d_v), d_model)
nn.init.xavier_normal_(self.linear.weight)
def forward(self, q, k, v, mask=None):
'''
# Argument
q, k, v: (batch, sequence, out_features)
mask: (batch, sequence (, sequence))
'''
batch_size = q.size(0)
q = torch.einsum('hijk,hkl->hijl',
(q.unsqueeze(0).repeat(self.h, 1, 1, 1),
self.W_q))
k = torch.einsum('hijk,hkl->hijl',
(k.unsqueeze(0).repeat(self.h, 1, 1, 1),
self.W_k))
v = torch.einsum('hijk,hkl->hijl',
(v.unsqueeze(0).repeat(self.h, 1, 1, 1),
self.W_v))
q = q.view(-1, q.size(-2), q.size(-1))
k = k.view(-1, k.size(-2), k.size(-1))
v = v.view(-1, v.size(-2), v.size(-1))
if mask is not None:
multiples = [self.h] + [1] * (len(mask.size()) - 1)
mask = mask.repeat(multiples)
c = self.attn(q, k, v, mask=mask)
c = torch.split(c, batch_size, dim=0)
c = torch.cat(c, dim=-1)
out = self.linear(c)
return out
class Encoder(nn.Module):
def __init__(self,
depth_source,
N=6,
h=8,
d_model=512,
d_ff=2048,
p_dropout=0.1,
max_len=128,
device='cpu'):
super().__init__()
self.device = device
self.embedding = nn.Embedding(depth_source,
d_model, padding_idx=0)
self.pe = PositionalEncoding(d_model, max_len=max_len)
self.encs = nn.ModuleList([
EncoderLayer(h=h,
d_model=d_model,
d_ff=d_ff,
p_dropout=p_dropout,
max_len=max_len,
device=device) for _ in range(N)])
def forward(self, x, mask=None):
x = self.embedding(x)
y = self.pe(x)
for enc in self.encs:
y = enc(y, mask=mask)
return y
class EncoderLayer(nn.Module):
def __init__(self,
h=8,
d_model=512,
d_ff=2048,
p_dropout=0.1,
max_len=128,
device='cpu'):
super().__init__()
self.attn = MultiHeadAttention(h, d_model)
self.dropout1 = nn.Dropout(p_dropout)
self.norm1 = nn.LayerNorm(d_model)
self.ff = FFN(d_model, d_ff)
self.dropout2 = nn.Dropout(p_dropout)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
h = self.attn(x, x, x, mask=mask)
h = self.dropout1(h)
h = self.norm1(x + h)
y = self.ff(h)
y = self.dropout2(y)
y = self.norm2(h + y)
return y
class Decoder(nn.Module):
def __init__(self,
depth_target,
N=6,
h=8,
d_model=512,
d_ff=2048,
p_dropout=0.1,
max_len=128,
device='cpu'):
super().__init__()
self.device = device
self.embedding = nn.Embedding(depth_target,
d_model, padding_idx=0)
self.pe = PositionalEncoding(d_model, max_len=max_len)
self.decs = nn.ModuleList([
DecoderLayer(h=h,
d_model=d_model,
d_ff=d_ff,
p_dropout=p_dropout,
max_len=max_len,
device=device) for _ in range(N)])
def forward(self, x, hs,
mask=None,
source_mask=None):
x = self.embedding(x)
y = self.pe(x)
for dec in self.decs:
y = dec(y, hs,
mask=mask,
source_mask=source_mask)
return y
class DecoderLayer(nn.Module):
def __init__(self,
h=8,
d_model=512,
d_ff=2048,
p_dropout=0.1,
max_len=128,
device='cpu'):
super().__init__()
self.self_attn = MultiHeadAttention(h, d_model)
self.dropout1 = nn.Dropout(p_dropout)
self.norm1 = nn.LayerNorm(d_model)
self.src_tgt_attn = MultiHeadAttention(h, d_model)
self.dropout2 = nn.Dropout(p_dropout)
self.norm2 = nn.LayerNorm(d_model)
self.ff = FFN(d_model, d_ff)
self.dropout3 = nn.Dropout(p_dropout)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, hs,
mask=None,
source_mask=None):
h = self.self_attn(x, x, x, mask=mask)
h = self.dropout1(h)
h = self.norm1(x + h)
z = self.src_tgt_attn(h, hs, hs,
mask=source_mask)
z = self.dropout2(z)
z = self.norm2(h + z)
y = self.ff(z)
y = self.dropout3(y)
y = self.norm3(z + y)
return y
class FFN(nn.Module):
'''
Position-wise Feed-Forward Networks
'''
def __init__(self, d_model, d_ff,
device='cpu'):
super().__init__()
self.l1 = nn.Linear(d_model, d_ff)
self.l2 = nn.Linear(d_ff, d_model)
# self.l1 = nn.Conv1d(d_model, d_ff, 1)
# self.l2 = nn.Conv1d(d_ff, d_model, 1)
def forward(self, x):
x = self.l1(x)
x = torch.relu(x)
y = self.l2(x)
return y
class Transformer(nn.Module):
def __init__(self,
depth_source,
depth_target,
N=6,
h=8,
d_model=512,
d_ff=2048,
p_dropout=0.1,
max_len=20,
bos_value=1,
device='cpu'):
super().__init__()
self.device = device
self.encoder = Encoder(depth_source,
N=N,
h=h,
d_model=d_model,
d_ff=d_ff,
p_dropout=p_dropout,
max_len=max_len,
device=device)
self.decoder = Decoder(depth_target,
N=N,
h=h,
d_model=d_model,
d_ff=d_ff,
p_dropout=p_dropout,
max_len=max_len,
device=device)
self.out = nn.Linear(d_model, depth_target)
nn.init.xavier_normal_(self.out.weight)
self._BOS = bos_value
self._max_len = max_len
def forward(self, source, target=None):
source_mask = self.sequence_mask(source)
hs = self.encoder(source, mask=source_mask)
if target is not None:
len_target_sequences = target.size(1)
target_mask = self.sequence_mask(target).unsqueeze(1)
subsequent_mask = self.subsequence_mask(target)
target_mask = torch.gt(target_mask + subsequent_mask, 0)
y = self.decoder(target, hs,
mask=target_mask,
source_mask=source_mask)
output = self.out(y)
else:
batch_size = source.size(0)
len_target_sequences = self._max_len
output = torch.ones((batch_size, 1),
dtype=torch.long,
device=device) * self._BOS
for t in range(len_target_sequences - 1):
target_mask = self.subsequence_mask(output)
out = self.decoder(output, hs,
mask=target_mask,
source_mask=source_mask)
out = self.out(out)[:, -1:, :]
out = out.max(-1)[1]
output = torch.cat((output, out), dim=1)
return output
def sequence_mask(self, x):
return x.eq(0)
def subsequence_mask(self, x):
shape = (x.size(1), x.size(1))
mask = torch.triu(torch.ones(shape, dtype=torch.uint8),
diagonal=1)
return mask.unsqueeze(0).repeat(x.size(0), 1, 1).to(self.device)
def load_small_parallel_enja(path=None,
to_ja=True,
pad_value=0,
start_char=1,
end_char=2,
oov_char=3,
index_from=4,
pad='<PAD>',
bos='<BOS>',
eos='<EOS>',
oov='<UNK>',
add_bos=True,
add_eos=True):
'''
Download 50k En/Ja Parallel Corpus
from https://github.com/odashi/small_parallel_enja
and transform words to IDs.
Original Source from:
https://github.com/yusugomori/tftf/blob/master/tftf/datasets/small_parallel_enja.py
'''
url_base = 'https://raw.githubusercontent.com/' \
'odashi/small_parallel_enja/master/'
path = path or 'small_parallel_enja'
dir_path = os.path.join(os.path.expanduser('~'),
'.tftf', 'datasets', path)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
f_ja = ['train.ja', 'test.ja']
f_en = ['train.en', 'test.en']
for f in (f_ja + f_en):
f_path = os.path.join(dir_path, f)
if not os.path.exists(f_path):
url = url_base + f
print('Downloading {}'.format(f))
cmd = ['curl', '-o', f_path, url]
subprocess.call(cmd)
f_ja_train = os.path.join(dir_path, f_ja[0])
f_test_ja = os.path.join(dir_path, f_ja[1])
f_en_train = os.path.join(dir_path, f_en[0])
f_test_en = os.path.join(dir_path, f_en[1])
(ja_train, test_ja), num_words_ja, (w2i_ja, i2w_ja) = \
_build(f_ja_train, f_test_ja,
pad_value, start_char, end_char, oov_char, index_from,
pad, bos, eos, oov, add_bos, add_eos)
(en_train, test_en), num_words_en, (w2i_en, i2w_en) = \
_build(f_en_train, f_test_en,
pad_value, start_char, end_char, oov_char, index_from,
pad, bos, eos, oov, add_bos, add_eos)
if to_ja:
x_train, x_test, num_X, w2i_X, i2w_X = \
en_train, test_en, num_words_en, w2i_en, i2w_en
y_train, y_test, num_y, w2i_y, i2w_y = \
ja_train, test_ja, num_words_ja, w2i_ja, i2w_ja
else:
x_train, x_test, num_X, w2i_X, i2w_X = \
ja_train, test_ja, num_words_ja, w2i_ja, i2w_ja
y_train, y_test, num_y, w2i_y, i2w_y = \
en_train, test_en, num_words_en, w2i_en, i2w_en
x_train, x_test = np.array(x_train), np.array(x_test)
y_train, y_test = np.array(y_train), np.array(y_test)
return (x_train, y_train), (x_test, y_test), \
(num_X, num_y), (w2i_X, w2i_y), (i2w_X, i2w_y)
def _build(f_train, f_test,
pad_value=0,
start_char=1,
end_char=2,
oov_char=3,
index_from=4,
pad='<PAD>',
bos='<BOS>',
eos='<EOS>',
oov='<UNK>',
add_bos=True,
add_eos=True):
builder = _Builder(pad_value=pad_value,
start_char=start_char,
end_char=end_char,
oov_char=oov_char,
index_from=index_from,
pad=pad,
bos=bos,
eos=eos,
oov=oov,
add_bos=add_bos,
add_eos=add_eos)
builder.fit(f_train)
train = builder.transform(f_train)
test = builder.transform(f_test)
return (train, test), builder.num_words, (builder.w2i, builder.i2w)
class _Builder(object):
def __init__(self,
pad_value=0,
start_char=1,
end_char=2,
oov_char=3,
index_from=4,
pad='<PAD>',
bos='<BOS>',
eos='<EOS>',
oov='<UNK>',
add_bos=True,
add_eos=True):
self._vocab = None
self._w2i = None
self._i2w = None
self.pad_value = pad_value
self.start_char = start_char
self.end_char = end_char
self.oov_char = oov_char
self.index_from = index_from
self.pad = pad
self.bos = bos
self.eos = eos
self.oov = oov
self.add_bos = add_bos
self.add_eos = add_eos
@property
def num_words(self):
return max(self._w2i.values()) + 1
@property
def w2i(self):
'''
Dict of word to index
'''
return self._w2i
@property
def i2w(self):
'''
Dict of index to word
'''
return self._i2w
def fit(self, f_path):
self._vocab = set()
self._w2i = {}
for line in open(f_path, encoding='utf-8'):
_sentence = line.strip().split()
self._vocab.update(_sentence)
self._w2i = {w: (i + self.index_from)
for i, w in enumerate(self._vocab)}
if self.pad_value >= 0:
self._w2i[self.pad] = self.pad_value
self._w2i[self.bos] = self.start_char
self._w2i[self.eos] = self.end_char
self._w2i[self.oov] = self.oov_char
self._i2w = {i: w for w, i in self._w2i.items()}
def transform(self, f_path):
if self._vocab is None or self._w2i is None:
raise AttributeError('`{}.fit` must be called before `transform`.'
''.format(self.__class__.__name__))
sentences = []
for line in open(f_path, encoding='utf-8'):
_sentence = line.strip().split()
# _sentence = [self.bos] + _sentence + [self.eos]
if self.add_bos:
_sentence = [self.bos] + _sentence
if self.add_eos:
_sentence = _sentence + [self.eos]
sentences.append(self._encode(_sentence))
return sentences
def _encode(self, sentence):
encoded = []
for w in sentence:
if w not in self._w2i:
id = self.oov_char
else:
id = self._w2i[w]
encoded.append(id)
return encoded
def pad_sequences(data,
padding='pre',
value=0):
'''
# Arguments
data: list of lists / np.array of lists
# Returns
numpy.ndarray
'''
if type(data[0]) is not list:
raise ValueError('`data` must be a list of lists')
maxlen = len(max(data, key=len))
if padding == 'pre':
data = \
[[value] * (maxlen - len(data[i])) + data[i]
for i in range(len(data))]
elif padding == 'post':
data = \
[data[i] + [value] * (maxlen - len(data[i]))
for i in range(len(data))]
else:
raise ValueError('`padding` must be one of \'pre\' or \'post\'')
return np.array(data)
def sort(data, target,
order='ascend'):
if order == 'ascend' or order == 'ascending':
a = True
elif order == 'descend' or order == 'descending':
a = False
else:
raise ValueError('`order` must be of \'ascend\' or \'descend\'.')
lens = [len(i) for i in data]
indices = sorted(range(len(lens)),
key=lambda x: (2 * a - 1) * lens[x])
data = [data[i] for i in indices]
target = [target[i] for i in indices]
return (data, target)
if __name__ == '__main__':
np.random.seed(1234)
torch.manual_seed(1234)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def compute_loss(label, pred):
return criterion(pred, label)
def train_step(x, t):
model.train()
preds = model(x, t)
loss = compute_loss(t.contiguous().view(-1),
preds.contiguous().view(-1, preds.size(-1)))
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss, preds
def valid_step(x, t):
model.eval()
preds = model(x, t)
loss = compute_loss(t.contiguous().view(-1),
preds.contiguous().view(-1, preds.size(-1)))
return loss, preds
def test_step(x):
model.eval()
preds = model(x)
return preds
def ids_to_sentence(ids, i2w):
return [i2w[id] for id in ids]
'''
Load data
'''
class ParallelDataLoader(object):
def __init__(self, dataset,
batch_size=128,
shuffle=False,
random_state=None):
if type(dataset) is not tuple:
raise ValueError('argument `dataset` must be tuple,'
' not {}.'.format(type(dataset)))
self.dataset = list(zip(dataset[0], dataset[1]))
self.batch_size = batch_size
self.shuffle = shuffle
if random_state is None:
random_state = np.random.RandomState(1234)
self.random_state = random_state
self._idx = 0
def __len__(self):
return len(self.dataset)
def __iter__(self):
return self
def __next__(self):
if self._idx >= len(self.dataset):
self._reorder()
raise StopIteration()
x, y = zip(*self.dataset[self._idx:(self._idx + self.batch_size)])
x, y = sort(x, y, order='descend')
x = pad_sequences(x, padding='post')
y = pad_sequences(y, padding='post')
x = torch.LongTensor(x) # not use .t()
y = torch.LongTensor(y) # not use .t()
self._idx += self.batch_size
return x, y
def _reorder(self):
if self.shuffle:
self.data = shuffle(self.dataset,
random_state=self.random_state)
self._idx = 0
(x_train, y_train), \
(x_test, y_test), \
(num_x, num_y), \
(w2i_x, w2i_y), (i2w_x, i2w_y) = \
load_small_parallel_enja(to_ja=True)
train_dataloader = ParallelDataLoader((x_train, y_train),
shuffle=True)
valid_dataloader = ParallelDataLoader((x_test, y_test))
test_dataloader = ParallelDataLoader((x_test, y_test),
batch_size=1,
shuffle=True)
'''
Build model
'''
model = Transformer(num_x,
num_y,
N=3,
h=4,
d_model=128,
d_ff=256,
max_len=20,
device=device).to(device)
criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=0)
optimizer = optimizers.Adam(model.parameters())
'''
Train model
'''
epochs = 20
for epoch in range(epochs):
print('-' * 20)
print('Epoch: {}'.format(epoch+1))
train_loss = 0.
valid_loss = 0.
for idx, (source, target) in enumerate(train_dataloader):
source, target = source.to(device), target.to(device)
loss, _ = train_step(source, target)
train_loss += loss.item()
train_loss /= len(train_dataloader)
for (source, target) in valid_dataloader:
source, target = source.to(device), target.to(device)
loss, _ = valid_step(source, target)
valid_loss += loss.item()
valid_loss /= len(valid_dataloader)
print('Valid loss: {:.3}'.format(valid_loss))
for idx, (source, target) in enumerate(test_dataloader):
source, target = source.to(device), target.to(device)
out = test_step(source)
out = out.view(-1).tolist()
out = ' '.join(ids_to_sentence(out, i2w_y))
source = ' '.join(ids_to_sentence(source.view(-1).tolist(), i2w_x))
target = ' '.join(ids_to_sentence(target.view(-1).tolist(), i2w_y))
print('>', source)
print('=', target)
print('<', out)
print()
if idx >= 10:
break
@yusugomori
Copy link
Author

yusugomori commented Mar 19, 2019

NOTICE: This implementation is not working.
(Referenced on stackoverflow.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment