Skip to content

Instantly share code, notes, and snippets.

def forward(self, inp, targ=None):
bs,sl = inp.size()
hid = self.initHidden(bs)
emb = self.emb_enc_drop(self.emb_enc(inp))
enc_out, hid = self.encoder(emb, hid)
hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous()
hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid))
dec_inp = inp.new_zeros(bs).long() + self.bos_idx