Skip to content

Instantly share code, notes, and snippets.

@bhosmer
Created May 21, 2021 03:32
Show Gist options
  • Save bhosmer/ef517d0774f2f10336b8140116fd6b62 to your computer and use it in GitHub Desktop.
Save bhosmer/ef517d0774f2f10336b8140116fd6b62 to your computer and use it in GitHub Desktop.
from typing import List
import torch
from torch import nn, Tensor
from torch.nn.utils.rnn import pad_sequence
class TransformerEncoder(nn.Module):
"""
An example module for use with GenericPyTorchLightningMulticlassClassifierVarSized.
Applies 1-layer Transformer encoder followed by self attention pooling
"""
input_embeddings_size: int
fix_empty: bool
def __init__(
self,
input_embeddings_size: int,
embeddings_size: int,
num_layers: int = 1,
num_heads: int = 8,
pooling: str = "CLS",
fix_empty: bool = False,
):
super().__init__()
self.input_embeddings_size = input_embeddings_size
self.fix_empty = fix_empty
if input_embeddings_size != embeddings_size:
self.projection = nn.Linear(input_embeddings_size, embeddings_size)
else:
self.projection = None
encoder_norm = nn.LayerNorm(embeddings_size)
encoder_layer = nn.TransformerEncoderLayer(d_model=embeddings_size, nhead=num_heads, dim_feedforward=embeddings_size*4)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_heads, norm=encoder_norm)
self.pooling = pooling
#log_class_usage(__class__)
def forward(self, seqs: List[Tensor]) -> Tensor:
reshaped_seqs = [x.reshape(-1, self.input_embeddings_size) for x in seqs]
if self.fix_empty:
for i in range(len(reshaped_seqs)):
if reshaped_seqs[i].shape[0] == 0:
reshaped_seqs[i] = torch.zeros(
1, self.input_embeddings_size, device=seqs[0].device
)
if self.projection is not None:
reshaped_seqs = [self.projection(seq) for seq in reshaped_seqs]
lengths_to_mask = torch.tensor([len(x) for x in reshaped_seqs], device=seqs[0].device)
padding_mask = torch.ge(
torch.arange(torch.max(lengths_to_mask), device=seqs[0].device)[None, :],
lengths_to_mask[:, None].long(),
)
padded_seqs = pad_sequence(reshaped_seqs, batch_first=True)
result = self.transformer_encoder(padded_seqs.permute((1, 0, 2)), src_key_padding_mask=padding_mask) # (seq_length, batch_size, embeddings_size)
if self.pooling == "CLS":
# return the embeddings of the first token
return result[0, :, :]
else:
raise NotImplementedError("pooling approach {} is not implemented.".format(self.pooling))
model = TransformerEncoder(1024, 1024)
model = torch.quantization.quantize_dynamic(
model, dtype=torch.qint8, inplace=False
)
scripted = torch.jit.script(model)
@z-a-f
Copy link

z-a-f commented May 25, 2021

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment