Last active
August 24, 2024 04:23
-
-
Save SippieCup/a334e461946f38d2ba50501e9aa368c5 to your computer and use it in GitHub Desktop.
Example implementation from a basic Transformer model to CNNTransformer.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class CNNTransformer(nn.Module): | |
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size): | |
super(CNNTransformer, self).__init__() | |
self.convs = nn.ModuleList([ | |
nn.Conv1d(in_channels=input_dim if i == 0 else hidden_dim, | |
out_channels=hidden_dim, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2) | |
for i in range(num_layers) | |
]) | |
self.fc = nn.Linear(hidden_dim, output_dim) | |
self.layer_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)]) | |
def forward(self, x): | |
x = x.transpose(1, 2) # Transpose for Conv1d: (batch_size, input_dim, seq_len) | |
for conv, ln in zip(self.convs, self.layer_norms): | |
x = conv(x) | |
x = ln(x.transpose(1, 2)).transpose(1, 2) # Transpose back after LayerNorm | |
x = F.relu(x) | |
x = x.sum(dim=2) # Sum across the sequence dimension | |
output = self.fc(x) # Final linear layer | |
return output | |
def convert_transformer_to_cnn(transformer_model, num_layers, kernel_size): | |
""" | |
Convert a Transformer model to a CNNTransformer model. | |
Parameters: | |
transformer_model (nn.Module): The original Transformer model. | |
num_layers (int): The number of convolutional layers for the CNNTransformer. | |
kernel_size (int): The kernel size for the convolutional layers. | |
Returns: | |
CNNTransformer: The converted CNNTransformer model. | |
""" | |
# Extract the dimensions from the Transformer's embedding layer and final linear layer | |
input_dim = transformer_model.embeddings.word_embeddings.embedding_dim # Assuming the Transformer has an embedding layer | |
hidden_dim = transformer_model.encoder.layer[0].attention.self.query.out_features # Assuming a multi-head self-attention mechanism | |
output_dim = transformer_model.config.vocab_size # Assuming a language model; this might differ based on the task | |
# Instantiate the CNNTransformer | |
cnn_transformer = CNNTransformer(input_dim, hidden_dim, output_dim, num_layers, kernel_size) | |
return cnn_transformer | |
# Example usage: | |
if __name__ == "__main__": | |
# Assuming you have a pre-trained Transformer model like BERT | |
from transformers import BertModel | |
# Load a pre-trained BERT model (as an example of a Transformer model) | |
transformer_model = BertModel.from_pretrained('bert-base-uncased') | |
# Convert the Transformer model to a CNNTransformer model | |
num_layers = 12 | |
kernel_size = 3 | |
cnn_transformer = convert_transformer_to_cnn(transformer_model, num_layers, kernel_size) | |
# Example input: batch of sequences (batch_size, seq_len, input_dim) | |
batch_size = 16 | |
seq_len = 20 | |
input_dim = transformer_model.config.hidden_size | |
x = torch.randn(batch_size, seq_len, input_dim) | |
# Forward pass | |
output = cnn_transformer(x) | |
print(output.shape) # Expected output shape: (batch_size, output_dim) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment