Created
March 22, 2018 15:21
-
-
Save Arseny-N/b448daa7f4840ba12850dafc25215333 to your computer and use it in GitHub Desktop.
Code for pytorch forums question
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import collections | |
class Dummy(nn.Module): | |
def forward(self, x): return x | |
class Base(nn.Module): | |
def nshape(self, *args): | |
return self._nshape_base(*args) | |
def _nshape_base(self, nin, nh, nout): | |
return { | |
# inp oup | |
'rnn1' : (nin, nh), | |
'fc1' : (nh*2, nh), | |
'rnn2' : (nh, nh), | |
'fc2' : (nh*2, nout) | |
} | |
def _get_shape(self, *args): | |
shape = self.nshape(*args) | |
return [ | |
# rnns | |
shape['inputs'][0], shape['outputs'][0], | |
shape['inputs'][2], shape['outputs'][2], | |
# fcs | |
shape['inputs'][1], shape['outputs'][1], | |
shape['inputs'][3], shape['outputs'][3], | |
] | |
def __init__(self, nIn, nHidden, nOut): | |
super(Base, self).__init__() | |
sh = self._get_shape(nIn, nHidden, nOut) | |
rnns, embs = [], [] | |
for i, j in [(0,1), (2, 3)]: | |
if None in (sh[i], sh[j]): | |
rnn = Dummy() | |
else: | |
rnn = nn.LSTM(sh[i], sh[j], bidirectional=True) | |
rnns.append(rnn) | |
for i, j in [(4,5), (6, 7)]: | |
if None in (sh[i], sh[j]): | |
emb = Dummy() | |
else: | |
emb = nn.Linear(sh[i], sh[j]) | |
embs.append(emb) | |
self.rnns = nn.ModuleList(rnns) | |
self.embs = nn.ModuleList(embs) | |
self.nh = nHidden | |
self.nin = nIn | |
self.nout = nOut | |
def _forward(self, ix, inp, to_fc=None, drop_fc=False): | |
if isinstance(inp, (list, tuple)): | |
inp = torch.cat(tuple(inp), dim=2) | |
rnn, _ = self.rnns[ix](inp) | |
fc_inp = torch.cat((rnn, *to_fc), dim=2) \ | |
if to_fc else rnn | |
T, b, h = fc_inp.size() | |
#print('I', x.size()) | |
fc = None | |
if not drop_fc: | |
fc_inp = fc_inp.view(T * b, h) | |
fc = self.embs[ix](fc_inp) | |
fc = fc.view(T, b, -1) | |
return fc, rnn | |
def forward(self, x): | |
x, _ = self._forward(0, x) | |
x, _ = self._forward(1, x) | |
return x | |
class BaseLSTM(Base): | |
pass | |
# | |
# A lot of stuff omitted here .... | |
# | |
# | |
# Attention | |
# | |
class Attention(nn.Module): | |
"""Dot global attention from https://arxiv.org/abs/1508.04025 | |
Input :: | |
x : batch x dim | |
ctx : batch x seq x dim | |
Output :: | |
scores : batch x seq | |
""" | |
def __init__(self, dim): | |
super().__init__() | |
self.linear = nn.Linear(dim*2, dim, bias=False) | |
def forward(self, x, context): | |
assert x.size(0) == context.size(0), \ | |
f' x: {x.size()} ctx : {context.size()} I' # x: batch x dim | |
assert x.size(1) == context.size(2), \ | |
f' x: {x.size()} ctx : {context.size()} II' # context: batch x seq x dim | |
attn = F.softmax( | |
context.bmm( | |
x.unsqueeze(2) # bsz x dim x 1 | |
) # bsz x seq x 1 | |
.squeeze(2) # bsz x seq | |
, dim = 1) | |
weighted_context = attn.unsqueeze(1) # bsz x 1 x seq | |
weighted_context = weighted_context.bmm(context) # bsz x 1 x dim | |
weighted_context = weighted_context.squeeze(1) # bsz x dim | |
o = self.linear(torch.cat((x, weighted_context), 1)) | |
return F.tanh(o) | |
class AttnLSTM(nn.Module): | |
def __init__(self, d_inp, d_hidden): | |
super().__init__() | |
#self.rnn = nn.LSTM(d_inp, d_hidden) | |
self.rnn = nn.GRU(d_inp, d_hidden) | |
self.d_hidden = d_hidden | |
self.attn = Attention(d_hidden) | |
def init_hidden(self, bsz): | |
cuda = True | |
tt = torch.cuda if cuda else torch # use cuda tensor or not | |
if isinstance(self.rnn, nn.LSTM): | |
# create initial hidden state and initial cell state | |
h = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_()) | |
c = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_()) | |
return (h, c) | |
else: # GRU | |
h = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_()) | |
return h | |
def forward(self, xs, context): | |
# xs ~ seq x batch x dim | |
o = [] | |
hidden = self.init_hidden(xs.size(1)) | |
for x in xs: | |
res, hidden = self.rnn(x.unsqueeze(0), hidden) | |
o.append(self.attn(res.squeeze(0), context)) | |
return torch.stack(o, 0) | |
class AttentionLSTM1(Base): | |
def nshape(self, nin, nh, nout): | |
return { | |
'rnn1' : (None, None), 'fc1' : (None,None), | |
'rnn2' : (nin, nh), 'fc2' : (nh*2, nout) | |
} | |
def __init__(self, *args): | |
super().__init__(*args) | |
self.rnn1 = AttnLSTM(self.nin, self.nin) | |
def forward(self, inp): | |
rnn1 = self.rnn1(inp, context=inp.transpose(0, 1)) | |
fc2, rnn2 = self._forward(1, rnn1) | |
return fc2 | |
class AttentionLSTM2(Base): | |
def nshape(self, nin, nh, nout): | |
return { | |
'rnn1' : (nin, nh), 'fc1' : (None, None), | |
'rnn2' : (None, None), 'fc2' : (None,None), | |
} | |
def __init__(self, *args): | |
super().__init__(*args) | |
self.rnn2 = AttnLSTM(self.nh*2, self.nh*2) | |
self.fc2 = nn.Linear(self.nh*2, self.nout) | |
def forward(self, inp): | |
fc1, rnn1 = self._forward(0, inp, drop_fc=True) | |
rnn2 = self.rnn2(rnn1, context=rnn1.transpose(0, 1)) | |
return self.fc2(rnn2) | |
class AttentionLSTM3(Base): | |
def nshape(self, nin, nh, nout): | |
return { | |
'rnn1' : (nin, nh), 'fc1' : (None, None), | |
'rnn2' : (nh*2, nh), 'fc2' : (nh*2, nout), | |
} | |
def __init__(self, *args): | |
super().__init__(*args) | |
self.w = nn.Linear(128, 128) | |
def attention(self, ctx, x): | |
# ctx/x ~ seq x bs x dim/dim' | |
ctx = ctx.transpose(0, 1) | |
x = x.transpose(0, 1) | |
# ctx/x ~ bs x seq x dim/dim' | |
scores = self.w(ctx).bmm(ctx.transpose(1, 2)) # bs x seq x seq | |
scores = F.softmax(scores, dim=1) | |
print('SC', scores.size()) | |
res = scores.bmm(x) # bs x eq x dim' | |
print('RE', res.size()) | |
return res | |
def forward(self, inp): | |
fc1, rnn1 = self._forward(0, inp, drop_fc=True) | |
att = self.attention(ctx=inp, x=rnn1) | |
fc2, rnn2 = self._forward(1, att) | |
return fc2 | |
class AttentionLSTM4(Base): | |
def nshape(self, nin, nh, nout): | |
return { | |
'rnn1' : (nin, nh), 'fc1' : (None, None), | |
'rnn2' : (nh*2, nh), 'fc2' : (nh*2, nout), | |
} | |
def __init__(self, *args): | |
super().__init__(*args) | |
s = 512 | |
self.w = nn.Linear(s, s) | |
def scores(self, ctx, ctx_bch): | |
# ctx ~ seq x bs x dim | |
# ctx_bch ~ seq_bch x bs x dim | |
ctx = ctx.transpose(0, 1) | |
ctx_bch = ctx_bch.permute(1, 2, 0) | |
# ctx ~ bs x seq x dim | |
# ctx_bch ~ bs x dim x seq_bch | |
w = self.w(ctx) | |
scores = w.bmm(ctx_bch) # bs x seq x seq_bch | |
return F.softmax(scores, dim=1) | |
# res = scores.bmm(x) # [bs x seq_bch x dim] | |
def forward(self, inp): | |
_, rnn1 = self._forward(0, inp, drop_fc=True) | |
weighted = [] | |
for xs in rnn1.split(16, dim=0): | |
scores = self.scores(ctx=rnn1, ctx_bch=xs) | |
w = scores.transpose(1, 2).bmm(rnn1.transpose(0, 1)) | |
weighted.append(w) | |
print('OK') | |
res = torch.cat(weighted, dim=1) | |
fc2, rnn2 = self._forward(1, res) | |
return fc2 | |
class WinAttentionLSTM(Base): | |
def nshape(self, nin, nh, nout): | |
return { | |
'rnn1' : (nin, nh), 'fc1' : (None, None), | |
'rnn2' : (nh*2, nh), 'fc2' : (nh*2, nout), | |
} | |
def __init__(self, *args): | |
super().__init__(*args) | |
s = 512 | |
self.w = nn.Linear(s, s) | |
def scores(self, ctx): | |
# ctx ~ seq_bch x bs x dim | |
ctx = ctx.transpose(0, 1) | |
# ctx ~ bs x seq_bc x dim | |
scores = self.w(ctx).bmm(ctx.transpose(1, 2)) # bs x seq_bc x seq_bc | |
scores = F.softmax(scores, dim=1) | |
return scores | |
def forward(self, inp): | |
_, rnn1 = self._forward(0, inp, drop_fc=True) | |
weighted = [] | |
for xs in rnn1.split(2, dim=0): | |
scores = self.scores(ctx=xs) | |
w = scores.bmm(xs.transpose(0, 1)) | |
weighted.append(w) | |
print('OK') | |
res = torch.cat(weighted, dim=1) | |
fc2, rnn2 = self._forward(1, res) | |
return fc2 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment