Skip to content

Instantly share code, notes, and snippets.

@wkcn
Created October 17, 2022 09:50
Show Gist options
  • Save wkcn/232d2cf8d50e15cdb38be3e577cc4e3a to your computer and use it in GitHub Desktop.
Save wkcn/232d2cf8d50e15cdb38be3e577cc4e3a to your computer and use it in GitHub Desktop.
FP8GEMM
import torch
import transformer_engine.pytorch.cpp_extensions as texcpp
from transformer_engine.pytorch.module import get_workspace
import transformer_engine_extensions as tex
scale = 1.0
meta = tex.FP8TensorMeta()
meta.scale = torch.ones(1,dtype=torch.float32, device="cuda") * scale
meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale
meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
def cast_to_fp8(x, qtype):
ret = texcpp.cast_to_fp8(x, meta, tex.FP8FwdTensors.GEMM1_INPUT, qtype)
ret._fp8_qtype = qtype
return ret
def cast_from_fp8(x, qtype):
ret = texcpp.cast_from_fp8(x, meta, tex.FP8FwdTensors.GEMM1_INPUT, x._fp8_qtype, qtype)
ret._fp8_qtype = qtype
return ret
one_scale_inv = torch.ones(1, dtype=torch.float32, device="cuda")
empty_tensor = torch.Tensor()
workspace = get_workspace()
assert workspace.is_cuda
PT_DType = dict([(v, k) for k, v in texcpp.TE_DType.items()])
PT_DType[tex.DType.kFloat8E4M3] = torch.uint8
PT_DType[tex.DType.kFloat8E5M2] = torch.uint8
def fp8_gemm(fa, fb, trans_a, trans_b, bias=None, qtype=tex.DType.kFloat32):
'''
# te_gemm
input_A: (A_row, A_col)
input_B: (B_row, B_col)
when transa, transb = True, False
m, k, n = A_row, A_col, B_row
lda, ldb, ldd = A_col, A_col, A_row
output_D: (B_row, A_row)
when transa, transb = False, False
m, k, n = A_col, A_row, B_row
lda, ldb, ldd = A_col, A_row, A_col
output_D: (B_row, A_col)
when transa, transb = False, True
m, k, n = A_col, A_row, B_col
lda, ldb, ldd = A_col, B_col, A_col
output_D: (B_col, A_col)
'''
assert fa.is_cuda and fb.is_cuda
assert fa.is_contiguous()
assert fb.is_contiguous()
device = fa.device
fa_qtype, fb_qtype = fa._fp8_qtype, fb._fp8_qtype
A_row, A_col = fa.shape
B_row, B_col = fb.shape
if trans_a and not trans_b:
assert A_col == B_col
C_row, C_col = B_row, A_row
elif not trans_a and not trans_b:
assert A_row == B_col
C_row, C_col = B_row, A_col
elif not trans_a and trans_b:
assert A_row == B_row
C_row, C_col = B_col, A_col
out_shape = (C_row, C_col)
dtype = PT_DType[qtype]
out = torch.empty(out_shape, dtype=dtype, device=device)
# te_gemm is column-order.
tex.te_gemm(
fa, one_scale_inv, fa_qtype, trans_a,
fb, one_scale_inv, fb_qtype, trans_b,
out, qtype,
bias or empty_tensor, empty_tensor, False,
workspace, workspace.shape[0],
False, True,
)
out._fp8_qtype = qtype
return out
def fp8_matmul(fa, fb, bias=None, qtype=tex.DType.kFloat32):
# trans_a = False and trans_b = False is not implemented.
fb_qtype = fb._fp8_qtype
fb = fb.T.contiguous()
fb._fp8_qtype = fb_qtype
return fp8_gemm(fb, fa, trans_a=True, trans_b=False, bias=bias, qtype=qtype)
if __name__ == '__main__':
a = torch.randn(128, 128).cuda()
b = torch.randn(128, 128).cuda()
fa = cast_to_fp8(a, tex.DType.kFloat8E4M3)
fb = cast_to_fp8(b, tex.DType.kFloat8E4M3)
qa = cast_from_fp8(fa, tex.DType.kFloat32)
qb = cast_from_fp8(fb, tex.DType.kFloat32)
qc = torch.matmul(qa, qb)
qc2 = fp8_matmul(fa, fb, qtype=tex.DType.kFloat16)
# E4M3/E5M2 @ E4M3/E5M2 = FP16/FP32
print(qc, qc2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment