Skip to content

Instantly share code, notes, and snippets.

@codekansas
Created November 1, 2018 10:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save codekansas/96f7947fd4dfd3672e227bef9fab1988 to your computer and use it in GitHub Desktop.
Save codekansas/96f7947fd4dfd3672e227bef9fab1988 to your computer and use it in GitHub Desktop.
Implementation of the transformer block used by BERT
#!/usr/bin/env python3
"""Implementation of the transformer block used by BERT.
I saw an excellent implementation of the complete BERT model here:
https://github.com/codertimo/BERT-pytorch
I re-wrote a simplified version of the transformer block below. This was mainly
for my own understanding (so that I could get a grasp of the dimensions and
how the whole attention mechanism works), but I tried to document it pretty
thoroughly so that other people can understand it without having to go too far
into the weeds. The training task at the bottom is just a proof-of-concept,
where the model learns to output the input sequence.
"""
import math
import torch
from torch import (
nn,
optim,
Tensor,
)
from torch.nn import functional as F
class GELU(nn.Module):
"""Defines the Gaussian Error Linear Unit (GELU) activation function.
Input:
float tensor of any shape
Output:
float tensor with the same shape.
"""
def forward(self, x: Tensor) -> Tensor:
a = math.sqrt(2 / math.pi)
b = 0.044715
return 0.5 * x * (1 + torch.tanh(a * (x + b * torch.pow(x, 3))))
class TwoLayerLinear(nn.Module):
"""Defines a module with two linear layers, with dropout.
Args:
num_input: int, number of input dimensions.
num_hidden: int, number of hidden dimensons (between first and second).
dropout: float, the dropout rate.
Input:
x: float, (batch_size, time_steps, num_input)
Output:
float, (batch_size, time_steps, num_input)
"""
def __init__(self,
num_input: int,
num_hidden: int,
dropout: float=0.1) -> None:
super(TwoLayerLinear, self).__init__()
self.num_input = num_input
self.num_hidden = num_hidden
self.dropout = dropout
self.first_layer = nn.Linear(num_input, num_hidden)
self.second_layer = nn.Linear(num_hidden, num_input)
self.dropout_layer = nn.Dropout(dropout)
self.activation = GELU()
def forward(self, x: Tensor) -> Tensor:
x = self.first_layer(x)
x = self.activation(x)
x = self.dropout_layer(x)
x = self.second_layer(x)
return x
class Encoder(nn.Module):
"""Defines a general attention encoder.
Params:
num_input: int, number of input dimensions.
num_heads: int, number of attention heads to encode.
num_dimensions: int, number of encoding dimensions.
Input:
x: float, (batch_size, time_steps, num_input)
Output:
float, (batch_size, num_heads, time_steps, num_dimensions)
"""
def __init__(self,
num_input: int,
num_heads: int,
num_dimensions: int) -> None:
super(Encoder, self).__init__()
self.num_input = num_input
self.num_heads = num_heads
self.num_dimensions = num_dimensions
self.layer = nn.Linear(num_input, num_heads * num_dimensions)
def forward(self, x: Tensor) -> Tensor:
batch_size, time_steps, _ = x.size()
shape = (batch_size, time_steps, self.num_heads, self.num_dimensions)
x = self.layer(x).view(*shape)
return x.transpose(1, 2)
class AttentionLayer(nn.Module):
"""Defines a multi-headed scaled dot-product attention model.
Params:
num_input: int, number of input dimensions.
num_heads: int, number of attention heads to use.
num_hidden: int, number of hidden dimensions.
num_key_dims: int, number of dimensions in the key encoder. Defaults to
num_hidden.
num_value_dims: int, number of dimensions in the value encoder. Defaults
to num_hidden.
dropout: float, dropout to apply to attention weights.
Input:
x: float, (batch_size, time_steps, num_input)
mask: byte, (batch_size, time_steps)
Output:
float, (batch_size, time_steps, num_hidden)
"""
def __init__(self,
num_input: int,
num_heads: int,
num_hidden: int,
num_key_dims: int=None,
num_value_dims: int=None,
dropout: float=0.1) -> None:
super(AttentionLayer, self).__init__()
num_key_dims = num_key_dims or num_hidden
num_value_dims = num_value_dims or num_hidden
self.num_input = num_input
self.num_heads = num_heads
self.num_hidden = num_hidden
self.num_key_dims = num_key_dims
self.num_value_dims = num_value_dims
self.dropout = dropout
self.scale = math.sqrt(num_key_dims)
self.query_layer = Encoder(num_input, num_heads, num_key_dims)
self.key_layer = Encoder(num_input, num_heads, num_key_dims)
self.value_layer = Encoder(num_input, num_heads, num_value_dims)
self.dropout_layer = nn.Dropout(dropout)
self.decoder_layer = nn.Linear(num_heads * num_value_dims, num_hidden)
def forward(self, x: Tensor, mask: Tensor=None) -> Tensor:
batch_size, time_steps, _ = x.size()
# (batch_size, num_heads, time_steps, num_key_dims)
query = self.query_layer(x)
# (batch_size, num_heads, time_steps, num_key_dims)
key = self.key_layer(x)
# (batch_size, num_heads, time_steps, num_value_dims)
value = self.value_layer(x)
# (batch_size, num_heads, time_steps, time_steps)
logits = torch.matmul(query, key.transpose(-1, -2)) / self.scale
if mask is not None:
# (batch_size, 1, time_steps, time_steps)
mask = mask.unsqueeze(1).repeat(1, time_steps, 1).unsqueeze(1)
logits = logits.masked_fill(mask, -1e9)
softmax_weights = self.dropout_layer(F.softmax(logits, dim=-1))
# (batch_size, num_heads, time_steps, num_value_dims)
values = torch.matmul(softmax_weights, value)
# (batch_size, time_steps, num_heads * num_value_dims)
values = values.transpose(1, 2).contiguous().view(
batch_size,
time_steps,
self.num_heads * self.num_value_dims,
)
return self.decoder_layer(values)
class LayerNorm(nn.Module):
"""Defines a layer normalization layer.
See "Layer Normalization" (Ba et. al., 2016) for more details.
Args:
features: int, the number of input features.
eps: float, epsilon parameter (to avoid divide by zero).
Input:
float, (..., features)
Output:
float, (..., features)
"""
def __init__(self, features: int, eps: float=1e-6) -> None:
super().__init__()
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x: Tensor) -> Tensor:
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
class Transformer(nn.Module):
"""Defines the transformer used by BERT.
This transformer looks at the left and right context for a word to try to
disambiguate it's meaning. It can be used for a variety of NLP tasks.
Input:
x: float, (batch_size, time_steps, num_hidden)
mask: byte, (batch_size, time_steps), where masked dims are nonzero.
Output:
float, (batch_size, time_steps, num_hidden)
"""
def __init__(self,
num_hidden: int,
num_heads: int,
num_linear_hidden: int,
dropout: float=0.1) -> None:
super(Transformer, self).__init__()
self.num_hidden = num_hidden
self.num_heads = num_heads
self.num_linear_hidden = num_linear_hidden
self.dropout = dropout
self.attention = AttentionLayer(num_hidden, num_heads, num_hidden)
self.attention_norm = LayerNorm(num_hidden)
self.linear = TwoLayerLinear(num_hidden, num_linear_hidden)
self.linear_norm = LayerNorm(num_hidden)
self.dropout_layer = nn.Dropout(dropout)
def forward(self, x: Tensor, mask: Tensor=None) -> Tensor:
x = self.dropout_layer(self.attention(self.attention_norm(x), mask)) + x
x = self.dropout_layer(self.linear(self.linear_norm(x))) + x
return x
if __name__ == '__main__':
model = Transformer(20, 3, 80)
optimizer = optim.Adam(model.parameters())
loss_function = nn.L1Loss()
for _ in range(1000):
input = torch.randn(128, 10, 20)
optimizer.zero_grad()
output = model(input)
loss = loss_function(output, input)
print(loss)
loss.backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment