Skip to content

Instantly share code, notes, and snippets.

@bwasti
Last active April 18, 2024 02:09
Show Gist options
  • Save bwasti/78d7058aad7f42dc893d906c877b710c to your computer and use it in GitHub Desktop.
Save bwasti/78d7058aad7f42dc893d906c877b710c to your computer and use it in GitHub Desktop.
# This is a test (not implementation) of the impact bucketMul has on matrix multiplications
# https://kolinko.github.io/effort/bucketmul.html
import torch
import torch.nn.functional as F
import math
torch.manual_seed(1337)
B = 2
N = 8
M = 16
bucket_size = 4
effort = 0.5
# preprocess weights
W = (torch.randn(N, M).round(decimals=2) * 100)
probes = torch.diagonal(W)
W_s, W_i = torch.sort(abs(W), descending=True)
reshape = lambda x: x.reshape(x.shape[0], -1, bucket_size)
W_sb = reshape(W_s)
W_ib = reshape(W_i)
W_stat = W_sb.mean(dim=-1)
# generate test input
v = torch.randn(B, N)
# generate a "masked" weight to check algorithm impact on accuracy
k = math.floor(v.shape[-1] * effort)
cutoff = torch.sort(abs(v * probes), descending=True).values[:,k] if k < v.shape[-1] else -1
print(cutoff)
print(v[...,None].shape, W_stat[None,...].shape)
mask = abs(v[...,None] * W_stat[None,...]) > cutoff[:,None,None]
sorted_W = torch.gather(W, 1, W_i)
sorted_W_b = sorted_W.reshape(N,-1,bucket_size)
print(sorted_W_b.shape, mask[...,None].shape)
masked_sorted_W = (sorted_W.reshape(N,-1,bucket_size) * mask[...,None]).reshape(B, N, M)
masked_W = torch.zeros((B, N, M), dtype=W.dtype, device=W.device)
for i in range(B):
masked_W[i].scatter_(1, W_i, masked_sorted_W[i])
bucket_out = (v.unsqueeze(1) @ masked_W).squeeze(1)
# calculate reference
ref = v @ W
print(masked_W.shape, W.shape, ref.shape, bucket_out.shape)
print(F.cosine_similarity(ref, bucket_out))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment