Skip to content

Instantly share code, notes, and snippets.

@SippieCup
Last active August 24, 2024 04:23
Show Gist options
  • Save SippieCup/a334e461946f38d2ba50501e9aa368c5 to your computer and use it in GitHub Desktop.
Save SippieCup/a334e461946f38d2ba50501e9aa368c5 to your computer and use it in GitHub Desktop.
Example implementation from a basic Transformer model to CNNTransformer.
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNNTransformer(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
super(CNNTransformer, self).__init__()
self.convs = nn.ModuleList([
nn.Conv1d(in_channels=input_dim if i == 0 else hidden_dim,
out_channels=hidden_dim,
kernel_size=kernel_size,
padding=kernel_size // 2)
for i in range(num_layers)
])
self.fc = nn.Linear(hidden_dim, output_dim)
self.layer_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
def forward(self, x):
x = x.transpose(1, 2) # Transpose for Conv1d: (batch_size, input_dim, seq_len)
for conv, ln in zip(self.convs, self.layer_norms):
x = conv(x)
x = ln(x.transpose(1, 2)).transpose(1, 2) # Transpose back after LayerNorm
x = F.relu(x)
x = x.sum(dim=2) # Sum across the sequence dimension
output = self.fc(x) # Final linear layer
return output
def convert_transformer_to_cnn(transformer_model, num_layers, kernel_size):
"""
Convert a Transformer model to a CNNTransformer model.
Parameters:
transformer_model (nn.Module): The original Transformer model.
num_layers (int): The number of convolutional layers for the CNNTransformer.
kernel_size (int): The kernel size for the convolutional layers.
Returns:
CNNTransformer: The converted CNNTransformer model.
"""
# Extract the dimensions from the Transformer's embedding layer and final linear layer
input_dim = transformer_model.embeddings.word_embeddings.embedding_dim # Assuming the Transformer has an embedding layer
hidden_dim = transformer_model.encoder.layer[0].attention.self.query.out_features # Assuming a multi-head self-attention mechanism
output_dim = transformer_model.config.vocab_size # Assuming a language model; this might differ based on the task
# Instantiate the CNNTransformer
cnn_transformer = CNNTransformer(input_dim, hidden_dim, output_dim, num_layers, kernel_size)
return cnn_transformer
# Example usage:
if __name__ == "__main__":
# Assuming you have a pre-trained Transformer model like BERT
from transformers import BertModel
# Load a pre-trained BERT model (as an example of a Transformer model)
transformer_model = BertModel.from_pretrained('bert-base-uncased')
# Convert the Transformer model to a CNNTransformer model
num_layers = 12
kernel_size = 3
cnn_transformer = convert_transformer_to_cnn(transformer_model, num_layers, kernel_size)
# Example input: batch of sequences (batch_size, seq_len, input_dim)
batch_size = 16
seq_len = 20
input_dim = transformer_model.config.hidden_size
x = torch.randn(batch_size, seq_len, input_dim)
# Forward pass
output = cnn_transformer(x)
print(output.shape) # Expected output shape: (batch_size, output_dim)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment