Skip to content

Instantly share code, notes, and snippets.

@mobicham
Created September 3, 2024 15:30
Show Gist options
  • Save mobicham/d08684728660f1cafbce94e4e69f7576 to your computer and use it in GitHub Desktop.
Save mobicham/d08684728660f1cafbce94e4e69f7576 to your computer and use it in GitHub Desktop.
import torch
dtype = torch.bfloat16
device = 'cuda:0'
torch.manual_seed(100)
M, K, N = 128, 8192, 4096
percent_outliers = 0.10 #0.10 #0.10
def quantize_int8(data, axis):
scales = data.abs().amax(axis=axis, keepdim=True) / 128.
data_int8 = (data / scales).round().to(torch.int8)
return data_int8, scales
weight = torch.randn((K, N), dtype=dtype, device=device) / 20.
x = torch.randn((M, K), dtype=dtype, device=device) / 20.
if(percent_outliers>0):
x[torch.rand(x.shape)<percent_outliers] = 1000. #outliers
y_ref = torch.matmul(x, weight)
weight_int8, weight_scales = quantize_int8(weight, axis=1)
x_int8, x_scales = quantize_int8(x, axis=1)
y_w8a16 = torch.matmul(x, weight_int8 * weight_scales)
y_w8a8 = torch.matmul(x_int8 * x_scales, weight_int8 * weight_scales)
print('W8A16 error', 100*(y_ref - y_w8a16).abs().mean()/y_ref.max() )
print('W8A8 error', 100*(y_ref - y_w8a8).abs().mean()/y_ref.max() )
#No outliers
# W8A16 error tensor(1.1172, device='cuda:0', dtype=torch.bfloat16)
# W8A8 error tensor(1.5000, device='cuda:0', dtype=torch.bfloat16)
#10% outliers
# W8A16 error tensor(0.8320, device='cuda:0', dtype=torch.bfloat16)
# W8A8 error tensor(36.5000, device='cuda:0', dtype=torch.bfloat16)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment