Skip to content

Instantly share code, notes, and snippets.

@Arseny-N
Created March 22, 2018 15:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Arseny-N/b448daa7f4840ba12850dafc25215333 to your computer and use it in GitHub Desktop.
Save Arseny-N/b448daa7f4840ba12850dafc25215333 to your computer and use it in GitHub Desktop.
Code for pytorch forums question
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import collections
class Dummy(nn.Module):
def forward(self, x): return x
class Base(nn.Module):
def nshape(self, *args):
return self._nshape_base(*args)
def _nshape_base(self, nin, nh, nout):
return {
# inp oup
'rnn1' : (nin, nh),
'fc1' : (nh*2, nh),
'rnn2' : (nh, nh),
'fc2' : (nh*2, nout)
}
def _get_shape(self, *args):
shape = self.nshape(*args)
return [
# rnns
shape['inputs'][0], shape['outputs'][0],
shape['inputs'][2], shape['outputs'][2],
# fcs
shape['inputs'][1], shape['outputs'][1],
shape['inputs'][3], shape['outputs'][3],
]
def __init__(self, nIn, nHidden, nOut):
super(Base, self).__init__()
sh = self._get_shape(nIn, nHidden, nOut)
rnns, embs = [], []
for i, j in [(0,1), (2, 3)]:
if None in (sh[i], sh[j]):
rnn = Dummy()
else:
rnn = nn.LSTM(sh[i], sh[j], bidirectional=True)
rnns.append(rnn)
for i, j in [(4,5), (6, 7)]:
if None in (sh[i], sh[j]):
emb = Dummy()
else:
emb = nn.Linear(sh[i], sh[j])
embs.append(emb)
self.rnns = nn.ModuleList(rnns)
self.embs = nn.ModuleList(embs)
self.nh = nHidden
self.nin = nIn
self.nout = nOut
def _forward(self, ix, inp, to_fc=None, drop_fc=False):
if isinstance(inp, (list, tuple)):
inp = torch.cat(tuple(inp), dim=2)
rnn, _ = self.rnns[ix](inp)
fc_inp = torch.cat((rnn, *to_fc), dim=2) \
if to_fc else rnn
T, b, h = fc_inp.size()
#print('I', x.size())
fc = None
if not drop_fc:
fc_inp = fc_inp.view(T * b, h)
fc = self.embs[ix](fc_inp)
fc = fc.view(T, b, -1)
return fc, rnn
def forward(self, x):
x, _ = self._forward(0, x)
x, _ = self._forward(1, x)
return x
class BaseLSTM(Base):
pass
#
# A lot of stuff omitted here ....
#
#
# Attention
#
class Attention(nn.Module):
"""Dot global attention from https://arxiv.org/abs/1508.04025
Input ::
x : batch x dim
ctx : batch x seq x dim
Output ::
scores : batch x seq
"""
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim*2, dim, bias=False)
def forward(self, x, context):
assert x.size(0) == context.size(0), \
f' x: {x.size()} ctx : {context.size()} I' # x: batch x dim
assert x.size(1) == context.size(2), \
f' x: {x.size()} ctx : {context.size()} II' # context: batch x seq x dim
attn = F.softmax(
context.bmm(
x.unsqueeze(2) # bsz x dim x 1
) # bsz x seq x 1
.squeeze(2) # bsz x seq
, dim = 1)
weighted_context = attn.unsqueeze(1) # bsz x 1 x seq
weighted_context = weighted_context.bmm(context) # bsz x 1 x dim
weighted_context = weighted_context.squeeze(1) # bsz x dim
o = self.linear(torch.cat((x, weighted_context), 1))
return F.tanh(o)
class AttnLSTM(nn.Module):
def __init__(self, d_inp, d_hidden):
super().__init__()
#self.rnn = nn.LSTM(d_inp, d_hidden)
self.rnn = nn.GRU(d_inp, d_hidden)
self.d_hidden = d_hidden
self.attn = Attention(d_hidden)
def init_hidden(self, bsz):
cuda = True
tt = torch.cuda if cuda else torch # use cuda tensor or not
if isinstance(self.rnn, nn.LSTM):
# create initial hidden state and initial cell state
h = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_())
c = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_())
return (h, c)
else: # GRU
h = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_())
return h
def forward(self, xs, context):
# xs ~ seq x batch x dim
o = []
hidden = self.init_hidden(xs.size(1))
for x in xs:
res, hidden = self.rnn(x.unsqueeze(0), hidden)
o.append(self.attn(res.squeeze(0), context))
return torch.stack(o, 0)
class AttentionLSTM1(Base):
def nshape(self, nin, nh, nout):
return {
'rnn1' : (None, None), 'fc1' : (None,None),
'rnn2' : (nin, nh), 'fc2' : (nh*2, nout)
}
def __init__(self, *args):
super().__init__(*args)
self.rnn1 = AttnLSTM(self.nin, self.nin)
def forward(self, inp):
rnn1 = self.rnn1(inp, context=inp.transpose(0, 1))
fc2, rnn2 = self._forward(1, rnn1)
return fc2
class AttentionLSTM2(Base):
def nshape(self, nin, nh, nout):
return {
'rnn1' : (nin, nh), 'fc1' : (None, None),
'rnn2' : (None, None), 'fc2' : (None,None),
}
def __init__(self, *args):
super().__init__(*args)
self.rnn2 = AttnLSTM(self.nh*2, self.nh*2)
self.fc2 = nn.Linear(self.nh*2, self.nout)
def forward(self, inp):
fc1, rnn1 = self._forward(0, inp, drop_fc=True)
rnn2 = self.rnn2(rnn1, context=rnn1.transpose(0, 1))
return self.fc2(rnn2)
class AttentionLSTM3(Base):
def nshape(self, nin, nh, nout):
return {
'rnn1' : (nin, nh), 'fc1' : (None, None),
'rnn2' : (nh*2, nh), 'fc2' : (nh*2, nout),
}
def __init__(self, *args):
super().__init__(*args)
self.w = nn.Linear(128, 128)
def attention(self, ctx, x):
# ctx/x ~ seq x bs x dim/dim'
ctx = ctx.transpose(0, 1)
x = x.transpose(0, 1)
# ctx/x ~ bs x seq x dim/dim'
scores = self.w(ctx).bmm(ctx.transpose(1, 2)) # bs x seq x seq
scores = F.softmax(scores, dim=1)
print('SC', scores.size())
res = scores.bmm(x) # bs x eq x dim'
print('RE', res.size())
return res
def forward(self, inp):
fc1, rnn1 = self._forward(0, inp, drop_fc=True)
att = self.attention(ctx=inp, x=rnn1)
fc2, rnn2 = self._forward(1, att)
return fc2
class AttentionLSTM4(Base):
def nshape(self, nin, nh, nout):
return {
'rnn1' : (nin, nh), 'fc1' : (None, None),
'rnn2' : (nh*2, nh), 'fc2' : (nh*2, nout),
}
def __init__(self, *args):
super().__init__(*args)
s = 512
self.w = nn.Linear(s, s)
def scores(self, ctx, ctx_bch):
# ctx ~ seq x bs x dim
# ctx_bch ~ seq_bch x bs x dim
ctx = ctx.transpose(0, 1)
ctx_bch = ctx_bch.permute(1, 2, 0)
# ctx ~ bs x seq x dim
# ctx_bch ~ bs x dim x seq_bch
w = self.w(ctx)
scores = w.bmm(ctx_bch) # bs x seq x seq_bch
return F.softmax(scores, dim=1)
# res = scores.bmm(x) # [bs x seq_bch x dim]
def forward(self, inp):
_, rnn1 = self._forward(0, inp, drop_fc=True)
weighted = []
for xs in rnn1.split(16, dim=0):
scores = self.scores(ctx=rnn1, ctx_bch=xs)
w = scores.transpose(1, 2).bmm(rnn1.transpose(0, 1))
weighted.append(w)
print('OK')
res = torch.cat(weighted, dim=1)
fc2, rnn2 = self._forward(1, res)
return fc2
class WinAttentionLSTM(Base):
def nshape(self, nin, nh, nout):
return {
'rnn1' : (nin, nh), 'fc1' : (None, None),
'rnn2' : (nh*2, nh), 'fc2' : (nh*2, nout),
}
def __init__(self, *args):
super().__init__(*args)
s = 512
self.w = nn.Linear(s, s)
def scores(self, ctx):
# ctx ~ seq_bch x bs x dim
ctx = ctx.transpose(0, 1)
# ctx ~ bs x seq_bc x dim
scores = self.w(ctx).bmm(ctx.transpose(1, 2)) # bs x seq_bc x seq_bc
scores = F.softmax(scores, dim=1)
return scores
def forward(self, inp):
_, rnn1 = self._forward(0, inp, drop_fc=True)
weighted = []
for xs in rnn1.split(2, dim=0):
scores = self.scores(ctx=xs)
w = scores.bmm(xs.transpose(0, 1))
weighted.append(w)
print('OK')
res = torch.cat(weighted, dim=1)
fc2, rnn2 = self._forward(1, res)
return fc2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment