Last active
August 9, 2022 18:31
-
-
Save maximus12793/698f2d33382aef46a3f99ef1d31c755a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Validate Attention | |
import math | |
from flax.linen.attention import dot_product_attention | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# Setup | |
np.random.seed(0) | |
seq_len, d_k = 5, 5 | |
qkv = np.random.rand(3, seq_len, d_k) | |
# Flax version | |
# Note: Currently QKV is (3 vec, 5x5 dim). In order to add a heads dim for Flax | |
# we can expand dims to (3, 5, 1, 5) which corresponds to 3x | |
# (seq_len, dim_head, dim_key) via. | |
# qkv = jnp.expand_dims(qkv, axis=2) | |
# Alternatively, expand for each vec. | |
q, k, v = qkv[0], qkv[1], qkv[2] | |
q = jnp.expand_dims(q, axis=1) #[batch..., q_length, num_heads, qk_depth_per_head] | |
k = jnp.expand_dims(k, axis=1) | |
v = jnp.expand_dims(v, axis=1) | |
# Results (Flax) | |
flax_value = dot_product_attention(q, k, v, dropout_rate=0.0, deterministic=True) | |
flax_value = jnp.asarray(jnp.squeeze(flax_value)) # Remove head-dim | |
# Torch Impl. | |
def attention(q, k, v, dropout_p=0.0): | |
upper = torch.matmul(q, k.transpose(-2, -1)) | |
lower = math.sqrt(k.shape[-1]) # B, Nt, E* | |
inner = (upper / lower) | |
attention = F.softmax(inner, dim=-1) | |
out = F.dropout(attention, p=dropout_p) | |
values = torch.matmul(out, v) | |
return values, attention | |
qkv = torch.tensor(qkv_random) | |
q, k, v = qkv[0], qkv[1], qkv[2] | |
# Results (Torch) | |
t_values, t_attention = attention(q, k, v, dropout_p=0.0) | |
t_values = t_values.numpy() | |
assert np.allclose(t_values, flax_value), "Failed" | |
print("Success") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment