Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shunting314/d43b74e680b3e5b514f7c28160c39f40 to your computer and use it in GitHub Desktop.
Save shunting314/d43b74e680b3e5b514f7c28160c39f40 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_shunting/ka/ckaipznkcgdkakijqawfivs5guxkcpaf7qxx3gyvgtkzpfcd7fty.py
# Source Nodes: [scatter_], Original ATen: [aten.scatter]
# scatter_ => scatter
triton_poi_fused_scatter_0 = async_compile.triton('triton_poi_fused_scatter_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[2097152],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_scatter_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.016777216},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_scatter_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2097152
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tl.store(out_ptr0 + (x0), tmp0, None)
def get_args():
arg_0 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_scatter_0.run(*args, 2097152, grid=grid(2097152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_scatter_0.benchmark_all_configs(*args, 2097152, grid=grid(2097152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.016777216
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_shunting/nr/cnrg6bmqxnledenbwbixhzmxb5ubstwlqbi6pczolpdtgnrybx2c.py
# Source Nodes: [scatter_], Original ATen: [aten.scatter]
# scatter_ => scatter
triton_poi_fused_scatter_1 = async_compile.triton('triton_poi_fused_scatter_1', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[1024],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_scatter_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 1.2288e-05},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_scatter_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl.full([XBLOCK], 2048, tl.int32)
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert(((0 <= tmp4) & (tmp4 < 2048)) | ~(xmask), "index out of bounds: 0 <= tmp4 < 2048")
tmp6 = 0.618
tl.store(out_ptr0 + (tmp4 + (2048*x0)), tmp6, xmask)
def get_args():
arg_0 = rand_strided((1024, 1), (1, 1), device='cuda:0', dtype=torch.int64)
arg_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_scatter_1.run(*args, 1024, grid=grid(1024), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_scatter_1.benchmark_all_configs(*args, 1024, grid=grid(1024))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 1.2288e-05
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/x5/cx5hm5eodyydcdxjpces4m3nfjghrt6qyi5hegovyzbgdk4y3vra.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_2 = async_compile.triton('triton_poi_fused_2', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
@triton_heuristics.pointwise(
size_hints=[2097152],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_2', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '196B341D951BBDD96DE5B2B44B37054C2CBA8E89494C8B8A1052A863B8BC7596', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.008388608},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2097152
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tl.store(out_ptr0 + (x0), tmp0, None)
def get_args():
arg_0 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_2.run(*args, 2097152, grid=grid(2097152), stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_2.benchmark_all_configs(*args, 2097152, grid=grid(2097152))
if __name__ == '__main__':
from triton.testing import do_bench
args = get_args()
ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
num_gb = 0.008388608
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (1024, 2048), (2048, 1))
assert_size_stride(arg1_1, (1024, 1), (1, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((1024, 2048), (2048, 1), torch.float32)
# Source Nodes: [scatter_], Original ATen: [aten.scatter]
stream0 = get_raw_stream(0)
triton_poi_fused_scatter_0.run(arg0_1, buf0, 2097152, grid=grid(2097152), stream=stream0)
# Source Nodes: [scatter_], Original ATen: [aten.scatter]
triton_poi_fused_scatter_1.run(arg1_1, buf0, 1024, grid=grid(1024), stream=stream0)
del arg1_1
# Source Nodes: [], Original ATen: []
triton_poi_fused_2.run(buf0, arg0_1, 2097152, grid=grid(2097152), stream=stream0)
del buf0
return (arg0_1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg1_1 = rand_strided((1024, 1), (1, 1), device='cuda:0', dtype=torch.int64)
fn = lambda: call([arg0_1, arg1_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment