Skip to content

Instantly share code, notes, and snippets.

@innat
Last active September 28, 2023 19:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save innat/e88b096390985e806299b4a3dccc5118 to your computer and use it in GitHub Desktop.
Save innat/e88b096390985e806299b4a3dccc5118 to your computer and use it in GitHub Desktop.
torch 2 tf mha weight porting

About: A simple demonstration to translate multihead self attention from PyTorch to Keras.

Multi-Head Self Attention

import torch
import torch.nn as nn

class TorchAttentionModel(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_head)
    def forward(self, x):
        return self.attn(x, x, x)

Keras has dedicated MHSA layer also, but we'll use a custom layer anyway.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class TFMultiheadAttention(keras.Model):
    def __init__(self, num_heads, key_dim, dropout=0.0, **kwargs):
        super().__init__(**kwargs)
        self.key_dim = key_dim
        self.num_heads = num_heads
        self.head_dim = key_dim // num_heads

        assert (
            self.head_dim * num_heads == key_dim
        ), "key_dim size needs to be divisible by num_heads"

        # Create weights for query, key, and value projections
        self.wq = layers.Dense(key_dim)
        self.wk = layers.Dense(key_dim)
        self.wv = layers.Dense(key_dim)
        
        # Output dense layer
        self.fc_out = layers.Dense(key_dim)
        
        # attn dropput
        self.dropout =  layers.Dropout(rate=dropout)

    def transpose_qkv(self, x, T, N):
        x = tf.reshape(x, [T, N, self.num_heads, self.head_dim])
        x = tf.transpose(x, [1, 2, 0, 3])
        return x

    def call(
        self, 
        query, 
        key, 
        value, 
        attention_mask=None, 
        return_attention_scores=False, 
        training=None
    ):
        batch_size = tf.shape(query)[0]
        
        # Linear projections
        query = self.wq(query)
        key = self.wk(key)
        value = self.wv(value)

        # transposing
        Tx, Ty, N = tf.shape(query)[0], tf.shape(key)[0], tf.shape(query)[1]
        query = self.transpose_qkv(query, Tx, N)
        key = self.transpose_qkv(key, Ty, N)
        value = self.transpose_qkv(value, Ty, N)

        # Dot product attention
        matmul_qk = tf.matmul(query, key, transpose_b=True)
        d_k = tf.cast(self.head_dim, dtype=matmul_qk.dtype)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(d_k)

        if attention_mask is not None:
            scaled_attention_logits += (attention_mask * -1e9)

        # Apply softmax
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        attention_weights = self.dropout(attention_weights, training=training)
        
        # matmul between qk and v
        attention_output = tf.matmul(attention_weights, value)
        attention_output = tf.transpose(attention_output, perm=[2, 0, 1, 3])
        attention_output = tf.reshape(attention_output, (batch_size, -1, self.key_dim))
        attention_output = self.fc_out(attention_output)
    
        if return_attention_scores:
            return attention_output, attention_weights
        return attention_output
@innat
Copy link
Author

innat commented Sep 28, 2023

Note: In torch, the q, k, and v are combined in one nn.Linear layer whereas in keras it is separated in three different linear layer, layer.Dense. So, while poritng between these layers, consider such computation.


Weight Porting Uitliy [Option 1]

def qkv_torch_to_tf_v1(qkv, embed_dim):

    if qkv.ndim == 2:
        qkv_tf = qkv.T
        q = qkv_tf[:, : embed_dim]
        k = qkv_tf[:, embed_dim : 2 * embed_dim]
        v = qkv_tf[:, -embed_dim :]
    elif qkv.ndim == 1:
        qkv_tf = qkv
        q = qkv_tf[: embed_dim]
        k = qkv_tf[embed_dim : 2 * embed_dim]
        v = qkv_tf[-embed_dim :]

    return q, k, v

Weight Porting Uitliy [Option 2]

This utility can be used either.

def qkv_torch_to_tf_v2(torch_model, tf_model, embed_dim):
    # get torch weights
    torch_weights = {
        name: param for name, param in torch_model.named_parameters()
    }

    # split in_proj weights and biases
    d_model = embed_dim
    q_weight = torch_weights['attn.in_proj_weight'][:d_model].detach().numpy().T
    k_weight = torch_weights['attn.in_proj_weight'][d_model:2*d_model].detach().numpy().T
    v_weight = torch_weights['attn.in_proj_weight'][2*d_model:].detach().numpy().T

    q_bias = torch_weights['attn.in_proj_bias'][:d_model].detach().numpy()
    k_bias = torch_weights['attn.in_proj_bias'][d_model:2*d_model].detach().numpy()
    v_bias = torch_weights['attn.in_proj_bias'][2*d_model:].detach().numpy()

    out_weight = torch_weights['attn.out_proj.weight'].detach().numpy().T
    out_bias = torch_weights['attn.out_proj.bias'].detach().numpy()

    # assigning
    tf_model.wq.set_weights([q_weight, q_bias])
    tf_model.wk.set_weights([k_weight, k_bias])
    tf_model.wv.set_weights([v_weight, v_bias])
    tf_model.fc_out.set_weights([out_weight, out_bias])

    return tf_model

@innat
Copy link
Author

innat commented Sep 28, 2023

torch_input = torch.randn(10, 16, 32)
torch_model = TorchAttentionModel(d_model=32, n_head=8)
torch_output, _ = torch_model(torch_input) # self-attn: q==k==v

torch_params = sum(
    p.numel() for p in torch_model.parameters() if p.requires_grad
)
print(torch_output.shape), torch_params
4224, torch.Size([10, 16, 32])
tf_model = TFMultiheadAttention(key_dim=32, num_heads=8)
tf_input = tf.cast(torch_input, dtype='float32')
tf_output, _ = tf_model(tf_input, tf_input, tf_input, return_attention_scores=True)

tf_params = tf_model.count_params()
tf_params, print(tf_output.shape),
4224, (10, 16, 32)

@innat
Copy link
Author

innat commented Sep 28, 2023

Weight port - Option 1

tf_model = qkv_torch_to_tf_v2(
    torch_model, 
    tf_model, 
    embed_dim=32 # d_model == key_dim in this case
)

tf_out = tf_model(
    tf_input, tf_input, tf_input
) 

np.testing.assert_allclose(
    torch_output.detach().numpy(),
    tf_out.numpy(),
    1e-4, 1e-4
) # OK

@innat
Copy link
Author

innat commented Sep 28, 2023

Weight port - Option 2

state_dict = torch_model.state_dict()

q_weight, k_weight, v_weight = qkv_torch_to_tf_v1(
    state_dict['attn.in_proj_weight'], embed_dim=32 # d_model == key_dim in this case
)
q_bias, k_bias, v_bias = qkv_torch_to_tf_v1(
    state_dict['attn.in_proj_bias'], embed_dim=32 # d_model == key_dim in this case
)
tf_model.wq.kernel.assign(tf.Variable(q_weight))
tf_model.wq.bias.assign(tf.Variable(q_bias))
tf_model.wk.kernel.assign(tf.Variable(k_weight))
tf_model.wk.bias.assign(tf.Variable(k_bias))
tf_model.wv.kernel.assign(tf.Variable(v_weight))
tf_model.wv.bias.assign(tf.Variable(v_bias))
tf_model.fc_out.kernel.assign(
    tf.Variable(state_dict['attn.out_proj.weight'].detach().numpy().T)
)
tf_model.fc_out.bias.assign(
    tf.Variable(state_dict['attn.out_proj.bias'].detach().numpy())
)
torch_input = torch.randn(10, 16, 32)
torch_model = TorchAttentionModel(d_model=32, n_head=8)
torch_output, _ = torch_model(torch_input) # self-attn: q==k==v

tf_input = tf.cast(torch_input, dtype='float32')
tf_out = tf_model(tf_input, tf_input, tf_input)

np.testing.assert_allclose(
    torch_output.detach().numpy(),
    tf_out.numpy(),
    1e-4, 1e-4
) # OK

Colab Gist.

@innat
Copy link
Author

innat commented Sep 28, 2023

In official keras.layers.MultiHeadAttention(num_heads, key_dim, ...)

num_heads: Number of attention heads.
key_dim: Size of each attention head for query and key.

In torch.nn.MultiheadAttention(embed_dim, num_heads, ..

embed_dimTotal dimension of the model.
num_headsNumber of parallel attention heads. 
            Note that embed_dim will be split across num_heads 
           (i.e. each head will have dimension embed_dim // num_heads).

So, in keras, instead of TFAttentionModel(d_model=32, n_head=8), it would be TFAttentionModel(d_model=32 // 8, n_head=8).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment