Created
October 27, 2018 08:56
-
-
Save lyger/b93e653cdfbf100898a4f0c61c607595 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Tuple, Union | |
from .autopackingrnn import AutoPackingRNN | |
FloatType = Union[torch.FloatTensor, torch.cuda.FloatTensor] | |
LongType = Union[torch.LongTensor, torch.cuda.LongTensor] | |
class AAE(nn.Module): | |
''' | |
Alignment Auto-Encoder. | |
Learns an alignment by learning to reconstruct each input from the attended | |
representation of the other input. | |
''' | |
def __init__(self, emb_dims: int, enc_dims: int, dec_dims: int): | |
''' | |
:param emb_dims: Dimensions of input embeddings (and final outputs). | |
:param enc_dims: Encoder dimensions. | |
:param dec_dims: Decoder dimensions. | |
''' | |
super().__init__() | |
self.encoder = AutoPackingRNN(nn.LSTM(emb_dims, enc_dims, bidirectional=True), True) | |
self.decoder = AutoPackingRNN(nn.LSTM(emb_dims, dec_dims, bidirectional=True), True) | |
self.fc = nn.Linear(2 * dec_dims, emb_dims) | |
def forward(self, stup1: Tuple[FloatType, LongType], stup2: Tuple[FloatType, LongType]): | |
''' | |
:param stup1: Tuple of (batch, sequence, embedding) inputs and (batch) lengths. | |
:param stup2: Tuple of (batch, sequence, embedding) inputs and (batch) lengths. | |
:returns: e1: Predicted embeddings based on aligned sequence 2. | |
e2: Predicted embeddings based on aligned sequence 1. | |
''' | |
s1, lens1 = stup1 | |
s2, lens2 = stup2 | |
# Alignment matrix. | |
# (batch, sequence 1, sequence 2) | |
M = self.align(stup1, stup2) | |
# Attended representation based on other sequence. | |
# (batch, sequence, encoder * 2) | |
a1 = F.softmax(M, dim=2).bmm(s2) | |
a2 = F.softmax(M.transpose(1, 2), dim=2).bmm(s1) | |
# Decoder outputs. | |
# (batch, sequence, decoder * 2) | |
d1 = self.decoder(a1, lens1) | |
d2 = self.decoder(a2, lens2) | |
# Predicted output embeddings. | |
# (batch, sequence, embedding) | |
e1 = self.fc(d1) | |
e2 = self.fc(d2) | |
# Zero out outputs past end of sequence. | |
for i, (l1, l2) in enumerate(zip(lens1, lens2)): | |
e1[i, l1:] = 0 | |
e2[i, l2:] = 0 | |
return e1, e2 | |
def align(self, stup1: Tuple[FloatType, LongType], stup2: Tuple[FloatType, LongType]): | |
''' | |
:param stup1: Tuple of (batch, sequence, embedding) inputs and (batch) lengths. | |
:param stup2: Tuple of (batch, sequence, embedding) inputs and (batch) lengths. | |
:returns: M: Batch of alignment matrices, with indices past the ends of the | |
sequences set to a large negative number for softmaxing. | |
''' | |
s1, lens1 = stup1 | |
s2, lens2 = stup2 | |
# Encoder outputs. | |
# (batch, sequence, encoder * 2) | |
h1 = self.encoder(*stup1) | |
h2 = self.encoder(*stup2) | |
# Alignment matrix. | |
# (batch, sequence 1, sequence 2) | |
M = h1.bmm(h2.transpose(1, 2)) | |
# Ignore past end of sequences. | |
for i, (l1, l2) in enumerate(zip(lens1, lens2)): | |
M[i, l1:] = -1e-5 | |
M[i, :, l2:] = -1e-5 | |
return M |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment