Skip to content

Instantly share code, notes, and snippets.

@zhuangh
Last active May 5, 2024 02:00
Show Gist options
  • Save zhuangh/df682b954e36484e6c274f3446b408aa to your computer and use it in GitHub Desktop.
Save zhuangh/df682b954e36484e6c274f3446b408aa to your computer and use it in GitHub Desktop.
MQA reshape_go_faster.py
"""
baseline runtime(s) 0.5243263244628906
with reshape runtime (s) 0.0022399425506591797
@ cpu
=========
baseline runtime (s) 0.25386476516723633
with reshape runtime (s) 0.0008966922760009766
@ cuda:0
"""
import torch
import torch.nn.functional as F
from einops import einsum, rearrange
import time
# Start timing
# b, seq, g*head, emb
qq = torch.ones(1, 256, 8, 64)
# b, seq, head, emb
kk = torch.ones(1, 256, 2, 64)
vv = torch.ones(1, 256, 2, 64)
num_head_groups = qq.shape[2] // kk.shape[2]
scale = qq.size(-1) ** 0.5
start_time = time.time()
q = rearrange(qq, "b s (g h) e -> b s g h e", g=num_head_groups)
scores = einsum(q, kk, "b s g h e, b s m e -> b s h m")
att = F.softmax(scores / scale, dim=-1)
out1 = einsum(att, vv, "b s h m, b s m e -> b s h e")
end_time = time.time() - start_time
print("baseline runtime(s)", end_time)
start_time = time.time()
q = rearrange(qq, "b s g e -> b g s e")
v = rearrange(vv, "b s h e -> b h s e")
k = rearrange(kk, "b s h e -> b h s e")
q = rearrange(q, "b (g h) s e -> b g h s e", g=num_head_groups)
scores = einsum(q, k, "b g h s e, b h ss e -> b h s ss")
att = F.softmax(scores / scale, dim=-1)
out = einsum(att, v, "b h s ss, b h ss e -> b h s e")
out = rearrange(out, "b h s e -> b s h e")
end_time = time.time() - start_time
print("with reshape runtime (s)", end_time)
torch.testing.assert_close(out1, out)
print("@", out.device)
print("=========")
# Start timing
# b, seq, g*head, emb
qq = torch.ones(1, 256, 8, 64).to("cuda")
# b, seq, head, emb
kk = torch.ones(1, 256, 2, 64).to("cuda")
vv = torch.ones(1, 256, 2, 64).to("cuda")
num_head_groups = qq.shape[2] // kk.shape[2]
scale = qq.size(-1) ** 0.5
start_time = time.time()
q = rearrange(qq, "b s (g h) e -> b s g h e", g=num_head_groups)
scores = einsum(q, kk, "b s g h e, b s m e -> b s h m")
att = F.softmax(scores / scale, dim=-1)
out1 = einsum(att, vv, "b s h m, b s m e -> b s h e")
end_time = time.time() - start_time
print("baseline runtime (s)", end_time)
start_time = time.time()
q = rearrange(qq, "b s g e -> b g s e")
v = rearrange(vv, "b s h e -> b h s e")
k = rearrange(kk, "b s h e -> b h s e")
q = rearrange(q, "b (g h) s e -> b g h s e", g=num_head_groups)
scores = einsum(q, k, "b g h s e, b h ss e -> b h s ss")
att = F.softmax(scores / scale, dim=-1)
out = einsum(att, v, "b h s ss, b h ss e -> b h s e")
out = rearrange(out, "b h s e -> b s h e")
end_time = time.time() - start_time
print("with reshape runtime (s)", end_time)
torch.testing.assert_close(out1, out)
print("@", out.device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment