-
-
Save gau-nernst/4afd0f5b97368ecf26d54b5f3415b004 to your computer and use it in GitHub Desktop.
Mixed BF16-FP8 matmul
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# AOT ID: ['0_inference'] | |
from ctypes import c_void_p, c_long | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from math import inf, nan | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch._inductor.codegen.memory_planning import _align as align | |
from torch import device, empty_strided | |
from torch._inductor.async_compile import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
aten = torch.ops.aten | |
inductor_ops = torch.ops.inductor | |
_quantized = torch.ops._quantized | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
async_compile = AsyncCompile() | |
# kernel path: /tmp/torchinductor_thien/h6/ch6v4627xtcehdofkgmtn7w3ecfczhxjinkk2d34mkf6gjaikpty.py | |
# Source Nodes: [to], Original ATen: [aten._to_copy] | |
# to => convert_element_type | |
triton_poi_fused__to_copy_0 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[1024], | |
filename=__file__, | |
triton_meta={'signature': {0: '*fp8e5', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=66), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '3944C5C5320391B1D24D8C144564E62E4ADD46D971AAE933FEB74FF902064863', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
min_elem_per_thread=4 | |
) | |
@triton.jit | |
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 1024 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), xmask) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, xmask) | |
''', device_str='cuda') | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
# kernel path: /tmp/torchinductor_thien/bf/cbfq6gajc3ohuwxdieudzreasp6klqlkxhhtnhyyo7gqd7bbm3r3.py | |
# Source Nodes: [mul], Original ATen: [aten.mul] | |
# mul => mul | |
triton_poi_fused_mul_1 = async_compile.triton('triton_', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties | |
@triton_heuristics.pointwise( | |
size_hints=[32], | |
filename=__file__, | |
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=66), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '3944C5C5320391B1D24D8C144564E62E4ADD46D971AAE933FEB74FF902064863', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 32 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32) | |
tmp1 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32) | |
tmp2 = tmp0 * tmp1 | |
tl.store(in_out_ptr0 + (x0), tmp2, xmask) | |
''', device_str='cuda') | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
arg0_1, arg1_1, arg2_1 = args | |
args.clear() | |
assert_size_stride(arg0_1, (32, 32), (32, 1)) | |
assert_size_stride(arg1_1, (1, 32), (32, 1)) | |
assert_size_stride(arg2_1, (32, ), (1, )) | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
buf0 = empty_strided_cuda((32, 32), (32, 1), torch.bfloat16) | |
# Source Nodes: [to], Original ATen: [aten._to_copy] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_0.run(arg0_1, buf0, 1024, grid=grid(1024), stream=stream0) | |
del arg0_1 | |
buf1 = empty_strided_cuda((1, 32), (32, 1), torch.bfloat16) | |
# Source Nodes: [mm, to], Original ATen: [aten._to_copy, aten.mm] | |
extern_kernels.mm(arg1_1, buf0, out=buf1) | |
del arg1_1 | |
del buf0 | |
buf2 = buf1; del buf1 # reuse | |
# Source Nodes: [mul], Original ATen: [aten.mul] | |
triton_poi_fused_mul_1.run(buf2, arg2_1, 32, grid=grid(32), stream=stream0) | |
del arg2_1 | |
return (buf2, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
arg0_1 = rand_strided((32, 32), (32, 1), device='cuda:0', dtype=torch.float8_e5m2) | |
arg1_1 = rand_strided((1, 32), (32, 1), device='cuda:0', dtype=torch.bfloat16) | |
arg2_1 = rand_strided((32, ), (1, ), device='cuda:0', dtype=torch.bfloat16) | |
fn = lambda: call([arg0_1, arg1_1, arg2_1]) | |
return print_performance(fn, times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.wrapper_benchmark import compiled_module_main | |
compiled_module_main('None', benchmark_compiled_module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch._inductor.config | |
# core dump if either of these flags are enabled | |
# torch._inductor.config.force_mixed_mm = True | |
# torch._inductor.config.use_mixed_mm = True | |
def f(a, b, s): | |
return torch.mm(a, b.to(a.dtype)) * s | |
fp16_act = torch.randn(1, 32).to(torch.bfloat16).cuda() | |
fp8_weight = torch.randn(32, 32).to(torch.float8_e5m2).cuda() | |
scales = torch.randn(32).to(torch.bfloat16).cuda() | |
torch.compile(f, mode="max-autotune", fullgraph=True)(fp16_act, fp8_weight, scales) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment