Skip to content

Instantly share code, notes, and snippets.

@anna-hope
Last active April 1, 2018 18:10
Show Gist options
  • Save anna-hope/a3b4e2e7ff127121c937020c5cd87ace to your computer and use it in GitHub Desktop.
Save anna-hope/a3b4e2e7ff127121c937020c5cd87ace to your computer and use it in GitHub Desktop.
Hierarchical Attention Network (Yang et al. 2016) in PyTorch
# Implementation of the Hierarchical Attention Network from Yang et al. 2016
# https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf
# Anton Melnikov
import torch
from torch import nn
import torch.nn.functional as F
class SequenceClassifierAttention(nn.Module):
# this follows the word-level attention from Yang et al. 2016
# https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf
# we will be using the same module for sentence-level attention
def __init__(self, n_hidden, *, batch_first=False):
super().__init__()
self.mlp = nn.Linear(n_hidden, n_hidden)
# word context vector
self.u_w = nn.Parameter(torch.rand(n_hidden))
self.batch_first = batch_first
def forward(self, X):
if not self.batch_first:
# make the input (batch_size, timesteps, features)
X = X.transpose(1, 0)
# get the hidden representation of the sequence
u_it = F.tanh(self.mlp(X))
# get attention weights for each timestep
alpha = F.softmax(torch.matmul(u_it, self.u_w), dim=1)
# get the weighted representation of the sequence
# and then get the sum
# (add a size 1 dimension to alpha so each time step's features could be scaled)
weighted_sequence = X * alpha.unsqueeze(2)
out = torch.sum(weighted_sequence, dim=1)
return out, alpha
class HierarchicalAttentionNetwork(nn.Module):
def __init__(self, *, n_hidden: int, n_classes: int,
vocab_size, embedding_dim, embedding_weights=None,
padding_idx=None):
super().__init__()
self.embed = nn.Embedding(vocab_size, embedding_dim,
padding_idx=padding_idx)
if embedding_weights is not None:
self.embed.data.weight.copy_(embedding_weights)
self.word_encoder = nn.GRU(embedding_dim, n_hidden, bidirectional=True,
batch_first=True)
self.word_attention = SequenceClassifierAttention(n_hidden * 2,
batch_first=True)
self.sentence_encoder = nn.GRU(n_hidden * 2, n_hidden, bidirectional=True,
batch_first=True)
self.sentence_attention = SequenceClassifierAttention(n_hidden * 2,
batch_first=True)
self.out = nn.Linear(n_hidden * 2, n_classes)
def forward(self, X):
batch_size, n_sents, n_words = X.shape
encoded_sents_word = []
sentence_alphas = []
# there might be a more efficient way of encoding the sentences
# than a sentence at a time
for i in range(n_sents):
sentence_words = X[:,i,:]
words_embedded = self.embed(sentence_words)
words_encoded, _ = self.word_encoder(words_embedded)
sentence_vector, sentence_alpha = self.word_attention(words_encoded)
# unsqueeze the sentence vector to insert dummy "sentence timestep" dimension
# so that we can concatenate on it
encoded_sents_word.append(sentence_vector.unsqueeze(1))
sentence_alphas.append(sentence_alpha)
encoded_sents_word = torch.cat(encoded_sents_word, dim=1)
encoded_sents, _ = self.sentence_encoder(encoded_sents_word)
encoded_docs, document_alpha = self.sentence_attention(encoded_sents)
out = self.out(encoded_docs)
return out, sentence_alphas, document_alpha
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment