Skip to content

Instantly share code, notes, and snippets.

@airalcorn2
airalcorn2 / hook_transformer_attn.py
Last active July 15, 2024 19:38
A simple script for extracting the attention weights from a PyTorch Transformer.
# Inspired by: https://towardsdatascience.com/the-one-pytorch-trick-which-you-should-know-2d5e9c1da2ca.
# Monkey patching idea suggested by @kklemon here:
# https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91?permalink_comment_id=4407423#gistcomment-4407423.
import torch
from torch import nn
def patch_attention(m):