Skip to content

Instantly share code, notes, and snippets.

@cwarny
Last active June 14, 2022 23:30
Show Gist options
  • Save cwarny/01907e221e01ea9b140e2a86a7070cf7 to your computer and use it in GitHub Desktop.
Save cwarny/01907e221e01ea9b140e2a86a7070cf7 to your computer and use it in GitHub Desktop.
Software 2.0
class Acc:
def __init__(self, ignore_index=1):
self.ignore_index = ignore_index
def __call__(self, pred, tgt):
# both pred and tgt have shape (bs,seq_len)
mask = tgt != self.ignore_index
pred *= mask
tgt *= mask
correct = torch.eq(pred, tgt).all(1).sum()
return correct.item()
assert roman_to_integer('MCXCIII') == 1193
max_len = 20
vocab = build_vocab_from_iterator(
list(map(str,range(10))) # Arabic symbols
+ ['I','V','X','L','C','D','M'], # Roman symbols
specials=['<bos>', '<pad>', '<eos>'] # Special symbols
)
vocab_size = len(vocab)
pad_idx = vocab['<pad>']
collate_fn = partial(collate, pad_idx=pad_idx, max_len=max_len)
proc = Processor(vocab)
train_ds = NumberDataset.from_file('train', processor=proc)
valid_ds = NumberDataset.from_file('valid', processor=proc)
train_dl = DataLoader(train_ds, batch_size=10, collate_fn=collate_fn)
valid_dl = DataLoader(valid_ds, batch_size=10, collate_fn=collate_fn)
import torch
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from functools import partial
class NumberDataset(Dataset):
valid_targets = ['roman','integer']
def __init__(self, lst, processor=None, target='roman'):
assert target in self.valid_targets, f'Target needs to be one of {self.valid_targets}'
self.target = target
self.processor = processor
self.lst = lst
def __getitem__(self, i):
i,r = self.lst[i]
t = (i,r) if self.target == 'roman' else (r,i)
return list(map(self.processor.process, t))
def __len__(self):
return len(self.lst)
@classmethod
def from_file(cls, fn, root=None, extension='txt', **kwargs):
if root is None: root = default_data_path
url = root/('.'.join([fn,extension]))
with open(url) as f: lines = [line.split() for line in f]
return cls(lines, **kwargs)
class Processor:
def __init__(self, vocab):
self.vocab = vocab
def process(self, x):
seq = ['<bos>'] + list(x) + ['<eos>']
return [self.vocab[tok] for tok in seq]
def deprocess(self, x):
out = []
for idx in x:
tok = self.vocab.lookup_token(idx)
if tok == '<bos>': continue
elif tok == '<eos>': return ''.join(out)
else: out.append(tok)
return ''.join(out)
def collate(batch, max_len=20, pad_idx=1):
src_lst, tgt_lst = [], []
for src, tgt in batch:
src, tgt = map(torch.tensor, [src, tgt])
src_lst.append(src)
tgt_lst.append(tgt)
src_lst[0] = nn.ConstantPad1d((0, max_len-src_lst[0].size(0)), pad_idx)(src_lst[0])
tgt_lst[0] = nn.ConstantPad1d((0, max_len-tgt_lst[0].size(0)), pad_idx)(tgt_lst[0])
return list(map(partial(pad_sequence, padding_value=pad_idx, batch_first=True), [src_lst, tgt_lst]))
def evaluate(mdl, dl, loss_fn, metric):
mdl.eval()
epoch_loss = 0
correct = 0
with torch.no_grad():
for i, (src, tgt) in enumerate(dl):
out = mdl(src, tgt, teacher_forcing_proba=0) # turn off teacher forcing
bs, seq_len, out_dim = out.shape
out = out.view(-1, out_dim)
tgt = tgt[:,1:].contiguous().view(-1)
loss = loss_fn(out, tgt)
epoch_loss += loss.item()
pred = out.argmax(-1)
m = metric(pred.view(bs, -1), tgt.view(bs, -1))
correct += m
n = (i+1)*bs
return epoch_loss/n, correct/n
best_i2r_mdl = fit(100, i2r_mdl, train_dl, valid_dl, opt, criterion, metric, patience=3)
import json
def fit(epochs, mdl, train_dl, valid_dl, opt, criterion, metric, patience=2):
fmt = lambda x: f'{x:.3f}'
best_valid_loss = float('inf')
best_mdl = None
irritation = 0
for epoch in range(epochs):
print(f'Epoch: {epoch+1:02}')
train_loss, train_metric = train(mdl, train_dl, opt, criterion, metric)
valid_loss, valid_metric = evaluate(mdl, valid_dl, criterion, metric)
print('\t' + json.dumps({
'train': {
'loss': fmt(train_loss),
'metric': fmt(train_metric)
},
'valid': {
'loss': fmt(valid_loss),
'metric': fmt(valid_metric)
}
}, indent=4))
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
best_mdl = mdl
torch.save(mdl.state_dict(), 'model.pt')
irritation = 0
else:
irritation += 1
if irritation == patience: break
return best_mdl
def integer_to_roman(n):
div = 1
while n >= div: div *= 10
div //= 10
out = []
while n:
d = n // div # get most significant digit via floor division by a power of 10
if d < 4:
o = i2r[div]*d
elif d == 4:
o = i2r[div] + i2r[div*5]
elif d < 9:
o = i2r[div*5] + (d-5)*i2r[div]
else:
o = i2r[div] + i2r[div*10]
out.append(o)
n = n % div # the new integer is the remainder
div //= 10
return ''.join(out)
r2i = {
'I': 1,
'V': 5,
'X': 10,
'L': 50,
'C': 100,
'D': 500,
'M': 1000
}
i2r = {v:k for k,v in r2i.items()} # reverse the mapping
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, vocab_size, hidden_dim, dropout):
super().__init__()
self.emb = nn.Embedding(vocab_size, hidden_dim)
self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True, dropout=dropout, bidirectional=True)
self.dropout = nn.Dropout(dropout)
self.project = nn.Linear(hidden_dim*2, hidden_dim)
def forward(self, x):
bs = x.size(0)
x = self.dropout(self.emb(x))
h, h_last = self.rnn(x)
h_last = h_last.permute(1,0,2).contiguous().view(bs, -1) # (bs,hidden_dim*2)
h, h_last = map(self.project, [h, h_last])
h_last = h_last.unsqueeze(0)
return h, h_last # (bs,seq_len,hidden_dim), (1,bs,hidden_dim)
class Attention(nn.Module):
def __init__(self, encoder_hidden_dim, decoder_hidden_dim):
super().__init__()
self.decoder_hidden_dim = torch.tensor(decoder_hidden_dim)
self.w = nn.Parameter(torch.FloatTensor(decoder_hidden_dim, encoder_hidden_dim).uniform_(-0.1, 0.1))
def forward(self, query, values):
score = (query.unsqueeze(1) @ self.w @ values.permute(0,2,1))/torch.sqrt(self.decoder_hidden_dim)
attention_weights = F.softmax(score, 1)
context = attention_weights @ values
return context
class Decoder(nn.Module):
def __init__(self, vocab_size, hidden_dim, dropout):
super().__init__()
self.emb = nn.Embedding(vocab_size, hidden_dim)
self.rnn = nn.GRU(hidden_dim*2, hidden_dim, batch_first=True, dropout=dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, y, h_in, context):
y = self.dropout(self.emb(y.unsqueeze(1)))
y = torch.cat([context, y], -1) # (bs,1,hidden_dim*2)
h, h_last = self.rnn(y, h_in) # (bs,1,hidden_dim), (1,bs,hidden_dim)
return h.squeeze(1), h_last
class Seq2SeqWithAttention(nn.Module):
def __init__(self, vocab_size, hidden_dim=20, dropout=.5):
super().__init__()
self.encode = Encoder(vocab_size, hidden_dim, dropout)
self.attend = Attention(hidden_dim, hidden_dim)
self.decode = Decoder(vocab_size, hidden_dim, dropout)
self.project = nn.Linear(hidden_dim, vocab_size)
def forward(self, src, tgt, teacher_forcing_proba=.5):
bs, tgt_len = tgt.shape
h, h_last = self.encode(src)
s = h_last.squeeze(0) # (bs,hidden_dim)
y = tgt[:,0]
logits = []
for t in range(1, tgt_len):
context = self.attend(s, h) # context: (bs,1,hidden_dim)
s, h_last = self.decode(y, h_last, context)
logit = self.project(s)
logits.append(logit)
teacher_force = random.random() < teacher_forcing_proba
y = tgt[:,t] if teacher_force else logit.argmax(-1)
return torch.stack(logits, 1)
import torch.optim as optim
i2r_mdl = Seq2SeqWithAttention(vocab_size, hidden_dim=30, dropout=.3)
opt = optim.Adam(i2r_mdl.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
metric = Acc(ignore_index=pad_idx)
to_predict = [i for i,r in test_set]
preds = predict(best_i2r_mdl, to_predict, proc, collate_fn)
for (i,r),pred in zip(test_set,preds):
print(f'{i} -> {pred} ({r})')
to_predict = [r for i,r in test_set]
preds = predict(best_r2i_mdl, to_predict, proc, collate_fn)
for (i,r),pred in zip(test_set,preds):
print(f'{r} -> {pred} ({i})')
def predict(mdl, tests, proc, collate):
mdl.eval()
tests = list(zip(tests,tests))
test_ds = NumberDataset(tests, processor=proc)
test_dl = DataLoader(test_ds, batch_size=len(tests), collate_fn=collate)
with torch.no_grad():
for src,tgt in test_dl:
out = mdl(src, tgt, teacher_forcing_proba=0)
bs, seq_len, out_dim = out.shape
out = out.view(-1, out_dim)
tgt = tgt[:,1:].contiguous().view(-1)
pred = out.argmax(-1)
return [proc.deprocess(seq) for seq in pred.view(bs, -1)]
train_ds = NumberDataset.from_file('train', processor=proc, target='integer')
valid_ds = NumberDataset.from_file('valid', processor=proc, target='integer')
train_dl = DataLoader(train_ds, batch_size=10, collate_fn=collate_fn)
valid_dl = DataLoader(valid_ds, batch_size=10, collate_fn=collate_fn)
r2i_mdl = Seq2SeqWithAttention(vocab_size, hidden_dim=30, dropout=.3)
opt = optim.Adam(r2i_mdl.parameters(), lr=1e-3)
best_r2i_mdl = fit(100, r2i_mdl, train_dl, valid_dl, opt, criterion, metric, patience=3)
def roman_to_integer(s):
l = len(s)
tot = 0
prev_n = 0
for i in range(l):
current_n = r2i[s[i]]
next_n = r2i[s[i+1]] if i+1 < l else 0
if current_n >= next_n:
tot += (current_n - prev_n)
prev_n = 0
else:
prev_n = current_n
return tot
test_ds = NumberDataset.from_file('test', processor=proc)
test_dl = DataLoader(test_ds, batch_size=5, collate_fn=collate_fn)
_, test_metric = evaluate(best_i2r_mdl, test_dl, criterion, metric)
print(test_metric)
test_ds = NumberDataset.from_file('test', processor=proc, target='integer')
test_dl = DataLoader(test_ds, batch_size=5, collate_fn=collate_fn)
_, test_metric = evaluate(best_r2i_mdl, test_dl, criterion, metric)
print(test_metric)
test_set = [
('4', 'IV'),
('1193', 'MCXCIII'),
('548', 'DXLVIII'),
('3616', 'MMMDCXVI'),
('21', 'XXI')
]
for src,tgt in test_set:
assert integer_to_roman(int(src)) == tgt
print('Success')
def train(mdl, dl, opt, loss_fn, metric):
mdl.train()
epoch_loss = 0
correct = 0
for i, (src, tgt) in enumerate(dl):
opt.zero_grad()
out = mdl(src, tgt)
bs, seq_len, out_dim = out.shape
assert out.size(1) == tgt.size(1)-1 # we skipped the first element in the output
# collapse seq and batch dims
out = out.view(-1, out_dim)
tgt = tgt[:,1:].contiguous().view(-1) # skip the first element in the ground truth
loss = loss_fn(out, tgt)
loss.backward()
opt.step()
epoch_loss += loss.item()
pred = out.argmax(-1)
m = metric(pred.view(bs, -1), tgt.view(bs, -1))
correct += m
if i > 0 and i % 1e4 == 0:
print(f'\t{i}: {epoch_loss/i:.3f}')
n = (i+1)*bs
return epoch_loss/n, correct/n
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment