Skip to content

Instantly share code, notes, and snippets.

@armheb
Created May 14, 2019 11:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save armheb/e86adc8b21af599de175673a32f0e188 to your computer and use it in GitHub Desktop.
Save armheb/e86adc8b21af599de175673a32f0e188 to your computer and use it in GitHub Desktop.
def forward(self, inp, targ=None):
bs,sl = inp.size()
hid = self.initHidden(bs)
emb = self.emb_enc_drop(self.emb_enc(inp))
enc_out, hid = self.encoder(emb, hid)
hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous()
hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid))
dec_inp = inp.new_zeros(bs).long() + self.bos_idx
res = []
enc_att = self.enc_att(enc_out)
for i in range(self.max_len):
hid_att = self.hid_att(hid[-1])
u = torch.tanh(enc_att + hid_att[:,None])
attn_wgts = F.softmax(u @ self.V, 1)
ctx = (attn_wgts[...,None] * enc_out).sum(1)
emb = self.emb_dec(dec_inp)
outp, hid = self.decoder(torch.cat([emb, ctx], 1)[:,None], hid)
outp = self.out(self.out_drop(outp[:,0]))
res.append(outp)
dec_inp = outp.data.max(1)[1]
if (dec_inp==self.pad_idx).all(): break
if (targ is not None) and (random.random()<self.pr_force):
if i>=targ.shape[1]: break
dec_inp = targ[:,i]
return torch.stack(res, dim=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment