Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save GallagherCommaJack/3fb796cd0c99cd3ad3da90adaf02c1bf to your computer and use it in GitHub Desktop.
Save GallagherCommaJack/3fb796cd0c99cd3ad3da90adaf02c1bf to your computer and use it in GitHub Desktop.
from math import sqrt
import torch
from fast_transformers.attention_registry import (
AttentionRegistry, EventDispatcherInstance, Float, Optional,
RecurrentAttentionRegistry, RecurrentCrossAttentionRegistry)
from fast_transformers.builders import (RecurrentDecoderBuilder,
RecurrentEncoderBuilder,
TransformerDecoderBuilder,
TransformerEncoderBuilder)
from fast_transformers.events import AttentionEvent, EventDispatcher
from fast_transformers.masking import FullMask, LengthMask
from fast_transformers.recurrent import (RecurrentTranformerEncoderLayer,
RecurrentTransformerDecoderLayer)
from fast_transformers.transformers import (TransformerDecoderLayer,
TransformerEncoderLayer)
from torch.nn import Dropout, LayerNorm, Module
class CV_RTEL(RecurrentTranformerEncoderLayer):
"""Attention to the previous inputs and feed forward with skip connections.
This transformer encoder layer is the recurrent dual of
fast_transformers.transformers.TransformerEncoderLayer . The results should
be identical given the same inputs and a lower triangular mask.
Arguments
---------
attention: The attention implementation to use given as a nn.Module
d_model: The input feature dimensionality
d_ff: The dimensionality of the intermediate features after the
attention (default: d_model*4)
dropout: The dropout rate to apply to the intermediate features
(default: 0.1)
activation: {'relu', 'gelu'} Which activation to use for the feed
forward part of the layer (default: relu)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(
self,
attention,
d_model,
d_ff=None,
dropout=0.1,
activation='gelu',
event_dispatcher='',
):
super().__init__(
attention,
d_model,
d_ff,
dropout,
activation,
event_dispatcher,
)
self.norm3 = LayerNorm(d_model)
self.norm4 = LayerNorm(d_model)
def forward(self, x, state=None):
"""Apply the transformer encoder to the input x using the provided
memory.
Arguments
---------
x: The input features of shape (N, E) where N is the batch size and
E is d_model passed in the constructor
state: The state can vary depending on the attention implementation
"""
x_ = self.norm1(x / torch.max(x))
x_, state = self.attention(x_, x_, x_, state)
x_ = self.norm2(x_ / torch.max(x_))
x += self.dropout(x_)
x_ = self.norm3(x / torch.max(x))
x_ = self.activation(self.linear1(x_))
x_ = self.linear2(self.dropout(x_))
x_ = self.norm4(x_ / torch.max(x_))
x += self.dropout(x_)
return x, state
class CV_REB(RecurrentEncoderBuilder):
def _get_encoder_layer_class(self):
return CV_RTEL
class CV_RTDL(RecurrentTransformerDecoderLayer):
"""Attention to the previous inputs and a preprocessed memory.
This transformer decoder layer is the recurrent dual of
fast_transformers.transformers.TransformerDecoderLayer . The results should
be identical given the same inputs and a lower triangular mask for x_mask.
Arguments
---------
self_attention: The attention implementation to use for self attention
given as a nn.Module
cross_attention: The attention implementation to use for cross
attention given as a nn.Module
d_model: The input feature dimensionality
d_ff: The dimensionality of the intermediate features after the
attention (default: d_model*4)
dropout: The dropout rate to apply to the intermediate features
(default: 0.1)
activation: {'relu', 'gelu'} Which activation to use for the feed
forward part of the layer (default: relu)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(
self,
self_attention,
cross_attention,
d_model,
d_ff=None,
dropout=0.1,
activation='gelu',
event_dispatcher='',
):
super().__init__(
self_attention,
cross_attention,
d_model,
d_ff,
dropout,
activation,
event_dispatcher,
)
self.norm4 = LayerNorm(d_model)
self.norm5 = LayerNorm(d_model)
self.norm6 = LayerNorm(d_model)
def forward(self, x, memory, memory_length_mask=None, state=None):
"""Apply the transformer decoder to the input x and also attend to
memory.
Note the memory mask is assumed to be a full mask.
Arguments
---------
x: The input features of shape (N, E) where N is the batch size and
E is d_model passed in the constructor
memory: A sequence of features (N, S, E) that the input will attend
to. S is the sequence length and E is the same as for x.
memory_length_mask: An implementation of a BaseMask that encodes
how many elements each memory sequence in the
batch consists of.
state: The state varies depending on the attention implementations
but it allows for recurrent implementation.
"""
# Normalize the mask
N = x.shape[0]
L = memory.shape[1]
memory_length_mask = memory_length_mask or \
LengthMask(x.new_full((N,), L, dtype=torch.int64))
# Extract the individual states for the self attention and the cross
# attention
self_state, cross_state = state or [None, None]
# First apply the self attention and add it to the input
x_ = self.norm1(x / torch.max(x))
x_, self_state = self.self_attention(
x_,
x_,
x_,
state=self_state,
)
x_ = self.norm2(x_ / torch.max(x_))
x += self.dropout(x_)
# Secondly apply the cross attention and add it to the previous output
x_ = self.norm3(x / torch.max(x))
x_, cross_state = self.cross_attention(
x_,
memory,
memory,
memory_length_mask,
state=cross_state,
)
x_ = self.norm4(x_ / torch.max(x_))
x += self.dropout(x_)
# Finally run the fully connected part of the layer
x_ = self.norm5(x / torch.max(x))
x_ = self.activation(self.linear1(x_))
x_ = self.linear2(self.dropout(x_))
x_ = self.norm6(x_ / torch.max(x_))
x += self.dropout(x_)
return x, [self_state, cross_state]
class CV_RDB(RecurrentDecoderBuilder):
def _get_decoder_layer_class(self):
return CV_RTDL
class CV_TEL(TransformerEncoderLayer):
"""Self attention and feed forward network with skip connections.
This transformer encoder layer implements the same encoder layer as
PyTorch but is a bit more open for extension by receiving the attention
implementation as a constructor argument.
Arguments
---------
attention: The attention implementation to use given as a nn.Module
d_model: The input feature dimensionality
d_ff: The dimensionality of the intermediate features after the
attention (default: d_model*4)
dropout: The dropout rate to apply to the intermediate features
(default: 0.1)
activation: {'relu', 'gelu'} Which activation to use for the feed
forward part of the layer (default: relu)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(
self,
attention,
d_model,
d_ff=None,
dropout=0.1,
activation="gelu",
event_dispatcher="",
):
super().__init__(
attention,
d_model,
d_ff,
dropout,
activation,
event_dispatcher,
)
self.norm3 = LayerNorm(d_model)
self.norm4 = LayerNorm(d_model)
def forward(self, x, attn_mask=None, length_mask=None):
"""Apply the transformer encoder to the input x.
Arguments
---------
x: The input features of shape (N, L, E) where N is the batch size,
L is the sequence length (padded) and E is d_model passed in the
constructor.
attn_mask: An implementation of fast_transformers.masking.BaseMask
that encodes where each element of x can attend to.
length_mask: An implementation of
fast_transformers.masking.BaseMask that encodes how
many elements each sequence in the batch consists of.
"""
# Normalize the masks
N = x.shape[0]
L = x.shape[1]
attn_mask = attn_mask or FullMask(L, device=x.device)
length_mask = length_mask or \
LengthMask(x.new_full((N,), L, dtype=torch.int64))
# Run self attention and add it to the input
x_ = self.norm1(x / torch.max(x))
x_ = self.attention(
x_,
x_,
x_,
attn_mask=attn_mask,
query_lengths=length_mask,
key_lengths=length_mask,
)
x_ = self.norm2(x_ / torch.max(x_))
x += self.dropout(x_)
# Run the fully connected part of the layer
x_ = self.norm3(x / torch.max(x))
x_ = self.activation(self.linear1(x_))
x_ = self.linear2(x_)
x_ = self.norm4(x_ / torch.max(x_))
x += self.dropout(x_)
return x
class CV_TEB(TransformerEncoderBuilder):
def _get_encoder_layer_class(self):
return CV_TEL
class CV_TDL(TransformerDecoderLayer):
"""The decoder layer from "Attention Is All You Need".
Similar to the encoder layer, this layer implements the decoder that
PyTorch implements but can be used with any attention implementation
because it receives the attention layers as constructor arguments.
Arguments
---------
self_attention: The attention implementation to use for self attention
given as a nn.Module
cross_attention: The attention implementation to use for cross
attention given as a nn.Module
d_model: The input feature dimensionality
d_ff: The dimensionality of the intermediate features after the
attention (default: d_model*4)
dropout: The dropout rate to apply to the intermediate features
(default: 0.1)
activation: {'relu', 'gelu'} Which activation to use for the feed
forward part of the layer (default: relu)
event_dispatcher: str or EventDispatcherpytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(
self,
self_attention,
cross_attention,
d_model,
d_ff=None,
dropout=0.1,
activation="gelu",
event_dispatcher="",
):
super().__init__(
self_attention,
cross_attention,
d_model,
d_ff,
dropout,
activation,
event_dispatcher,
)
self.norm4 = LayerNorm(d_model)
self.norm5 = LayerNorm(d_model)
self.norm6 = LayerNorm(d_model)
def forward(
self,
x,
memory,
x_mask=None,
x_length_mask=None,
memory_mask=None,
memory_length_mask=None,
):
"""Apply the transformer decoder to the input x using the memory
`memory`.
Arguments
---------
x: The input features of shape (N, L, E) where N is the batch size,
L is the sequence length (padded) and E should be the same as
the d_model passed in the constructor.
memory: The memory features of shape (N, L', E) where N is the
batch size, L' is the memory's sequence length (padded) and
E should be the same as the d_model.
x_mask: An implementation of fast_transformers.masking.BaseMask
that encodes where each element of x can attend to in x.
Namely the self attention mask.
x_length_mask: An implementation of a BaseMask that encodes how
many elements each sequence in the batch consists
of.
memory_mask: An implementation of BaseMask that encodes where each
element of x can attend to in the memory. Namely the
cross attention mask.
memory_length_mask: An implementation of a BaseMask that encodes how
many elements each memory sequence in the batch
consists of.
"""
# Normalize the masks
N = x.shape[0]
L = x.shape[1]
L_prime = memory.shape[1]
x_mask = x_mask or FullMask(L, device=x.device)
x_length_mask = x_length_mask or \
LengthMask(x.new_full((N,), L, dtype=torch.int64))
memory_mask = memory_mask or FullMask(L, L_prime, device=x.device)
memory_length_mask = memory_length_mask or \
LengthMask(x.new_full((N,), L_prime, dtype=torch.int64))
# First apply the self attention and add it to the input
x_ = self.norm1(x / torch.max(x))
x_ = self.self_attention(
x_,
x_,
x_,
attn_mask=x_mask,
query_lengths=x_length_mask,
key_lengths=x_length_mask,
)
x_ = self.norm2(x_ / torch.max(x_))
x += self.dropout(x_)
# Secondly apply the cross attention and add it to the previous output
x_ = self.norm3(x / torch.max(x))
x_ = self.cross_attention(
x_,
memory,
memory,
attn_mask=memory_mask,
query_lengths=x_length_mask,
key_lengths=memory_length_mask,
)
x_ = self.norm4(x_ / torch.max(x_))
x += self.dropout(x_)
# Finally run the fully connected part of the layer
x_ = self.norm5(x / torch.max(x_))
x_ = self.activation(self.linear1(x_))
x_ = self.linear2(x_)
x_ = self.norm6(x_ / torch.max(x_))
x += self.dropout(x_)
return x
class CV_TDB(TransformerDecoderBuilder):
def _get_decoder_layer_class(self):
return CV_TDL
class CV_RFA(Module):
"""Implement the full softmax attention with PB-Relax as a recurrent module.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(
self,
softmax_temp=None,
alpha=32.0,
attention_dropout=0.1,
event_dispatcher="",
):
super().__init__()
self.softmax_temp = softmax_temp
self.dropout = Dropout(attention_dropout)
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
self.alpha = alpha
def forward(self, query, key, value, state=None):
# Extract some shapes and compute the temperature
N, H, E = query.shape
_, _, D = value.shape
softmax_temp = self.softmax_temp or 1. / sqrt(E)
# Aggregate the list of keys and values
if state is not None:
keys, values = state
keys = torch.cat([keys, key[:, :, None]], dim=2)
values = torch.cat([values, value[:, :, None]], dim=2)
else:
keys = key[:, :, None]
values = value[:, :, None]
query /= self.alpha
query *= softmax_temp
# Compute the unnormalized attention
QK = torch.einsum("nhe,nhse->nhs", query, keys)
# PB-Relax
QK -= torch.max(QK)
QK *= self.alpha
# Compute the attention and the weighted average
A = self.dropout(torch.softmax(QK, dim=-1))
V = torch.einsum("nhs,nhsd->nhd", A, values)
# Let the world know of the attention matrix
self.event_dispatcher.dispatch(AttentionEvent(self, A))
# Make sure that what we return is contiguous
return V.contiguous(), [keys, values]
# Register the attention implementation so that it becomes available in builders
RecurrentAttentionRegistry.register("cv_full", CV_RFA, [
("softmax_temp", Optional(Float)),
("alpha", Optional(Float, 32.0)),
("attention_dropout", Optional(Float, 0.1)),
("event_dispatcher", Optional(EventDispatcherInstance, "")),
])
class CV_RCFA(Module):
"""Implement the full softmax attention with PB-Relax as a recurrent module.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
alpha: constant for preventing overflow (default: 32)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(
self,
softmax_temp=None,
alpha=32.0,
attention_dropout=0.1,
event_dispatcher="",
):
super().__init__()
self.softmax_temp = softmax_temp
self.dropout = Dropout(attention_dropout)
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
self.alpha = alpha
def forward(self, query, keys, values, key_lengths, state=None):
# Extract some shapes and compute the temperature
N, H, E = query.shape
softmax_temp = self.softmax_temp or 1. / sqrt(E)
# Extract the keys and values either from the arguments or the state
if state is not None:
keys, values = state
# Scale the queries instead of applying the softmax temperature to the
# dot products
query /= self.alpha
query *= softmax_temp
# Compute the unnormalized attention and apply the key length mask
QK = torch.einsum("nhe,nshe->nsh", query, keys)
QK += key_lengths.additive_matrix[:, :, None]
# PB-Relax
QK -= torch.max(QK)
QK *= self.alpha
# Compute the attention and the weighted average
A = self.dropout(torch.softmax(QK, dim=1))
V = torch.einsum("nsh,nshd->nhd", A, values)
# Let the world know of the attention matrix
self.event_dispatcher.dispatch(AttentionEvent(self, A))
# Make sure that we return a contiguous value
return V.contiguous(), [keys, values]
# Register the attention implementation so that it becomes available in builders
RecurrentCrossAttentionRegistry.register("cv_full", CV_RCFA, [
("softmax_temp", Optional(Float)),
("alpha", Optional(Float, 32.0)),
("attention_dropout", Optional(Float, 0.1)),
("event_dispatcher", Optional(EventDispatcherInstance, "")),
])
class CV_FA(Module):
"""Implement the scaled dot product attention with PB-Relax and softmax.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
alpha: constant for preventing overflow (default: 32)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(
self,
softmax_temp=None,
alpha=32.0,
attention_dropout=0.1,
event_dispatcher="",
):
super().__init__()
self.softmax_temp = softmax_temp
self.alpha = alpha
self.dropout = Dropout(attention_dropout)
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
"""Implements the multihead softmax attention.
Arguments
---------
queries: (N, L, H, E) The tensor containing the queries
keys: (N, S, H, E) The tensor containing the keys
values: (N, S, H, D) The tensor containing the values
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
query_lengths: An implementation of BaseMask that encodes how
many queries each sequence in the batch consists of
key_lengths: An implementation of BaseMask that encodes how
many queries each sequence in the batch consists of
"""
# Extract some shapes and compute the temperature
N, L, H, E = queries.shape
_, S, _, D = values.shape
softmax_temp = self.softmax_temp or 1. / sqrt(E)
# Scale the queries instead of applying the softmax temperature to the
# dot products
queries /= self.alpha
queries *= softmax_temp
# Compute the unnormalized attention and apply the masks
QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
if not attn_mask.all_ones:
QK += attn_mask.additive_matrix
if not key_lengths.all_ones:
QK += key_lengths.additive_matrix[:, None, None]
# PB-Relax
QK -= torch.max(QK)
QK *= self.alpha
# Compute the attention and the weighted average
A = self.dropout(torch.softmax(QK, dim=-1))
V = torch.einsum("nhls,nshd->nlhd", A, values)
# Let the world know of the attention matrix
self.event_dispatcher.dispatch(AttentionEvent(self, A))
# Make sure that what we return is contiguous
return V.contiguous()
# Register the attention implementation so that it becomes available in our
# builders
AttentionRegistry.register("cv_full", CV_FA, [
("softmax_temp", Optional(Float)),
("alpha", Optional(Float, 32.0)),
("attention_dropout", Optional(Float, 0.1)),
("event_dispatcher", Optional(EventDispatcherInstance, "")),
])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment