Skip to content

Instantly share code, notes, and snippets.

@lyger
Created October 27, 2018 08:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lyger/b93e653cdfbf100898a4f0c61c607595 to your computer and use it in GitHub Desktop.
Save lyger/b93e653cdfbf100898a4f0c61c607595 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Union
from .autopackingrnn import AutoPackingRNN
FloatType = Union[torch.FloatTensor, torch.cuda.FloatTensor]
LongType = Union[torch.LongTensor, torch.cuda.LongTensor]
class AAE(nn.Module):
'''
Alignment Auto-Encoder.
Learns an alignment by learning to reconstruct each input from the attended
representation of the other input.
'''
def __init__(self, emb_dims: int, enc_dims: int, dec_dims: int):
'''
:param emb_dims: Dimensions of input embeddings (and final outputs).
:param enc_dims: Encoder dimensions.
:param dec_dims: Decoder dimensions.
'''
super().__init__()
self.encoder = AutoPackingRNN(nn.LSTM(emb_dims, enc_dims, bidirectional=True), True)
self.decoder = AutoPackingRNN(nn.LSTM(emb_dims, dec_dims, bidirectional=True), True)
self.fc = nn.Linear(2 * dec_dims, emb_dims)
def forward(self, stup1: Tuple[FloatType, LongType], stup2: Tuple[FloatType, LongType]):
'''
:param stup1: Tuple of (batch, sequence, embedding) inputs and (batch) lengths.
:param stup2: Tuple of (batch, sequence, embedding) inputs and (batch) lengths.
:returns: e1: Predicted embeddings based on aligned sequence 2.
e2: Predicted embeddings based on aligned sequence 1.
'''
s1, lens1 = stup1
s2, lens2 = stup2
# Alignment matrix.
# (batch, sequence 1, sequence 2)
M = self.align(stup1, stup2)
# Attended representation based on other sequence.
# (batch, sequence, encoder * 2)
a1 = F.softmax(M, dim=2).bmm(s2)
a2 = F.softmax(M.transpose(1, 2), dim=2).bmm(s1)
# Decoder outputs.
# (batch, sequence, decoder * 2)
d1 = self.decoder(a1, lens1)
d2 = self.decoder(a2, lens2)
# Predicted output embeddings.
# (batch, sequence, embedding)
e1 = self.fc(d1)
e2 = self.fc(d2)
# Zero out outputs past end of sequence.
for i, (l1, l2) in enumerate(zip(lens1, lens2)):
e1[i, l1:] = 0
e2[i, l2:] = 0
return e1, e2
def align(self, stup1: Tuple[FloatType, LongType], stup2: Tuple[FloatType, LongType]):
'''
:param stup1: Tuple of (batch, sequence, embedding) inputs and (batch) lengths.
:param stup2: Tuple of (batch, sequence, embedding) inputs and (batch) lengths.
:returns: M: Batch of alignment matrices, with indices past the ends of the
sequences set to a large negative number for softmaxing.
'''
s1, lens1 = stup1
s2, lens2 = stup2
# Encoder outputs.
# (batch, sequence, encoder * 2)
h1 = self.encoder(*stup1)
h2 = self.encoder(*stup2)
# Alignment matrix.
# (batch, sequence 1, sequence 2)
M = h1.bmm(h2.transpose(1, 2))
# Ignore past end of sequences.
for i, (l1, l2) in enumerate(zip(lens1, lens2)):
M[i, l1:] = -1e-5
M[i, :, l2:] = -1e-5
return M
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment