Last active June 18, 2023 13:42
A token matrix visualization representing the importance a token has in multi-headed attention layers in transformers like GPT-2
import torch
from transformers import GPT2Tokenizer, GPT2Model
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation
def get_model_info(model):
num_layers = model.config.n_layer
num_heads = model.config.n_head
hidden_dim = model.config.n_embd
head_size = hidden_dim // num_heads
return num_layers, num_heads, hidden_dim, head_size
model_name = "gpt2" # Replace with your desired GPT-2 model name
model = GPT2Model.from_pretrained(model_name)
def get_activated_neurons_with_weights(model, tokenizer, input_text):
input_ids = tokenizer.encode(input_text, return_tensors="pt")
with torch.no_grad():
outputs = model(input_ids, output_attentions=True)
attentions = outputs.attentions
# Get the attention weights for the last layer
last_layer_attention = attentions[-1][0]
# Find the neurons with the highest average attention across all tokens
average_attention = last_layer_attention.mean(dim=0)
activated_neurons = torch.argsort(average_attention, descending=True)
# Get the attention weights for the activated neurons
activated_neuron_weights = last_layer_attention[:, activated_neurons]
return activated_neurons, activated_neuron_weights, tokenizer.decode(input_ids[0])
# Example usage
model_name = "gpt2" # Replace with your desired GPT-2 model name
input_text = "I dont know?"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name)
# Generate the initial chunk of text and extract neuron activations
activated_neurons, neuron_weights, generated_text = get_activated_neurons_with_weights(model, tokenizer, input_text)
# Store the generated text and neuron activations
memory = {
'generated_text': generated_text,
'activated_neurons': activated_neurons,
'neuron_weights': neuron_weights
# Retrieve the stored memory
retrieved_generated_text = memory['generated_text']
retrieved_activated_neurons = memory['activated_neurons']
retrieved_neuron_weights = memory['neuron_weights']
# Create a figure to display the animation
fig = plt.figure()
# Define a function to update the animation
def update(frame):
# Clear the previous plot
# Plot the current frame
im = plt.imshow(np.array(retrieved_neuron_weights[0][frame]), cmap='hot', interpolation='nearest')
# Add labels for each token
tokens = tokenizer.tokenize(retrieved_generated_text)
plt.xticks(range(len(tokens)), tokens, rotation='vertical')
plt.yticks(range(len(tokens)), tokens)
# Create the animation
anim = FuncAnimation(fig, update, frames=len(retrieved_neuron_weights[0]), interval=500, repeat=True)
# Display the animation
