Skip to content

Instantly share code, notes, and snippets.

@mlazos
Created May 1, 2024 23:30
Show Gist options
  • Save mlazos/8797aa4e93efeab8a1fe8396cfd61024 to your computer and use it in GitHub Desktop.
Save mlazos/8797aa4e93efeab8a1fe8396cfd61024 to your computer and use it in GitHub Desktop.
`--> TORCH_LOGS="output_code" python optim_repro.py
[WARNING]:Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[DEBUG]:Output code:
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_mlazos/qw/cqwdta64dilqxthaf7r7oq6533jvl3wc6hr6slcetf5nno7l7mkg.py
# Source Nodes: [], Original ATen: []
triton_for_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.foreach(
num_warps=8,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'kernel_name': 'triton_for_fused_0', 'backend_hash': '63937d058519033f995f0585a4aab6c8c8898fe6839dd14ce1536da9b902b160', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2):
xpid = tl.program_id(0)
XBLOCK: tl.constexpr = 1024
if xpid >= 0 and xpid < 1:
xpid_offset = xpid - 0
xnumel = 1
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tmp2 = 1.0
tmp3 = tmp1 + tmp2
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp3, None)
elif xpid >= 1 and xpid < 2:
xpid_offset = xpid - 1
xnumel = 1
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp4 = tl.load(in_ptr1 + (0))
tmp5 = tl.broadcast_to(tmp4, [XBLOCK])
tmp6 = 1.0
tmp7 = tmp5 + tmp6
tl.store(out_ptr1 + (tl.full([XBLOCK], 0, tl.int32)), tmp7, None)
elif xpid >= 2 and xpid < 3:
xpid_offset = xpid - 2
xnumel = 1
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp8 = tl.load(in_ptr2 + (0))
tmp9 = tl.broadcast_to(tmp8, [XBLOCK])
tmp10 = 1.0
tmp11 = tmp9 + tmp10
tl.store(out_ptr2 + (tl.full([XBLOCK], 0, tl.int32)), tmp11, None)
else:
pass
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_mlazos/f6/cf6mqlvcaho5wytbkdyareia6v2d32yctobezzb3f2ttjj6mskmc.py
# Source Nodes: [], Original ATen: []
triton_for_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.foreach(
num_warps=8,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: '*fp32', 11: '*fp32', 12: '*fp32', 13: '*fp32', 14: '*fp32', 15: '*fp32', 16: '*fp32', 17: '*fp32', 18: '*fp32', 19: '*fp32', 20: '*fp32', 21: '*fp32', 22: '*fp32', 23: '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'kernel_name': 'triton_for_fused_1', 'backend_hash': '63937d058519033f995f0585a4aab6c8c8898fe6839dd14ce1536da9b902b160', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, out_ptr1, out_ptr2, out_ptr3, out_ptr5, out_ptr6, out_ptr7, out_ptr9, out_ptr10, out_ptr11):
xpid = tl.program_id(0)
XBLOCK: tl.constexpr = 1024
if xpid >= 0 and xpid < 1:
xpid_offset = xpid - 0
xnumel = 20
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp3 = tl.load(in_ptr1 + (x0), xmask)
tmp8 = tl.load(in_ptr2 + (x0), xmask)
tmp13 = tl.load(in_ptr3 + (x0), xmask)
tmp15 = tl.load(in_ptr4 + (0))
tmp16 = tl.broadcast_to(tmp15, [XBLOCK])
tmp1 = 0.999
tmp2 = tmp0 * tmp1
tmp4 = tmp3 * tmp3
tmp5 = 0.0010000000000000009
tmp6 = tmp4 * tmp5
tmp7 = tmp2 + tmp6
tmp9 = tmp3 - tmp8
tmp10 = 0.09999999999999998
tmp11 = tmp9 * tmp10
tmp12 = tmp8 + tmp11
tmp14 = libdevice.sqrt(tmp7)
tmp17 = libdevice.pow(tmp1, tmp16)
tmp18 = 1.0
tmp19 = tmp17 - tmp18
tmp20 = -tmp19
tmp21 = libdevice.sqrt(tmp20)
tmp22 = tmp14 / tmp21
tmp23 = 1e-08
tmp24 = tmp22 + tmp23
tmp25 = 0.9
tmp26 = libdevice.pow(tmp25, tmp16)
tmp27 = tmp26 - tmp18
tmp28 = 1000.0
tmp29 = tmp27 * tmp28
tmp30 = 1 / tmp29
tmp31 = tmp24 / tmp30
tmp32 = tmp12 / tmp31
tmp33 = tmp13 + tmp32
tl.store(out_ptr1 + (x0), tmp12, xmask)
tl.store(out_ptr2 + (x0), tmp33, xmask)
tl.store(out_ptr3 + (x0), tmp7, xmask)
elif xpid >= 1 and xpid < 2:
xpid_offset = xpid - 1
xnumel = 20
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = xindex
tmp34 = tl.load(in_ptr5 + (x1), xmask)
tmp37 = tl.load(in_ptr6 + (x1), xmask)
tmp42 = tl.load(in_ptr7 + (x1), xmask)
tmp47 = tl.load(in_ptr8 + (x1), xmask)
tmp49 = tl.load(in_ptr9 + (0))
tmp50 = tl.broadcast_to(tmp49, [XBLOCK])
tmp35 = 0.999
tmp36 = tmp34 * tmp35
tmp38 = tmp37 * tmp37
tmp39 = 0.0010000000000000009
tmp40 = tmp38 * tmp39
tmp41 = tmp36 + tmp40
tmp43 = tmp37 - tmp42
tmp44 = 0.09999999999999998
tmp45 = tmp43 * tmp44
tmp46 = tmp42 + tmp45
tmp48 = libdevice.sqrt(tmp41)
tmp51 = libdevice.pow(tmp35, tmp50)
tmp52 = 1.0
tmp53 = tmp51 - tmp52
tmp54 = -tmp53
tmp55 = libdevice.sqrt(tmp54)
tmp56 = tmp48 / tmp55
tmp57 = 1e-08
tmp58 = tmp56 + tmp57
tmp59 = 0.9
tmp60 = libdevice.pow(tmp59, tmp50)
tmp61 = tmp60 - tmp52
tmp62 = 1000.0
tmp63 = tmp61 * tmp62
tmp64 = 1 / tmp63
tmp65 = tmp58 / tmp64
tmp66 = tmp46 / tmp65
tmp67 = tmp47 + tmp66
tl.store(out_ptr5 + (x1), tmp46, xmask)
tl.store(out_ptr6 + (x1), tmp67, xmask)
tl.store(out_ptr7 + (x1), tmp41, xmask)
elif xpid >= 2 and xpid < 3:
xpid_offset = xpid - 2
xnumel = 20
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
tmp68 = tl.load(in_ptr10 + (x2), xmask)
tmp71 = tl.load(in_ptr11 + (x2), xmask)
tmp76 = tl.load(in_ptr12 + (x2), xmask)
tmp81 = tl.load(in_ptr13 + (x2), xmask)
tmp83 = tl.load(in_ptr14 + (0))
tmp84 = tl.broadcast_to(tmp83, [XBLOCK])
tmp69 = 0.999
tmp70 = tmp68 * tmp69
tmp72 = tmp71 * tmp71
tmp73 = 0.0010000000000000009
tmp74 = tmp72 * tmp73
tmp75 = tmp70 + tmp74
tmp77 = tmp71 - tmp76
tmp78 = 0.09999999999999998
tmp79 = tmp77 * tmp78
tmp80 = tmp76 + tmp79
tmp82 = libdevice.sqrt(tmp75)
tmp85 = libdevice.pow(tmp69, tmp84)
tmp86 = 1.0
tmp87 = tmp85 - tmp86
tmp88 = -tmp87
tmp89 = libdevice.sqrt(tmp88)
tmp90 = tmp82 / tmp89
tmp91 = 1e-08
tmp92 = tmp90 + tmp91
tmp93 = 0.9
tmp94 = libdevice.pow(tmp93, tmp84)
tmp95 = tmp94 - tmp86
tmp96 = 1000.0
tmp97 = tmp95 * tmp96
tmp98 = 1 / tmp97
tmp99 = tmp92 / tmp98
tmp100 = tmp80 / tmp99
tmp101 = tmp81 + tmp100
tl.store(out_ptr9 + (x2), tmp80, xmask)
tl.store(out_ptr10 + (x2), tmp101, xmask)
tl.store(out_ptr11 + (x2), tmp75, xmask)
else:
pass
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1 = args
args.clear()
assert_size_stride(arg0_1, (4, 5), (5, 1))
assert_size_stride(arg1_1, (4, 5), (5, 1))
assert_size_stride(arg2_1, (4, 5), (5, 1))
assert_size_stride(arg3_1, (), ())
assert_size_stride(arg4_1, (4, 5), (5, 1))
assert_size_stride(arg5_1, (4, 5), (5, 1))
assert_size_stride(arg6_1, (4, 5), (5, 1))
assert_size_stride(arg7_1, (4, 5), (5, 1))
assert_size_stride(arg8_1, (4, 5), (5, 1))
assert_size_stride(arg9_1, (4, 5), (5, 1))
assert_size_stride(arg10_1, (), ())
assert_size_stride(arg11_1, (), ())
assert_size_stride(arg12_1, (4, 5), (5, 1))
assert_size_stride(arg13_1, (4, 5), (5, 1))
assert_size_stride(arg14_1, (4, 5), (5, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_for_fused_0.run(arg10_1, arg3_1, arg11_1, arg10_1, arg3_1, arg11_1, grid=((3, 1, 1)), stream=stream0)
# Source Nodes: [], Original ATen: []
triton_for_fused_1.run(arg8_1, arg12_1, arg6_1, arg0_1, arg10_1, arg5_1, arg13_1, arg4_1, arg1_1, arg3_1, arg9_1, arg14_1, arg7_1, arg2_1, arg11_1, arg6_1, arg0_1, arg8_1, arg4_1, arg1_1, arg5_1, arg7_1, arg2_1, arg9_1, grid=((3, 1, 1)), stream=stream0)
del arg0_1
del arg10_1
del arg11_1
del arg12_1
del arg13_1
del arg14_1
del arg1_1
del arg2_1
del arg3_1
del arg4_1
del arg5_1
del arg6_1
del arg7_1
del arg8_1
del arg9_1
return ()
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
arg1_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
arg2_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
arg3_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
arg4_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
arg5_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
arg6_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
arg7_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
arg8_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
arg9_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
arg10_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
arg11_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
arg12_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
arg13_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
arg14_1 = rand_strided((4, 5), (5, 1), device='cuda:0', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
[INFO]:Output code written to: /tmp/torchinductor_mlazos/ue/cueqxubtnm3far2kh6u4xpvwjaptxhw4kwvxjgaf4ulsgbi26rzx.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment