Skip to content

Instantly share code, notes, and snippets.

@jansel
Last active August 27, 2022 21:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jansel/f4af078791ad681a0d4094adeb844396 to your computer and use it in GitHub Desktop.
Save jansel/f4af078791ad681a0d4094adeb844396 to your computer and use it in GitHub Desktop.
TORCHINDUCTOR_TRACE=1 example output
1380522 function calls (1367278 primitive calls) in 3.192 seconds
Ordered by: cumulative time
List reduced from 2159 to 100 due to restriction <100>
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 3.193 3.193 compile_fx.py:39(compile_fx_inner)
1 0.000 0.000 2.741 2.741 compile_fx.py:74(cudagraphify)
2 0.000 0.000 2.690 1.345 crhwqgmbqtchqt3v3wdeeszjb352m4vbjbvdovaaeqpzi7tdjxqr.py:99(call)
2 0.000 0.000 2.689 1.344 code_gen.py:998(__call__)
2 0.000 0.000 2.689 1.344 code_gen.py:1056(__call__)
1 0.000 0.000 2.688 2.688 code_gen.py:1073(<dictcomp>)
4 0.000 0.000 2.688 0.672 autotune.py:26(_bench)
4 0.004 0.001 2.688 0.672 code_gen.py:1037(_bench)
4 0.015 0.004 2.684 0.671 testing.py:119(do_bench)
8955 0.021 0.000 2.029 0.000 code_gen.py:1049(kernel_call)
8957 0.059 0.000 2.008 0.000 code_gen.py:959(__call__)
8957 1.728 0.000 1.737 0.000 {built-in method triton._C.libtriton.triton.runtime.launch}
15 0.000 0.000 0.345 0.023 __init__.py:491(synchronize)
15 0.345 0.023 0.345 0.023 {built-in method torch._C._cuda_synchronize}
1 0.000 0.000 0.290 0.290 graph.py:320(compile_to_fn)
1 0.000 0.000 0.290 0.290 graph.py:304(compile_to_module)
1 0.000 0.000 0.265 0.265 graph.py:296(codegen)
14302 0.015 0.000 0.217 0.000 streams.py:173(record)
5 0.000 0.000 0.206 0.041 subprocess.py:736(__init__)
5 0.000 0.000 0.206 0.041 subprocess.py:1552(_execute_child)
4 0.000 0.000 0.177 0.044 subprocess.py:452(run)
1 0.000 0.000 0.177 0.177 scheduler.py:500(__init__)
14304 0.022 0.000 0.165 0.000 __init__.py:517(current_stream)
1 0.000 0.000 0.132 0.132 code_gen.py:1201(cache_key)
1 0.001 0.001 0.130 0.130 code_gen.py:1093(version_key)
14 0.128 0.009 0.128 0.009 {built-in method posix.read}
86/38 0.000 0.000 0.123 0.003 {built-in method builtins.exec}
23282 0.025 0.000 0.120 0.000 _utils.py:7(_get_device_index)
57/4 0.000 0.000 0.118 0.030 <frozen importlib._bootstrap>:986(_find_and_load)
57/4 0.000 0.000 0.118 0.030 <frozen importlib._bootstrap>:956(_find_and_load_unlocked)
54/4 0.000 0.000 0.117 0.029 <frozen importlib._bootstrap>:650(_load_unlocked)
51/4 0.000 0.000 0.117 0.029 <frozen importlib._bootstrap_external>:837(exec_module)
69/4 0.000 0.000 0.116 0.029 <frozen importlib._bootstrap>:211(_call_with_frames_removed)
5 0.000 0.000 0.114 0.023 __init__.py:1(<module>)
1 0.000 0.000 0.100 0.100 debug.py:312(graph_diagram)
1 0.000 0.000 0.100 0.100 debug.py:51(draw_buffers)
1 0.000 0.000 0.096 0.096 debug.py:296(fx_graph)
1 0.000 0.000 0.096 0.096 debug_utils.py:104(save_graph_repro)
1 0.000 0.000 0.096 0.096 debug_utils.py:22(generate_repro_string)
23282 0.018 0.000 0.086 0.000 _utils.py:591(_get_device_index)
2 0.000 0.000 0.082 0.041 subprocess.py:368(check_output)
5 0.076 0.015 0.076 0.015 {built-in method _posixsubprocess.fork_exec}
13 0.071 0.005 0.071 0.005 {built-in method _hashlib.openssl_md5}
14/6 0.000 0.000 0.066 0.011 {built-in method builtins.__import__}
1 0.000 0.000 0.065 0.065 graph.py:125(run)
1 0.000 0.000 0.065 0.065 interpreter.py:95(run)
28 0.000 0.000 0.065 0.002 graph.py:283(run_node)
14321 0.006 0.000 0.063 0.000 _utils.py:565(_get_current_device_index)
28 0.000 0.000 0.062 0.002 interpreter.py:144(run_node)
1 0.000 0.000 0.057 0.057 partitioners.py:427(draw_graph)
14321 0.012 0.000 0.057 0.000 _utils.py:555(_get_device_attr)
4683/2928 0.003 0.000 0.056 0.000 cache.py:67(wrapper)
1 0.000 0.000 0.055 0.055 pydot.py:1739(new_method)
2/1 0.000 0.000 0.055 0.055 pydot.py:1794(write)
1 0.000 0.000 0.055 0.055 pydot.py:1833(create)
7 0.000 0.000 0.055 0.008 scheduler.py:218(__init__)
1 0.000 0.000 0.053 0.053 pydot.py:113(call_graphviz)
1 0.000 0.000 0.053 0.053 wrapper.py:168(__init__)
405 0.000 0.000 0.053 0.000 decorators.py:224(_func)
1 0.000 0.000 0.052 0.052 utils.py:18(has_triton)
7146 0.052 0.000 0.052 0.000 {method 'zero_' of 'torch._C._TensorBase' objects}
30 0.000 0.000 0.052 0.002 code_gen.py:1442(jit)
30 0.000 0.000 0.052 0.002 code_gen.py:1169(__init__)
381 0.000 0.000 0.050 0.000 decorators.py:99(binary_op_wrapper)
1 0.000 0.000 0.050 0.050 graphs.py:145(__enter__)
21 0.000 0.000 0.049 0.002 ir.py:3290(__call__)
1 0.049 0.049 0.049 0.049 {built-in method gc.collect}
175 0.000 0.000 0.048 0.000 expr.py:215(__mul__)
22/20 0.000 0.000 0.048 0.002 operations.py:52(__new__)
5 0.000 0.000 0.048 0.010 graph.py:193(placeholder)
5 0.000 0.000 0.048 0.010 graph.py:33(symbolic_sizes_strides)
30 0.000 0.000 0.048 0.002 inspect.py:991(getsource)
30 0.000 0.000 0.048 0.002 inspect.py:970(getsourcelines)
2 0.000 0.000 0.048 0.024 mul.py:197(flatten)
6 0.000 0.000 0.048 0.008 add.py:184(flatten)
4 0.000 0.000 0.048 0.012 mul.py:467(_gather)
21 0.000 0.000 0.047 0.002 ir.py:3394(__call__)
1 0.000 0.000 0.046 0.046 tensor.py:1(<module>)
30 0.004 0.000 0.046 0.002 inspect.py:959(getblock)
1 0.000 0.000 0.041 0.041 debug.py:42(has_dot)
7 0.000 0.000 0.039 0.006 ir.py:1851(simplify_and_reorder)
9070 0.014 0.000 0.039 0.000 tokenize.py:429(_tokenize)
1 0.000 0.000 0.038 0.038 polyhedron.py:1(<module>)
92/82 0.001 0.000 0.038 0.000 <frozen importlib._bootstrap>:1017(_handle_fromlist)
14302 0.038 0.000 0.038 0.000 {function Event.record at 0x7f321a0c8a60}
930 0.010 0.000 0.036 0.000 basic.py:802(subs)
23293 0.011 0.000 0.034 0.000 __init__.py:485(current_device)
1 0.000 0.000 0.032 0.032 scheduler.py:987(codegen)
1 0.000 0.000 0.031 0.031 triton.py:1032(codegen_nodes)
1 0.000 0.000 0.031 0.031 triton.py:1101(codegen_node_schedule)
8957 0.006 0.000 0.029 0.000 __init__.py:307(set_device)
61 0.028 0.000 0.028 0.000 {method 'read' of '_io.BufferedReader' objects}
14 0.000 0.000 0.027 0.002 ir.py:3227(__init__)
14 0.000 0.000 0.027 0.002 ir.py:3311(__init__)
14306 0.014 0.000 0.027 0.000 streams.py:31(__new__)
2 0.000 0.000 0.025 0.013 __init__.py:2(<module>)
14321 0.005 0.000 0.025 0.000 _utils.py:567(<lambda>)
5 0.000 0.000 0.025 0.005 subprocess.py:984(communicate)
1 0.000 0.000 0.025 0.025 codecache.py:152(load)
7 0.000 0.000 0.024 0.003 scheduler.py:313(codegen)
1380522 function calls (1367278 primitive calls) in 3.192 seconds
Ordered by: internal time
List reduced from 2159 to 100 due to restriction <100>
ncalls tottime percall cumtime percall filename:lineno(function)
8957 1.728 0.000 1.737 0.000 {built-in method triton._C.libtriton.triton.runtime.launch}
15 0.345 0.023 0.345 0.023 {built-in method torch._C._cuda_synchronize}
14 0.128 0.009 0.128 0.009 {built-in method posix.read}
5 0.076 0.015 0.076 0.015 {built-in method _posixsubprocess.fork_exec}
13 0.071 0.005 0.071 0.005 {built-in method _hashlib.openssl_md5}
8957 0.059 0.000 2.008 0.000 code_gen.py:959(__call__)
7146 0.052 0.000 0.052 0.000 {method 'zero_' of 'torch._C._TensorBase' objects}
1 0.049 0.049 0.049 0.049 {built-in method gc.collect}
14302 0.038 0.000 0.038 0.000 {function Event.record at 0x7f321a0c8a60}
61 0.028 0.000 0.028 0.000 {method 'read' of '_io.BufferedReader' objects}
23282 0.025 0.000 0.120 0.000 _utils.py:7(_get_device_index)
14304 0.022 0.000 0.165 0.000 __init__.py:517(current_stream)
39479 0.021 0.000 0.021 0.000 {built-in method __new__ of type object at 0x560c66fc39a0}
239119/238013 0.021 0.000 0.023 0.000 {built-in method builtins.isinstance}
8955 0.021 0.000 2.029 0.000 code_gen.py:1049(kernel_call)
23282 0.018 0.000 0.086 0.000 _utils.py:591(_get_device_index)
14302 0.015 0.000 0.217 0.000 streams.py:173(record)
4 0.015 0.004 2.684 0.671 testing.py:119(do_bench)
9070 0.014 0.000 0.039 0.000 tokenize.py:429(_tokenize)
14306 0.014 0.000 0.027 0.000 streams.py:31(__new__)
7 0.012 0.002 0.012 0.002 {method 'poll' of 'select.poll' objects}
14321 0.012 0.000 0.057 0.000 _utils.py:555(_get_device_attr)
2 0.011 0.005 0.011 0.005 {built-in method _imp.create_dynamic}
23293 0.011 0.000 0.034 0.000 __init__.py:485(current_device)
930 0.010 0.000 0.036 0.000 basic.py:802(subs)
23293 0.009 0.000 0.009 0.000 {built-in method torch._C._cuda_getDevice}
37631 0.009 0.000 0.024 0.000 __init__.py:196(_lazy_init)
9719 0.008 0.000 0.008 0.000 {method 'match' of 're.Pattern' objects}
37632 0.008 0.000 0.015 0.000 __init__.py:153(is_initialized)
8957 0.008 0.000 0.013 0.000 code_gen.py:962(<dictcomp>)
14326 0.007 0.000 0.013 0.000 __init__.py:77(is_available)
17914 0.007 0.000 0.009 0.000 core.py:340(__init__)
37632 0.007 0.000 0.007 0.000 {built-in method torch._C._cuda_isInBadFork}
10199/10163 0.006 0.000 0.011 0.000 {built-in method builtins.sorted}
32 0.006 0.000 0.006 0.000 {built-in method builtins.compile}
14321 0.006 0.000 0.063 0.000 _utils.py:565(_get_current_device_index)
8957 0.006 0.000 0.029 0.000 __init__.py:307(set_device)
14321 0.005 0.000 0.018 0.000 _utils.py:546(_get_available_device_type)
14304 0.005 0.000 0.005 0.000 {built-in method torch._C._cuda_getCurrentStream}
18028 0.005 0.000 0.005 0.000 {method 'index' of 'list' objects}
51 0.005 0.000 0.005 0.000 {built-in method marshal.loads}
8957 0.005 0.000 0.005 0.000 {built-in method torch._C._cuda_setDevice}
20253/19676 0.005 0.000 0.006 0.000 {built-in method builtins.getattr}
14321 0.005 0.000 0.025 0.000 _utils.py:567(<lambda>)
8957 0.005 0.000 0.007 0.000 autotune.py:452(grid_fn)
14302 0.004 0.000 0.011 0.000 streams.py:163(__new__)
4 0.004 0.001 2.688 0.672 code_gen.py:1037(_bench)
8749 0.004 0.000 0.009 0.000 re.py:289(_compile)
30 0.004 0.000 0.046 0.002 inspect.py:959(getblock)
22654 0.004 0.000 0.004 0.000 {built-in method builtins.hasattr}
7644/7464 0.004 0.000 0.010 0.000 sympify.py:102(sympify)
37624 0.004 0.000 0.004 0.000 _jit_internal.py:1082(is_scripting)
7150 0.003 0.000 0.003 0.000 {function Event.elapsed_time at 0x7f321a0c8c10}
3784/1814 0.003 0.000 0.007 0.000 node.py:603(map_aggregate)
9028 0.003 0.000 0.003 0.000 inspect.py:909(tokeneater)
14330 0.003 0.000 0.003 0.000 {built-in method torch._C._cuda_getDeviceCount}
8957 0.003 0.000 0.003 0.000 {built-in method torch._C._cuda_getCurrentRawStream}
34440/33719 0.003 0.000 0.003 0.000 {built-in method builtins.len}
4683/2928 0.003 0.000 0.056 0.000 cache.py:67(wrapper)
1743/1523 0.003 0.000 0.010 0.000 sorting.py:203(ordered)
8550 0.003 0.000 0.012 0.000 tokenize.py:98(_compile)
4 0.003 0.001 0.008 0.002 testing.py:153(<listcomp>)
17973 0.003 0.000 0.003 0.000 {method 'insert' of 'list' objects}
4 0.003 0.001 0.008 0.002 testing.py:154(<listcomp>)
8224 0.003 0.000 0.004 0.000 <frozen importlib._bootstrap>:389(parent)
568/500 0.003 0.000 0.004 0.000 printer.py:294(_print)
501 0.002 0.000 0.006 0.000 basic.py:2016(_aresame)
8957 0.002 0.000 0.002 0.000 code_gen.py:1482(cdiv)
22431 0.002 0.000 0.002 0.000 {method 'items' of 'dict' objects}
980 0.002 0.000 0.006 0.000 sorting.py:10(default_sort_key)
8550 0.002 0.000 0.003 0.000 types.py:171(__get__)
8957 0.002 0.000 0.005 0.000 code_gen.py:32(current_cuda_stream)
15540 0.002 0.000 0.002 0.000 {method 'lower' of 'str' objects}
4478 0.002 0.000 0.006 0.000 numbers.py:2238(__eq__)
8676 0.002 0.000 0.010 0.000 re.py:250(compile)
31 0.002 0.000 0.002 0.000 pydot.py:530(create_attribute_methods)
7150 0.002 0.000 0.005 0.000 streams.py:204(elapsed_time)
1662 0.002 0.000 0.003 0.000 graph.py:123(create_name)
746 0.002 0.000 0.004 0.000 function.py:2495(expand)
6122 0.002 0.000 0.002 0.000 misc.py:491(as_int)
9037 0.002 0.000 0.003 0.000 <string>:1(__new__)
118/117 0.002 0.000 0.005 0.000 {built-in method builtins.__build_class__}
100 0.001 0.000 0.009 0.000 iterables.py:1173(least_rotation)
1470 0.001 0.000 0.003 0.000 basic.py:350(__eq__)
571 0.001 0.000 0.003 0.000 sorting.py:180(_nodes)
1 0.001 0.001 0.130 0.130 code_gen.py:1093(version_key)
459 0.001 0.000 0.004 0.000 permutations.py:384(<listcomp>)
3291 0.001 0.000 0.005 0.000 sympify.py:503(_sympify)
826/376 0.001 0.000 0.004 0.000 _symbolic_trace.py:254(create_arg)
2015 0.001 0.000 0.003 0.000 containers.py:58(__getitem__)
8510 0.001 0.000 0.001 0.000 {method 'rpartition' of 'str' objects}
30 0.001 0.000 0.014 0.000 graph.py:297(_gen_python_code)
4 0.001 0.000 0.006 0.002 testing.py:175(<listcomp>)
4 0.001 0.000 0.001 0.000 {built-in method torch.quantile}
5658 0.001 0.000 0.002 0.000 numbers.py:2284(__hash__)
10 0.001 0.000 0.001 0.000 {built-in method torch.empty_strided}
367 0.001 0.000 0.002 0.000 permutations.py:1316(__mul__)
1730 0.001 0.000 0.003 0.000 numbers.py:1867(__eq__)
8490/6732 0.001 0.000 0.001 0.000 {built-in method builtins.hash}
2950/2720 0.001 0.000 0.003 0.000 containers.py:54(<genexpr>)
[compile_fx.py:48 INFO] Compiling FORWARDS graph
[triton.py:1098 INFO] schedule: [SchedulerNode(name='buf1'), <class 'torchinductor.codegen.triton.DisableReduction'>, <class 'torchinductor.codegen.triton.EnableReduction'>, SchedulerNode(name='buf2'), SchedulerNode(name='buf3'), <class 'torchinductor.codegen.triton.DisableReduction'>, <class 'torchinductor.codegen.triton.EnableReduction'>, SchedulerNode(name='buf4'), <class 'torchinductor.codegen.triton.DisableReduction'>, <class 'torchinductor.codegen.triton.EnableReduction'>, SchedulerNode(name='buf5'), <class 'torchinductor.codegen.triton.DisableReduction'>, <class 'torchinductor.codegen.triton.EnableReduction'>, SchedulerNode(name='buf6'), <class 'torchinductor.codegen.triton.DisableReduction'>, SchedulerNode(name='buf7'), <class 'torchinductor.codegen.triton.EnableReduction'>]
[scheduler.py:944 DEBUG] remove_buffer('buf3')
[scheduler.py:944 DEBUG] remove_buffer('buf2')
[scheduler.py:944 DEBUG] remove_buffer('buf1')
[graph.py:315 INFO] Output code: /tmp/torchinductor_jansel/rh/crhwqgmbqtchqt3v3wdeeszjb352m4vbjbvdovaaeqpzi7tdjxqr.py
[debug.py:260 WARNING] model_forward_0 debug trace: /tmp/torchinductor_jansel/rh/crhwqgmbqtchqt3v3wdeeszjb352m4vbjbvdovaaeqpzi7tdjxqr.debug
import torch
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx
# torch version: 1.13.0a0+gita089114
# torch cuda version: 11.6
# torch git version: a08911400edb62c9caa0c94d1ce176cf8cb29765
# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Tue_Mar__8_18:18:20_PST_2022
# Cuda compilation tools, release 11.6, V11.6.124
# Build cuda_11.6.r11.6/compiler.31057947_0
# GPU Hardware Info:
# NVIDIA GeForce RTX 3090 : 1
class Repro(torch.nn.Module):
def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5):
permute_default = torch.ops.aten.permute.default(primals_1, [1, 0]); primals_1 = None
unsqueeze_default = torch.ops.aten.unsqueeze.default(primals_5, 0); primals_5 = None
mm_default = torch.ops.aten.mm.default(unsqueeze_default, permute_default); permute_default = None
squeeze_dim = torch.ops.aten.squeeze.dim(mm_default, 0); mm_default = None
add_tensor = torch.ops.aten.add.Tensor(squeeze_dim, primals_2); squeeze_dim = primals_2 = None
convert_element_type_default = torch.ops.prims.convert_element_type.default(add_tensor, torch.float32)
var_default = torch.ops.prims.var.default(convert_element_type_default, [0], correction = 0)
broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(var_default, [1], []); var_default = None
sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0]); convert_element_type_default = None
broadcast_in_dim_default_1 = torch.ops.prims.broadcast_in_dim.default(sum_default, [1], []); sum_default = None
div_default = torch.ops.prims.div.default(broadcast_in_dim_default_1, 10.0); broadcast_in_dim_default_1 = None
add_tensor_1 = torch.ops.aten.add.Tensor(broadcast_in_dim_default, 1e-05); broadcast_in_dim_default = None
sqrt_default = torch.ops.aten.sqrt.default(add_tensor_1); add_tensor_1 = None
reciprocal_default = torch.ops.aten.reciprocal.default(sqrt_default); sqrt_default = None
sub_tensor = torch.ops.aten.sub.Tensor(add_tensor, div_default); add_tensor = div_default = None
mul_tensor = torch.ops.aten.mul.Tensor(sub_tensor, reciprocal_default); sub_tensor = None
mul_tensor_1 = torch.ops.aten.mul.Tensor(mul_tensor, primals_3)
add_tensor_2 = torch.ops.aten.add.Tensor(mul_tensor_1, primals_4); mul_tensor_1 = primals_4 = None
convert_element_type_default_2 = torch.ops.prims.convert_element_type.default(add_tensor_2, torch.float32); add_tensor_2 = None
relu_default = torch.ops.aten.relu.default(convert_element_type_default_2); convert_element_type_default_2 = None
le_scalar = torch.ops.aten.le.Scalar(relu_default, 0)
div_tensor = torch.ops.aten.div.Tensor(reciprocal_default, 10); reciprocal_default = None
return [relu_default, primals_3, unsqueeze_default, mul_tensor, le_scalar, div_tensor]
args = [((10, 10), (10, 1), torch.float32, 'cuda'), ((10,), (1,), torch.float32, 'cuda'), ((10,), (1,), torch.float32, 'cuda'), ((10,), (1,), torch.float32, 'cuda'), ((10,), (1,), torch.float32, 'cuda')]
args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]
mod = make_fx(Repro())(*args)
from torchinductor.compile_fx import compile_fx_inner
compiled = compile_fx_inner(mod, args)
compiled(*args)
buf0: ExternKernelSchedulerNode(MatrixMultiply)
buf0.writes = {StarDep(name='buf0')}
buf0.unmet_dependencies = set()
buf0.met_dependencies = {StarDep(name='primals_5'), StarDep(name='primals_1')}
buf0.node.kernel = aten.mm.out
buf1_buf2_buf3_buf4_buf5_buf6_buf7: FusedSchedulerNode(NoneType)
buf1_buf2_buf3_buf4_buf5_buf6_buf7.writes =
{ MemoryDep(name='buf1', index=0, size=()),
MemoryDep(name='buf1', index=0, size=(s0,)),
MemoryDep(name='buf2', index=0, size=()),
MemoryDep(name='buf2', index=0, size=(s0,)),
MemoryDep(name='buf3', index=0, size=()),
MemoryDep(name='buf3', index=0, size=(s0,)),
MemoryDep(name='buf4', index=c0, size=(s0,)),
MemoryDep(name='buf5', index=c0, size=(s0,)),
MemoryDep(name='buf6', index=c0, size=(s0,)),
MemoryDep(name='buf7', index=0, size=())}
buf1_buf2_buf3_buf4_buf5_buf6_buf7.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))}
buf1_buf2_buf3_buf4_buf5_buf6_buf7.met_dependencies =
{ MemoryDep(name='primals_2', index=c0, size=(s0,)),
MemoryDep(name='primals_3', index=c0, size=(s0,)),
MemoryDep(name='primals_4', index=c0, size=(s0,))}
buf1_buf2_buf3_buf4_buf5_buf6_buf7.snodes = ['buf1', 'buf2', 'buf3', 'buf4', 'buf5', 'buf6', 'buf7']
buf0: ExternKernelSchedulerNode(MatrixMultiply)
buf0.writes = {StarDep(name='buf0')}
buf0.unmet_dependencies = set()
buf0.met_dependencies = {StarDep(name='primals_5'), StarDep(name='primals_1')}
buf0.node.kernel = aten.mm.out
buf1: SchedulerNode(ComputedBuffer)
buf1.writes =
{ MemoryDep(name='buf1', index=0, size=()),
MemoryDep(name='buf1', index=0, size=(s0,))}
buf1.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))}
buf1.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))}
buf1.group.device = cuda:0
buf1.group.iteration = (1, s0)
buf1.sizes = ([], [s0])
class buf1_loop_body:
var_ranges = {z0: s0}
index0 = z0
index1 = 0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('buf0', get_index, False)
get_index_1 = self.get_index('index0')
load_1 = ops.load('primals_2', get_index_1, False)
add = ops.add(load, load_1)
get_index_2 = self.get_index('index1')
reduction = ops.reduction('buf1', torch.float32, torch.float32, 'sum', get_index_2, add)
return reduction
buf2: SchedulerNode(ComputedBuffer)
buf2.writes =
{ MemoryDep(name='buf2', index=0, size=()),
MemoryDep(name='buf2', index=0, size=(s0,))}
buf2.unmet_dependencies =
{ MemoryDep(name='buf0', index=c0, size=(s0,)),
MemoryDep(name='buf1', index=0, size=(s0,))}
buf2.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))}
buf2.group.device = cuda:0
buf2.group.iteration = (1, s0)
buf2.sizes = ([], [s0])
class buf2_loop_body:
var_ranges = {z0: s0}
index0 = z0
index1 = 0
index2 = s0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('buf0', get_index, False)
get_index_1 = self.get_index('index0')
load_1 = ops.load('primals_2', get_index_1, False)
add = ops.add(load, load_1)
get_index_2 = self.get_index('index1')
load_2 = ops.load('buf1', get_index_2, False)
get_index_3 = self.get_index('index2')
index_expr = ops.index_expr(get_index_3, torch.float32)
div = ops.div(load_2, index_expr)
sub = ops.sub(add, div)
square = ops.square(sub)
get_index_4 = self.get_index('index1')
reduction = ops.reduction('buf2', torch.float32, torch.float32, 'sum', get_index_4, square)
return reduction
buf3: SchedulerNode(ComputedBuffer)
buf3.writes =
{ MemoryDep(name='buf3', index=0, size=()),
MemoryDep(name='buf3', index=0, size=(s0,))}
buf3.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))}
buf3.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))}
buf3.group.device = cuda:0
buf3.group.iteration = (1, s0)
buf3.sizes = ([], [s0])
class buf3_loop_body:
var_ranges = {z0: s0}
index0 = z0
index1 = 0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('buf0', get_index, False)
get_index_1 = self.get_index('index0')
load_1 = ops.load('primals_2', get_index_1, False)
add = ops.add(load, load_1)
get_index_2 = self.get_index('index1')
reduction = ops.reduction('buf3', torch.float32, torch.float32, 'sum', get_index_2, add)
return reduction
buf4: SchedulerNode(ComputedBuffer)
buf4.writes = {MemoryDep(name='buf4', index=c0, size=(s0,))}
buf4.unmet_dependencies =
{ MemoryDep(name='buf0', index=c0, size=(s0,)),
MemoryDep(name='buf2', index=0, size=(s0,)),
MemoryDep(name='buf3', index=0, size=(s0,))}
buf4.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))}
buf4.group.device = cuda:0
buf4.group.iteration = (s0, 1)
buf4.sizes = ([s0], [])
class buf4_loop_body:
var_ranges = {z0: s0}
index0 = z0
index1 = 0
index2 = s0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('buf0', get_index, False)
get_index_1 = self.get_index('index0')
load_1 = ops.load('primals_2', get_index_1, False)
add = ops.add(load, load_1)
get_index_2 = self.get_index('index1')
load_2 = ops.load('buf3', get_index_2, False)
constant = ops.constant(10.0, torch.float32)
div = ops.div(load_2, constant)
sub = ops.sub(add, div)
get_index_3 = self.get_index('index1')
load_3 = ops.load('buf2', get_index_3, False)
get_index_4 = self.get_index('index2')
index_expr = ops.index_expr(get_index_4, torch.float32)
div_1 = ops.div(load_3, index_expr)
constant_1 = ops.constant(1e-05, torch.float32)
add_1 = ops.add(div_1, constant_1)
sqrt = ops.sqrt(add_1)
reciprocal = ops.reciprocal(sqrt)
mul = ops.mul(sub, reciprocal)
get_index_5 = self.get_index('index0')
store = ops.store('buf4', get_index_5, mul, None)
return store
buf5: SchedulerNode(ComputedBuffer)
buf5.writes = {MemoryDep(name='buf5', index=c0, size=(s0,))}
buf5.unmet_dependencies = {MemoryDep(name='buf4', index=c0, size=(s0,))}
buf5.met_dependencies =
{ MemoryDep(name='primals_3', index=c0, size=(s0,)),
MemoryDep(name='primals_4', index=c0, size=(s0,))}
buf5.group.device = cuda:0
buf5.group.iteration = (s0, 1)
buf5.sizes = ([s0], [])
class buf5_loop_body:
var_ranges = {z0: s0}
index0 = z0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('buf4', get_index, False)
get_index_1 = self.get_index('index0')
load_1 = ops.load('primals_3', get_index_1, False)
mul = ops.mul(load, load_1)
get_index_2 = self.get_index('index0')
load_2 = ops.load('primals_4', get_index_2, False)
add = ops.add(mul, load_2)
relu = ops.relu(add)
get_index_3 = self.get_index('index0')
store = ops.store('buf5', get_index_3, relu, None)
return store
buf6: SchedulerNode(ComputedBuffer)
buf6.writes = {MemoryDep(name='buf6', index=c0, size=(s0,))}
buf6.unmet_dependencies = {MemoryDep(name='buf5', index=c0, size=(s0,))}
buf6.met_dependencies = set()
buf6.group.device = cuda:0
buf6.group.iteration = (s0, 1)
buf6.sizes = ([s0], [])
class buf6_loop_body:
var_ranges = {z0: s0}
index0 = z0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('buf5', get_index, False)
constant = ops.constant(0, torch.float32)
le = ops.le(load, constant)
get_index_1 = self.get_index('index0')
store = ops.store('buf6', get_index_1, le, None)
return store
buf7: SchedulerNode(ComputedBuffer)
buf7.writes = {MemoryDep(name='buf7', index=0, size=())}
buf7.unmet_dependencies = {MemoryDep(name='buf2', index=0, size=())}
buf7.met_dependencies = set()
buf7.group.device = cuda:0
buf7.group.iteration = (1, 1)
buf7.sizes = ([], [])
class buf7_loop_body:
var_ranges = {}
index0 = 0
index1 = s0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('buf2', get_index, False)
get_index_1 = self.get_index('index1')
index_expr = ops.index_expr(get_index_1, torch.float32)
div = ops.div(load, index_expr)
constant = ops.constant(1e-05, torch.float32)
add = ops.add(div, constant)
sqrt = ops.sqrt(add)
reciprocal = ops.reciprocal(sqrt)
constant_1 = ops.constant(10, torch.float32)
div_1 = ops.div(reciprocal, constant_1)
get_index_2 = self.get_index('index0')
store = ops.store('buf7', get_index_2, div_1, None)
return store
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torchinductor.codecache import CppCodeCache, TritonCodeCache
aten = torch.ops.aten
import triton
import triton.language as tl
from torchinductor.triton_ops.autotune import pointwise_heuristics
from torchinductor.triton_ops.autotune import reduction_heuristics
from torchinductor.triton_ops.autotune import grid
@reduction_heuristics(size_hints=[1, 16])
@triton.jit
def kernel0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, out_ptr4, out_ptr5, out_ptr6, ks0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
xmask = xindex < xnumel
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK])
_tmp3 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + r0, rmask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + r0, rmask, eviction_policy='evict_last')
tmp2 = tmp0 + tmp1
_tmp3 = tl.where(xmask & rmask, _tmp3 + tmp2, _tmp3)
tmp3 = tl.reshape(tl.sum(_tmp3, 1), [XBLOCK, 1])
_tmp11 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
_tmp12 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp4 = tl.load(in_ptr0 + r0, rmask, eviction_policy='evict_last')
tmp5 = tl.load(in_ptr1 + r0, rmask, eviction_policy='evict_last')
tmp6 = tmp4 + tmp5
tmp7 = ks0
tmp8 = tmp3 / tmp7
tmp9 = tmp6 - tmp8
tmp10 = tmp9 * tmp9
_tmp11 = tl.where(xmask & rmask, _tmp11 + tmp10, _tmp11)
_tmp12 = tl.where(xmask & rmask, _tmp12 + tmp6, _tmp12)
tmp11 = tl.reshape(tl.sum(_tmp11, 1), [XBLOCK, 1])
tmp12 = tl.reshape(tl.sum(_tmp12, 1), [XBLOCK, 1])
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp13 = tl.load(in_ptr0 + r0, rmask, eviction_policy='evict_last')
tmp14 = tl.load(in_ptr1 + r0, rmask, eviction_policy='evict_last')
tmp15 = tmp13 + tmp14
tmp16 = 10.0
tmp17 = tmp12 / tmp16
tmp18 = tmp15 - tmp17
tmp19 = ks0
tmp20 = tmp11 / tmp19
tmp21 = 1e-05
tmp22 = tmp20 + tmp21
tmp23 = tl.sqrt(tmp22)
tmp24 = 1 / tmp23
tmp25 = tmp18 * tmp24
tl.store(out_ptr3 + r0 + tl.zeros([XBLOCK, RBLOCK], tl.int32), tmp25, xmask & rmask)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp26 = tl.load(out_ptr3 + r0, rmask, eviction_policy='evict_last')
tmp27 = tl.load(in_ptr2 + r0, rmask, eviction_policy='evict_last')
tmp29 = tl.load(in_ptr3 + r0, rmask, eviction_policy='evict_last')
tmp28 = tmp26 * tmp27
tmp30 = tmp28 + tmp29
tmp31 = tl.maximum(0, tmp30)
tl.store(out_ptr4 + r0 + tl.zeros([XBLOCK, RBLOCK], tl.int32), tmp31, xmask & rmask)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp32 = tl.load(out_ptr4 + r0, rmask, eviction_policy='evict_last')
tmp33 = 0
tmp34 = tmp32 <= tmp33
tl.store(out_ptr5 + r0 + tl.zeros([XBLOCK, RBLOCK], tl.int32), tmp34, xmask & rmask)
tmp35 = ks0
tmp36 = tmp11 / tmp35
tmp37 = 1e-05
tmp38 = tmp36 + tmp37
tmp39 = tl.sqrt(tmp38)
tmp40 = 1 / tmp39
tmp41 = 10
tmp42 = tmp40 / tmp41
tl.store(out_ptr6 + 0 + tl.zeros([XBLOCK, 1], tl.int32), tmp42, None)
def call(primals_1, primals_2, primals_3, primals_4, primals_5):
primals_1_size = primals_1.size()
s0 = primals_1_size[0]
buf0 = empty_strided((1, s0), (s0, 1), device='cuda', dtype=torch.float32)
aten.mm.out(as_strided(primals_5, (1, s0), (s0, 1)), as_strided(primals_1, (s0, s0), (1, s0)), out=buf0)
buf4 = empty_strided((s0, ), (1, ), device='cuda', dtype=torch.float32)
buf5 = empty_strided((s0, ), (1, ), device='cuda', dtype=torch.float32)
buf6 = empty_strided((s0, ), (1, ), device='cuda', dtype=torch.bool)
buf7 = empty_strided((1, ), (1, ), device='cuda', dtype=torch.float32)
kernel0[grid(1)](buf0, primals_2, primals_3, primals_4, buf4, buf5, buf6, buf7, s0, 1, s0)
return (buf5, primals_3, as_strided(primals_5, (1, s0), (s0, 1)), buf4, buf6, buf7, )
if __name__ == "__main__":
from torchdynamo.testing import rand_strided
from torchinductor.utils import print_performance
primals_1 = rand_strided((10, 10), (10, 1), device='cuda', dtype=torch.float32)
primals_2 = rand_strided((10, ), (1, ), device='cuda', dtype=torch.float32)
primals_3 = rand_strided((10, ), (1, ), device='cuda', dtype=torch.float32)
primals_4 = rand_strided((10, ), (1, ), device='cuda', dtype=torch.float32)
primals_5 = rand_strided((10, ), (1, ), device='cuda', dtype=torch.float32)
print_performance(lambda: call(primals_1, primals_2, primals_3, primals_4, primals_5))
@jansel
Copy link
Author

jansel commented Aug 27, 2022

graph_diagram.svg

graph_diagram

snakeviz compile.prof

Screenshot from 2022-08-27 14-04-37

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment