-
-
Save mobicham/d08684728660f1cafbce94e4e69f7576 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
dtype = torch.bfloat16 | |
device = 'cuda:0' | |
torch.manual_seed(100) | |
M, K, N = 128, 8192, 4096 | |
percent_outliers = 0.10 #0.10 #0.10 | |
def quantize_int8(data, axis): | |
scales = data.abs().amax(axis=axis, keepdim=True) / 128. | |
data_int8 = (data / scales).round().to(torch.int8) | |
return data_int8, scales | |
weight = torch.randn((K, N), dtype=dtype, device=device) / 20. | |
x = torch.randn((M, K), dtype=dtype, device=device) / 20. | |
if(percent_outliers>0): | |
x[torch.rand(x.shape)<percent_outliers] = 1000. #outliers | |
y_ref = torch.matmul(x, weight) | |
weight_int8, weight_scales = quantize_int8(weight, axis=1) | |
x_int8, x_scales = quantize_int8(x, axis=1) | |
y_w8a16 = torch.matmul(x, weight_int8 * weight_scales) | |
y_w8a8 = torch.matmul(x_int8 * x_scales, weight_int8 * weight_scales) | |
print('W8A16 error', 100*(y_ref - y_w8a16).abs().mean()/y_ref.max() ) | |
print('W8A8 error', 100*(y_ref - y_w8a8).abs().mean()/y_ref.max() ) | |
#No outliers | |
# W8A16 error tensor(1.1172, device='cuda:0', dtype=torch.bfloat16) | |
# W8A8 error tensor(1.5000, device='cuda:0', dtype=torch.bfloat16) | |
#10% outliers | |
# W8A16 error tensor(0.8320, device='cuda:0', dtype=torch.bfloat16) | |
# W8A8 error tensor(36.5000, device='cuda:0', dtype=torch.bfloat16) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment