Skip to content

Instantly share code, notes, and snippets.

@sshleifer
Last active October 29, 2021 15:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sshleifer/b6acc8367f565d9a0db0be3f1136de9a to your computer and use it in GitHub Desktop.
Save sshleifer/b6acc8367f565d9a0db0be3f1136de9a to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
d = 8
seq_len = 13
bs = 1
wt = torch.rand((d, d))
x = torch.rand((seq_len, bs, d))
x_r0, x_r1 = x[:,:, :d//2], x[:,:, d//2:]
wt_r0, wt_r1 = wt[:, :d//2], wt[:, d//2:]
x_catted = torch.cat([x_r0, x_r1], 2)
wt_catted = torch.cat([wt_r0, wt_r1], 1)
result_catted = F.linear(x_catted, wt_catted)
#
or0_proj = F.linear(x_r0, wt_r0)
or1_proj = F.linear(x_r1, wt_r1)
result_summed = or0_proj + or1_proj
equality_pct = result_summed.eq(result_catted).float().mean() # .6063
print(f'delta: {(result_summed - result_catted).abs().mean()}, equality_pct: {equality_pct}')
# delta: 5.788528056882569e-08, equality_pct: 0.6346153616905212
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment