Skip to content

Instantly share code, notes, and snippets.

@anna-hope
Last active April 22, 2019 06:05
Show Gist options
  • Save anna-hope/59415d349e8755b105e8aa21e803ef9b to your computer and use it in GitHub Desktop.
Save anna-hope/59415d349e8755b105e8aa21e803ef9b to your computer and use it in GitHub Desktop.
Structured Self-Attention in PyTorch (Lin et al. 2017)
# Implementation of Structured Self-Attention mechanism
# from Lin et al. 2017 (https://arxiv.org/pdf/1703.03130.pdf)
# Anton Melnikov
import torch
import torch.nn as nn
class StructuredAttention(nn.Module):
def __init__(self, *, input_dim: int, hidden_dim: int, attention_hops: int):
super().init()
self.w1 = nn.Parameter(torch.randn(size=(hidden_dim, input_dim)))
self.w2 = nn.Parameter(torch.randn(attention_hops, hidden_dim))
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
X = self.w1 @ hidden_states.transpose(2, 1)
X = torch.tanh(X)
a = torch.softmax((self.w2 @ X), dim=-1)
m = a @ hidden_states
return m, a
def get_attention_penalty(attention_matrix: torch.Tensor):
identity = torch.eye(attention_matrix.shape[1])
p = attention_matrix @ attention_matrix.transpose(2, 1) - identity
p = torch.norm(p)
return p
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment