Skip to content

Instantly share code, notes, and snippets.

@rsuderman
Last active June 10, 2024 21:10
Show Gist options
  • Save rsuderman/ca2dbf8d998e34c4880a51fb94fceb85 to your computer and use it in GitHub Desktop.
Save rsuderman/ca2dbf8d998e34c4880a51fb94fceb85 to your computer and use it in GitHub Desktop.
Matmul per channel quant
import matplotlib.pyplot as plt
import torch
A_SHAPE = (8, 128)
B_SHAPE = (16, 128)
torch.manual_seed(12345)
A_QUANT = torch.rand((A_SHAPE[0],1), dtype=torch.float)
B_QUANT = torch.rand((B_SHAPE[0],1), dtype=torch.float)
def generate_input(shape):
M = torch.rand(shape, dtype=torch.float)
return M
A = generate_input(A_SHAPE)
B = generate_input(B_SHAPE)
A = A * A_QUANT
B = B * B_QUANT
def quant_i8_per_tensor(A):
SCALE = torch.full((A.shape[0],), 255.0)
SCALE = SCALE / torch.max(A)
A = A * SCALE
A = torch.round(A)
A = torch.clamp(A, 0, 255.0)
A = A.to(torch.uint8)
return A, SCALE
def quant_i8_per_channel(A,axis=1):
SCALE = torch.full((A.shape[0],), 255.0)
SCALE = SCALE / torch.max(A, axis=axis).values
A = A * SCALE.unsqueeze(axis)
A = torch.round(A)
A = torch.clamp(A, 0, 255.0)
A = A.to(torch.uint8)
return A, SCALE
def mmt_quant(A, B, per_tensor=True):
if per_tensor:
A, A_SCALE = quant_i8_per_channel(A)
B, B_SCALE = quant_i8_per_channel(B)
else:
A, A_SCALE = quant_i8_per_tensor(A)
B, B_SCALE = quant_i8_per_tensor(B)
B = torch.transpose(B, 0, 1)
A = A.to(torch.float)
B = B.to(torch.float)
MM = torch.mm(A, B)
MM = MM / A_SCALE.unsqueeze(1)
MM = MM / B_SCALE.unsqueeze(0)
return MM
def mmt_float(A, B):
B = torch.transpose(B, 0, 1)
return torch.mm(A, B)
OUT_QUANT = mmt_quant(A, B)
OUT_FLOAT = mmt_float(A, B)
OUT_DIFF = OUT_QUANT - OUT_FLOAT
print(torch.mean(torch.abs(OUT_FLOAT)).item())
print(torch.mean(torch.abs(OUT_QUANT)).item())
print(torch.mean(torch.abs(OUT_DIFF)).item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment