Created
July 7, 2021 17:14
-
-
Save GallagherCommaJack/3fb796cd0c99cd3ad3da90adaf02c1bf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import sqrt | |
import torch | |
from fast_transformers.attention_registry import ( | |
AttentionRegistry, EventDispatcherInstance, Float, Optional, | |
RecurrentAttentionRegistry, RecurrentCrossAttentionRegistry) | |
from fast_transformers.builders import (RecurrentDecoderBuilder, | |
RecurrentEncoderBuilder, | |
TransformerDecoderBuilder, | |
TransformerEncoderBuilder) | |
from fast_transformers.events import AttentionEvent, EventDispatcher | |
from fast_transformers.masking import FullMask, LengthMask | |
from fast_transformers.recurrent import (RecurrentTranformerEncoderLayer, | |
RecurrentTransformerDecoderLayer) | |
from fast_transformers.transformers import (TransformerDecoderLayer, | |
TransformerEncoderLayer) | |
from torch.nn import Dropout, LayerNorm, Module | |
class CV_RTEL(RecurrentTranformerEncoderLayer): | |
"""Attention to the previous inputs and feed forward with skip connections. | |
This transformer encoder layer is the recurrent dual of | |
fast_transformers.transformers.TransformerEncoderLayer . The results should | |
be identical given the same inputs and a lower triangular mask. | |
Arguments | |
--------- | |
attention: The attention implementation to use given as a nn.Module | |
d_model: The input feature dimensionality | |
d_ff: The dimensionality of the intermediate features after the | |
attention (default: d_model*4) | |
dropout: The dropout rate to apply to the intermediate features | |
(default: 0.1) | |
activation: {'relu', 'gelu'} Which activation to use for the feed | |
forward part of the layer (default: relu) | |
event_dispatcher: str or EventDispatcher instance to be used by this | |
module for dispatching events (default: the default | |
global dispatcher) | |
""" | |
def __init__( | |
self, | |
attention, | |
d_model, | |
d_ff=None, | |
dropout=0.1, | |
activation='gelu', | |
event_dispatcher='', | |
): | |
super().__init__( | |
attention, | |
d_model, | |
d_ff, | |
dropout, | |
activation, | |
event_dispatcher, | |
) | |
self.norm3 = LayerNorm(d_model) | |
self.norm4 = LayerNorm(d_model) | |
def forward(self, x, state=None): | |
"""Apply the transformer encoder to the input x using the provided | |
memory. | |
Arguments | |
--------- | |
x: The input features of shape (N, E) where N is the batch size and | |
E is d_model passed in the constructor | |
state: The state can vary depending on the attention implementation | |
""" | |
x_ = self.norm1(x / torch.max(x)) | |
x_, state = self.attention(x_, x_, x_, state) | |
x_ = self.norm2(x_ / torch.max(x_)) | |
x += self.dropout(x_) | |
x_ = self.norm3(x / torch.max(x)) | |
x_ = self.activation(self.linear1(x_)) | |
x_ = self.linear2(self.dropout(x_)) | |
x_ = self.norm4(x_ / torch.max(x_)) | |
x += self.dropout(x_) | |
return x, state | |
class CV_REB(RecurrentEncoderBuilder): | |
def _get_encoder_layer_class(self): | |
return CV_RTEL | |
class CV_RTDL(RecurrentTransformerDecoderLayer): | |
"""Attention to the previous inputs and a preprocessed memory. | |
This transformer decoder layer is the recurrent dual of | |
fast_transformers.transformers.TransformerDecoderLayer . The results should | |
be identical given the same inputs and a lower triangular mask for x_mask. | |
Arguments | |
--------- | |
self_attention: The attention implementation to use for self attention | |
given as a nn.Module | |
cross_attention: The attention implementation to use for cross | |
attention given as a nn.Module | |
d_model: The input feature dimensionality | |
d_ff: The dimensionality of the intermediate features after the | |
attention (default: d_model*4) | |
dropout: The dropout rate to apply to the intermediate features | |
(default: 0.1) | |
activation: {'relu', 'gelu'} Which activation to use for the feed | |
forward part of the layer (default: relu) | |
event_dispatcher: str or EventDispatcher instance to be used by this | |
module for dispatching events (default: the default | |
global dispatcher) | |
""" | |
def __init__( | |
self, | |
self_attention, | |
cross_attention, | |
d_model, | |
d_ff=None, | |
dropout=0.1, | |
activation='gelu', | |
event_dispatcher='', | |
): | |
super().__init__( | |
self_attention, | |
cross_attention, | |
d_model, | |
d_ff, | |
dropout, | |
activation, | |
event_dispatcher, | |
) | |
self.norm4 = LayerNorm(d_model) | |
self.norm5 = LayerNorm(d_model) | |
self.norm6 = LayerNorm(d_model) | |
def forward(self, x, memory, memory_length_mask=None, state=None): | |
"""Apply the transformer decoder to the input x and also attend to | |
memory. | |
Note the memory mask is assumed to be a full mask. | |
Arguments | |
--------- | |
x: The input features of shape (N, E) where N is the batch size and | |
E is d_model passed in the constructor | |
memory: A sequence of features (N, S, E) that the input will attend | |
to. S is the sequence length and E is the same as for x. | |
memory_length_mask: An implementation of a BaseMask that encodes | |
how many elements each memory sequence in the | |
batch consists of. | |
state: The state varies depending on the attention implementations | |
but it allows for recurrent implementation. | |
""" | |
# Normalize the mask | |
N = x.shape[0] | |
L = memory.shape[1] | |
memory_length_mask = memory_length_mask or \ | |
LengthMask(x.new_full((N,), L, dtype=torch.int64)) | |
# Extract the individual states for the self attention and the cross | |
# attention | |
self_state, cross_state = state or [None, None] | |
# First apply the self attention and add it to the input | |
x_ = self.norm1(x / torch.max(x)) | |
x_, self_state = self.self_attention( | |
x_, | |
x_, | |
x_, | |
state=self_state, | |
) | |
x_ = self.norm2(x_ / torch.max(x_)) | |
x += self.dropout(x_) | |
# Secondly apply the cross attention and add it to the previous output | |
x_ = self.norm3(x / torch.max(x)) | |
x_, cross_state = self.cross_attention( | |
x_, | |
memory, | |
memory, | |
memory_length_mask, | |
state=cross_state, | |
) | |
x_ = self.norm4(x_ / torch.max(x_)) | |
x += self.dropout(x_) | |
# Finally run the fully connected part of the layer | |
x_ = self.norm5(x / torch.max(x)) | |
x_ = self.activation(self.linear1(x_)) | |
x_ = self.linear2(self.dropout(x_)) | |
x_ = self.norm6(x_ / torch.max(x_)) | |
x += self.dropout(x_) | |
return x, [self_state, cross_state] | |
class CV_RDB(RecurrentDecoderBuilder): | |
def _get_decoder_layer_class(self): | |
return CV_RTDL | |
class CV_TEL(TransformerEncoderLayer): | |
"""Self attention and feed forward network with skip connections. | |
This transformer encoder layer implements the same encoder layer as | |
PyTorch but is a bit more open for extension by receiving the attention | |
implementation as a constructor argument. | |
Arguments | |
--------- | |
attention: The attention implementation to use given as a nn.Module | |
d_model: The input feature dimensionality | |
d_ff: The dimensionality of the intermediate features after the | |
attention (default: d_model*4) | |
dropout: The dropout rate to apply to the intermediate features | |
(default: 0.1) | |
activation: {'relu', 'gelu'} Which activation to use for the feed | |
forward part of the layer (default: relu) | |
event_dispatcher: str or EventDispatcher instance to be used by this | |
module for dispatching events (default: the default | |
global dispatcher) | |
""" | |
def __init__( | |
self, | |
attention, | |
d_model, | |
d_ff=None, | |
dropout=0.1, | |
activation="gelu", | |
event_dispatcher="", | |
): | |
super().__init__( | |
attention, | |
d_model, | |
d_ff, | |
dropout, | |
activation, | |
event_dispatcher, | |
) | |
self.norm3 = LayerNorm(d_model) | |
self.norm4 = LayerNorm(d_model) | |
def forward(self, x, attn_mask=None, length_mask=None): | |
"""Apply the transformer encoder to the input x. | |
Arguments | |
--------- | |
x: The input features of shape (N, L, E) where N is the batch size, | |
L is the sequence length (padded) and E is d_model passed in the | |
constructor. | |
attn_mask: An implementation of fast_transformers.masking.BaseMask | |
that encodes where each element of x can attend to. | |
length_mask: An implementation of | |
fast_transformers.masking.BaseMask that encodes how | |
many elements each sequence in the batch consists of. | |
""" | |
# Normalize the masks | |
N = x.shape[0] | |
L = x.shape[1] | |
attn_mask = attn_mask or FullMask(L, device=x.device) | |
length_mask = length_mask or \ | |
LengthMask(x.new_full((N,), L, dtype=torch.int64)) | |
# Run self attention and add it to the input | |
x_ = self.norm1(x / torch.max(x)) | |
x_ = self.attention( | |
x_, | |
x_, | |
x_, | |
attn_mask=attn_mask, | |
query_lengths=length_mask, | |
key_lengths=length_mask, | |
) | |
x_ = self.norm2(x_ / torch.max(x_)) | |
x += self.dropout(x_) | |
# Run the fully connected part of the layer | |
x_ = self.norm3(x / torch.max(x)) | |
x_ = self.activation(self.linear1(x_)) | |
x_ = self.linear2(x_) | |
x_ = self.norm4(x_ / torch.max(x_)) | |
x += self.dropout(x_) | |
return x | |
class CV_TEB(TransformerEncoderBuilder): | |
def _get_encoder_layer_class(self): | |
return CV_TEL | |
class CV_TDL(TransformerDecoderLayer): | |
"""The decoder layer from "Attention Is All You Need". | |
Similar to the encoder layer, this layer implements the decoder that | |
PyTorch implements but can be used with any attention implementation | |
because it receives the attention layers as constructor arguments. | |
Arguments | |
--------- | |
self_attention: The attention implementation to use for self attention | |
given as a nn.Module | |
cross_attention: The attention implementation to use for cross | |
attention given as a nn.Module | |
d_model: The input feature dimensionality | |
d_ff: The dimensionality of the intermediate features after the | |
attention (default: d_model*4) | |
dropout: The dropout rate to apply to the intermediate features | |
(default: 0.1) | |
activation: {'relu', 'gelu'} Which activation to use for the feed | |
forward part of the layer (default: relu) | |
event_dispatcher: str or EventDispatcherpytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia instance to be used by this | |
module for dispatching events (default: the default | |
global dispatcher) | |
""" | |
def __init__( | |
self, | |
self_attention, | |
cross_attention, | |
d_model, | |
d_ff=None, | |
dropout=0.1, | |
activation="gelu", | |
event_dispatcher="", | |
): | |
super().__init__( | |
self_attention, | |
cross_attention, | |
d_model, | |
d_ff, | |
dropout, | |
activation, | |
event_dispatcher, | |
) | |
self.norm4 = LayerNorm(d_model) | |
self.norm5 = LayerNorm(d_model) | |
self.norm6 = LayerNorm(d_model) | |
def forward( | |
self, | |
x, | |
memory, | |
x_mask=None, | |
x_length_mask=None, | |
memory_mask=None, | |
memory_length_mask=None, | |
): | |
"""Apply the transformer decoder to the input x using the memory | |
`memory`. | |
Arguments | |
--------- | |
x: The input features of shape (N, L, E) where N is the batch size, | |
L is the sequence length (padded) and E should be the same as | |
the d_model passed in the constructor. | |
memory: The memory features of shape (N, L', E) where N is the | |
batch size, L' is the memory's sequence length (padded) and | |
E should be the same as the d_model. | |
x_mask: An implementation of fast_transformers.masking.BaseMask | |
that encodes where each element of x can attend to in x. | |
Namely the self attention mask. | |
x_length_mask: An implementation of a BaseMask that encodes how | |
many elements each sequence in the batch consists | |
of. | |
memory_mask: An implementation of BaseMask that encodes where each | |
element of x can attend to in the memory. Namely the | |
cross attention mask. | |
memory_length_mask: An implementation of a BaseMask that encodes how | |
many elements each memory sequence in the batch | |
consists of. | |
""" | |
# Normalize the masks | |
N = x.shape[0] | |
L = x.shape[1] | |
L_prime = memory.shape[1] | |
x_mask = x_mask or FullMask(L, device=x.device) | |
x_length_mask = x_length_mask or \ | |
LengthMask(x.new_full((N,), L, dtype=torch.int64)) | |
memory_mask = memory_mask or FullMask(L, L_prime, device=x.device) | |
memory_length_mask = memory_length_mask or \ | |
LengthMask(x.new_full((N,), L_prime, dtype=torch.int64)) | |
# First apply the self attention and add it to the input | |
x_ = self.norm1(x / torch.max(x)) | |
x_ = self.self_attention( | |
x_, | |
x_, | |
x_, | |
attn_mask=x_mask, | |
query_lengths=x_length_mask, | |
key_lengths=x_length_mask, | |
) | |
x_ = self.norm2(x_ / torch.max(x_)) | |
x += self.dropout(x_) | |
# Secondly apply the cross attention and add it to the previous output | |
x_ = self.norm3(x / torch.max(x)) | |
x_ = self.cross_attention( | |
x_, | |
memory, | |
memory, | |
attn_mask=memory_mask, | |
query_lengths=x_length_mask, | |
key_lengths=memory_length_mask, | |
) | |
x_ = self.norm4(x_ / torch.max(x_)) | |
x += self.dropout(x_) | |
# Finally run the fully connected part of the layer | |
x_ = self.norm5(x / torch.max(x_)) | |
x_ = self.activation(self.linear1(x_)) | |
x_ = self.linear2(x_) | |
x_ = self.norm6(x_ / torch.max(x_)) | |
x += self.dropout(x_) | |
return x | |
class CV_TDB(TransformerDecoderBuilder): | |
def _get_decoder_layer_class(self): | |
return CV_TDL | |
class CV_RFA(Module): | |
"""Implement the full softmax attention with PB-Relax as a recurrent module. | |
Arguments | |
--------- | |
softmax_temp: The temperature to use for the softmax attention. | |
(default: 1/sqrt(d_keys) where d_keys is computed at | |
runtime) | |
attention_dropout: The dropout rate to apply to the attention | |
(default: 0.1) | |
event_dispatcher: str or EventDispatcher instance to be used by this | |
module for dispatching events (default: the default | |
global dispatcher) | |
""" | |
def __init__( | |
self, | |
softmax_temp=None, | |
alpha=32.0, | |
attention_dropout=0.1, | |
event_dispatcher="", | |
): | |
super().__init__() | |
self.softmax_temp = softmax_temp | |
self.dropout = Dropout(attention_dropout) | |
self.event_dispatcher = EventDispatcher.get(event_dispatcher) | |
self.alpha = alpha | |
def forward(self, query, key, value, state=None): | |
# Extract some shapes and compute the temperature | |
N, H, E = query.shape | |
_, _, D = value.shape | |
softmax_temp = self.softmax_temp or 1. / sqrt(E) | |
# Aggregate the list of keys and values | |
if state is not None: | |
keys, values = state | |
keys = torch.cat([keys, key[:, :, None]], dim=2) | |
values = torch.cat([values, value[:, :, None]], dim=2) | |
else: | |
keys = key[:, :, None] | |
values = value[:, :, None] | |
query /= self.alpha | |
query *= softmax_temp | |
# Compute the unnormalized attention | |
QK = torch.einsum("nhe,nhse->nhs", query, keys) | |
# PB-Relax | |
QK -= torch.max(QK) | |
QK *= self.alpha | |
# Compute the attention and the weighted average | |
A = self.dropout(torch.softmax(QK, dim=-1)) | |
V = torch.einsum("nhs,nhsd->nhd", A, values) | |
# Let the world know of the attention matrix | |
self.event_dispatcher.dispatch(AttentionEvent(self, A)) | |
# Make sure that what we return is contiguous | |
return V.contiguous(), [keys, values] | |
# Register the attention implementation so that it becomes available in builders | |
RecurrentAttentionRegistry.register("cv_full", CV_RFA, [ | |
("softmax_temp", Optional(Float)), | |
("alpha", Optional(Float, 32.0)), | |
("attention_dropout", Optional(Float, 0.1)), | |
("event_dispatcher", Optional(EventDispatcherInstance, "")), | |
]) | |
class CV_RCFA(Module): | |
"""Implement the full softmax attention with PB-Relax as a recurrent module. | |
Arguments | |
--------- | |
softmax_temp: The temperature to use for the softmax attention. | |
(default: 1/sqrt(d_keys) where d_keys is computed at | |
runtime) | |
alpha: constant for preventing overflow (default: 32) | |
attention_dropout: The dropout rate to apply to the attention | |
(default: 0.1) | |
event_dispatcher: str or EventDispatcher instance to be used by this | |
module for dispatching events (default: the default | |
global dispatcher) | |
""" | |
def __init__( | |
self, | |
softmax_temp=None, | |
alpha=32.0, | |
attention_dropout=0.1, | |
event_dispatcher="", | |
): | |
super().__init__() | |
self.softmax_temp = softmax_temp | |
self.dropout = Dropout(attention_dropout) | |
self.event_dispatcher = EventDispatcher.get(event_dispatcher) | |
self.alpha = alpha | |
def forward(self, query, keys, values, key_lengths, state=None): | |
# Extract some shapes and compute the temperature | |
N, H, E = query.shape | |
softmax_temp = self.softmax_temp or 1. / sqrt(E) | |
# Extract the keys and values either from the arguments or the state | |
if state is not None: | |
keys, values = state | |
# Scale the queries instead of applying the softmax temperature to the | |
# dot products | |
query /= self.alpha | |
query *= softmax_temp | |
# Compute the unnormalized attention and apply the key length mask | |
QK = torch.einsum("nhe,nshe->nsh", query, keys) | |
QK += key_lengths.additive_matrix[:, :, None] | |
# PB-Relax | |
QK -= torch.max(QK) | |
QK *= self.alpha | |
# Compute the attention and the weighted average | |
A = self.dropout(torch.softmax(QK, dim=1)) | |
V = torch.einsum("nsh,nshd->nhd", A, values) | |
# Let the world know of the attention matrix | |
self.event_dispatcher.dispatch(AttentionEvent(self, A)) | |
# Make sure that we return a contiguous value | |
return V.contiguous(), [keys, values] | |
# Register the attention implementation so that it becomes available in builders | |
RecurrentCrossAttentionRegistry.register("cv_full", CV_RCFA, [ | |
("softmax_temp", Optional(Float)), | |
("alpha", Optional(Float, 32.0)), | |
("attention_dropout", Optional(Float, 0.1)), | |
("event_dispatcher", Optional(EventDispatcherInstance, "")), | |
]) | |
class CV_FA(Module): | |
"""Implement the scaled dot product attention with PB-Relax and softmax. | |
Arguments | |
--------- | |
softmax_temp: The temperature to use for the softmax attention. | |
(default: 1/sqrt(d_keys) where d_keys is computed at | |
runtime) | |
alpha: constant for preventing overflow (default: 32) | |
attention_dropout: The dropout rate to apply to the attention | |
(default: 0.1) | |
event_dispatcher: str or EventDispatcher instance to be used by this | |
module for dispatching events (default: the default | |
global dispatcher) | |
""" | |
def __init__( | |
self, | |
softmax_temp=None, | |
alpha=32.0, | |
attention_dropout=0.1, | |
event_dispatcher="", | |
): | |
super().__init__() | |
self.softmax_temp = softmax_temp | |
self.alpha = alpha | |
self.dropout = Dropout(attention_dropout) | |
self.event_dispatcher = EventDispatcher.get(event_dispatcher) | |
def forward(self, queries, keys, values, attn_mask, query_lengths, | |
key_lengths): | |
"""Implements the multihead softmax attention. | |
Arguments | |
--------- | |
queries: (N, L, H, E) The tensor containing the queries | |
keys: (N, S, H, E) The tensor containing the keys | |
values: (N, S, H, D) The tensor containing the values | |
attn_mask: An implementation of BaseMask that encodes where each | |
query can attend to | |
query_lengths: An implementation of BaseMask that encodes how | |
many queries each sequence in the batch consists of | |
key_lengths: An implementation of BaseMask that encodes how | |
many queries each sequence in the batch consists of | |
""" | |
# Extract some shapes and compute the temperature | |
N, L, H, E = queries.shape | |
_, S, _, D = values.shape | |
softmax_temp = self.softmax_temp or 1. / sqrt(E) | |
# Scale the queries instead of applying the softmax temperature to the | |
# dot products | |
queries /= self.alpha | |
queries *= softmax_temp | |
# Compute the unnormalized attention and apply the masks | |
QK = torch.einsum("nlhe,nshe->nhls", queries, keys) | |
if not attn_mask.all_ones: | |
QK += attn_mask.additive_matrix | |
if not key_lengths.all_ones: | |
QK += key_lengths.additive_matrix[:, None, None] | |
# PB-Relax | |
QK -= torch.max(QK) | |
QK *= self.alpha | |
# Compute the attention and the weighted average | |
A = self.dropout(torch.softmax(QK, dim=-1)) | |
V = torch.einsum("nhls,nshd->nlhd", A, values) | |
# Let the world know of the attention matrix | |
self.event_dispatcher.dispatch(AttentionEvent(self, A)) | |
# Make sure that what we return is contiguous | |
return V.contiguous() | |
# Register the attention implementation so that it becomes available in our | |
# builders | |
AttentionRegistry.register("cv_full", CV_FA, [ | |
("softmax_temp", Optional(Float)), | |
("alpha", Optional(Float, 32.0)), | |
("attention_dropout", Optional(Float, 0.1)), | |
("event_dispatcher", Optional(EventDispatcherInstance, "")), | |
]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment