Skip to content

Instantly share code, notes, and snippets.

@maximus12793
Last active August 9, 2022 18:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maximus12793/698f2d33382aef46a3f99ef1d31c755a to your computer and use it in GitHub Desktop.
Save maximus12793/698f2d33382aef46a3f99ef1d31c755a to your computer and use it in GitHub Desktop.
# Validate Attention
import math
from flax.linen.attention import dot_product_attention
import jax
import jax.numpy as jnp
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# Setup
np.random.seed(0)
seq_len, d_k = 5, 5
qkv = np.random.rand(3, seq_len, d_k)
# Flax version
# Note: Currently QKV is (3 vec, 5x5 dim). In order to add a heads dim for Flax
# we can expand dims to (3, 5, 1, 5) which corresponds to 3x
# (seq_len, dim_head, dim_key) via.
# qkv = jnp.expand_dims(qkv, axis=2)
# Alternatively, expand for each vec.
q, k, v = qkv[0], qkv[1], qkv[2]
q = jnp.expand_dims(q, axis=1) #[batch..., q_length, num_heads, qk_depth_per_head]
k = jnp.expand_dims(k, axis=1)
v = jnp.expand_dims(v, axis=1)
# Results (Flax)
flax_value = dot_product_attention(q, k, v, dropout_rate=0.0, deterministic=True)
flax_value = jnp.asarray(jnp.squeeze(flax_value)) # Remove head-dim
# Torch Impl.
def attention(q, k, v, dropout_p=0.0):
upper = torch.matmul(q, k.transpose(-2, -1))
lower = math.sqrt(k.shape[-1]) # B, Nt, E*
inner = (upper / lower)
attention = F.softmax(inner, dim=-1)
out = F.dropout(attention, p=dropout_p)
values = torch.matmul(out, v)
return values, attention
qkv = torch.tensor(qkv_random)
q, k, v = qkv[0], qkv[1], qkv[2]
# Results (Torch)
t_values, t_attention = attention(q, k, v, dropout_p=0.0)
t_values = t_values.numpy()
assert np.allclose(t_values, flax_value), "Failed"
print("Success")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment