Skip to content

Instantly share code, notes, and snippets.

@gau-nernst
Created June 11, 2024 08:56
Show Gist options
  • Save gau-nernst/4afd0f5b97368ecf26d54b5f3415b004 to your computer and use it in GitHub Desktop.
Save gau-nernst/4afd0f5b97368ecf26d54b5f3415b004 to your computer and use it in GitHub Desktop.
Mixed BF16-FP8 matmul
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_thien/h6/ch6v4627xtcehdofkgmtn7w3ecfczhxjinkk2d34mkf6gjaikpty.py
# Source Nodes: [to], Original ATen: [aten._to_copy]
# to => convert_element_type
triton_poi_fused__to_copy_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[1024],
filename=__file__,
triton_meta={'signature': {0: '*fp8e5', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=66), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '3944C5C5320391B1D24D8C144564E62E4ADD46D971AAE933FEB74FF902064863', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=4
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_thien/bf/cbfq6gajc3ohuwxdieudzreasp6klqlkxhhtnhyyo7gqd7bbm3r3.py
# Source Nodes: [mul], Original ATen: [aten.mul]
# mul => mul
triton_poi_fused_mul_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[32],
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=66), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '3944C5C5320391B1D24D8C144564E62E4ADD46D971AAE933FEB74FF902064863', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 32
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp2 = tmp0 * tmp1
tl.store(in_out_ptr0 + (x0), tmp2, xmask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1 = args
args.clear()
assert_size_stride(arg0_1, (32, 32), (32, 1))
assert_size_stride(arg1_1, (1, 32), (32, 1))
assert_size_stride(arg2_1, (32, ), (1, ))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((32, 32), (32, 1), torch.bfloat16)
# Source Nodes: [to], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_0.run(arg0_1, buf0, 1024, grid=grid(1024), stream=stream0)
del arg0_1
buf1 = empty_strided_cuda((1, 32), (32, 1), torch.bfloat16)
# Source Nodes: [mm, to], Original ATen: [aten._to_copy, aten.mm]
extern_kernels.mm(arg1_1, buf0, out=buf1)
del arg1_1
del buf0
buf2 = buf1; del buf1 # reuse
# Source Nodes: [mul], Original ATen: [aten.mul]
triton_poi_fused_mul_1.run(buf2, arg2_1, 32, grid=grid(32), stream=stream0)
del arg2_1
return (buf2, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((32, 32), (32, 1), device='cuda:0', dtype=torch.float8_e5m2)
arg1_1 = rand_strided((1, 32), (32, 1), device='cuda:0', dtype=torch.bfloat16)
arg2_1 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
fn = lambda: call([arg0_1, arg1_1, arg2_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
import torch
import torch._inductor.config
# core dump if either of these flags are enabled
# torch._inductor.config.force_mixed_mm = True
# torch._inductor.config.use_mixed_mm = True
def f(a, b, s):
return torch.mm(a, b.to(a.dtype)) * s
fp16_act = torch.randn(1, 32).to(torch.bfloat16).cuda()
fp8_weight = torch.randn(32, 32).to(torch.float8_e5m2).cuda()
scales = torch.randn(32).to(torch.bfloat16).cuda()
torch.compile(f, mode="max-autotune", fullgraph=True)(fp16_act, fp8_weight, scales)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment