-
-
Save jansel/f4af078791ad681a0d4094adeb844396 to your computer and use it in GitHub Desktop.
TORCHINDUCTOR_TRACE=1 example output
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1380522 function calls (1367278 primitive calls) in 3.192 seconds | |
Ordered by: cumulative time | |
List reduced from 2159 to 100 due to restriction <100> | |
ncalls tottime percall cumtime percall filename:lineno(function) | |
1 0.000 0.000 3.193 3.193 compile_fx.py:39(compile_fx_inner) | |
1 0.000 0.000 2.741 2.741 compile_fx.py:74(cudagraphify) | |
2 0.000 0.000 2.690 1.345 crhwqgmbqtchqt3v3wdeeszjb352m4vbjbvdovaaeqpzi7tdjxqr.py:99(call) | |
2 0.000 0.000 2.689 1.344 code_gen.py:998(__call__) | |
2 0.000 0.000 2.689 1.344 code_gen.py:1056(__call__) | |
1 0.000 0.000 2.688 2.688 code_gen.py:1073(<dictcomp>) | |
4 0.000 0.000 2.688 0.672 autotune.py:26(_bench) | |
4 0.004 0.001 2.688 0.672 code_gen.py:1037(_bench) | |
4 0.015 0.004 2.684 0.671 testing.py:119(do_bench) | |
8955 0.021 0.000 2.029 0.000 code_gen.py:1049(kernel_call) | |
8957 0.059 0.000 2.008 0.000 code_gen.py:959(__call__) | |
8957 1.728 0.000 1.737 0.000 {built-in method triton._C.libtriton.triton.runtime.launch} | |
15 0.000 0.000 0.345 0.023 __init__.py:491(synchronize) | |
15 0.345 0.023 0.345 0.023 {built-in method torch._C._cuda_synchronize} | |
1 0.000 0.000 0.290 0.290 graph.py:320(compile_to_fn) | |
1 0.000 0.000 0.290 0.290 graph.py:304(compile_to_module) | |
1 0.000 0.000 0.265 0.265 graph.py:296(codegen) | |
14302 0.015 0.000 0.217 0.000 streams.py:173(record) | |
5 0.000 0.000 0.206 0.041 subprocess.py:736(__init__) | |
5 0.000 0.000 0.206 0.041 subprocess.py:1552(_execute_child) | |
4 0.000 0.000 0.177 0.044 subprocess.py:452(run) | |
1 0.000 0.000 0.177 0.177 scheduler.py:500(__init__) | |
14304 0.022 0.000 0.165 0.000 __init__.py:517(current_stream) | |
1 0.000 0.000 0.132 0.132 code_gen.py:1201(cache_key) | |
1 0.001 0.001 0.130 0.130 code_gen.py:1093(version_key) | |
14 0.128 0.009 0.128 0.009 {built-in method posix.read} | |
86/38 0.000 0.000 0.123 0.003 {built-in method builtins.exec} | |
23282 0.025 0.000 0.120 0.000 _utils.py:7(_get_device_index) | |
57/4 0.000 0.000 0.118 0.030 <frozen importlib._bootstrap>:986(_find_and_load) | |
57/4 0.000 0.000 0.118 0.030 <frozen importlib._bootstrap>:956(_find_and_load_unlocked) | |
54/4 0.000 0.000 0.117 0.029 <frozen importlib._bootstrap>:650(_load_unlocked) | |
51/4 0.000 0.000 0.117 0.029 <frozen importlib._bootstrap_external>:837(exec_module) | |
69/4 0.000 0.000 0.116 0.029 <frozen importlib._bootstrap>:211(_call_with_frames_removed) | |
5 0.000 0.000 0.114 0.023 __init__.py:1(<module>) | |
1 0.000 0.000 0.100 0.100 debug.py:312(graph_diagram) | |
1 0.000 0.000 0.100 0.100 debug.py:51(draw_buffers) | |
1 0.000 0.000 0.096 0.096 debug.py:296(fx_graph) | |
1 0.000 0.000 0.096 0.096 debug_utils.py:104(save_graph_repro) | |
1 0.000 0.000 0.096 0.096 debug_utils.py:22(generate_repro_string) | |
23282 0.018 0.000 0.086 0.000 _utils.py:591(_get_device_index) | |
2 0.000 0.000 0.082 0.041 subprocess.py:368(check_output) | |
5 0.076 0.015 0.076 0.015 {built-in method _posixsubprocess.fork_exec} | |
13 0.071 0.005 0.071 0.005 {built-in method _hashlib.openssl_md5} | |
14/6 0.000 0.000 0.066 0.011 {built-in method builtins.__import__} | |
1 0.000 0.000 0.065 0.065 graph.py:125(run) | |
1 0.000 0.000 0.065 0.065 interpreter.py:95(run) | |
28 0.000 0.000 0.065 0.002 graph.py:283(run_node) | |
14321 0.006 0.000 0.063 0.000 _utils.py:565(_get_current_device_index) | |
28 0.000 0.000 0.062 0.002 interpreter.py:144(run_node) | |
1 0.000 0.000 0.057 0.057 partitioners.py:427(draw_graph) | |
14321 0.012 0.000 0.057 0.000 _utils.py:555(_get_device_attr) | |
4683/2928 0.003 0.000 0.056 0.000 cache.py:67(wrapper) | |
1 0.000 0.000 0.055 0.055 pydot.py:1739(new_method) | |
2/1 0.000 0.000 0.055 0.055 pydot.py:1794(write) | |
1 0.000 0.000 0.055 0.055 pydot.py:1833(create) | |
7 0.000 0.000 0.055 0.008 scheduler.py:218(__init__) | |
1 0.000 0.000 0.053 0.053 pydot.py:113(call_graphviz) | |
1 0.000 0.000 0.053 0.053 wrapper.py:168(__init__) | |
405 0.000 0.000 0.053 0.000 decorators.py:224(_func) | |
1 0.000 0.000 0.052 0.052 utils.py:18(has_triton) | |
7146 0.052 0.000 0.052 0.000 {method 'zero_' of 'torch._C._TensorBase' objects} | |
30 0.000 0.000 0.052 0.002 code_gen.py:1442(jit) | |
30 0.000 0.000 0.052 0.002 code_gen.py:1169(__init__) | |
381 0.000 0.000 0.050 0.000 decorators.py:99(binary_op_wrapper) | |
1 0.000 0.000 0.050 0.050 graphs.py:145(__enter__) | |
21 0.000 0.000 0.049 0.002 ir.py:3290(__call__) | |
1 0.049 0.049 0.049 0.049 {built-in method gc.collect} | |
175 0.000 0.000 0.048 0.000 expr.py:215(__mul__) | |
22/20 0.000 0.000 0.048 0.002 operations.py:52(__new__) | |
5 0.000 0.000 0.048 0.010 graph.py:193(placeholder) | |
5 0.000 0.000 0.048 0.010 graph.py:33(symbolic_sizes_strides) | |
30 0.000 0.000 0.048 0.002 inspect.py:991(getsource) | |
30 0.000 0.000 0.048 0.002 inspect.py:970(getsourcelines) | |
2 0.000 0.000 0.048 0.024 mul.py:197(flatten) | |
6 0.000 0.000 0.048 0.008 add.py:184(flatten) | |
4 0.000 0.000 0.048 0.012 mul.py:467(_gather) | |
21 0.000 0.000 0.047 0.002 ir.py:3394(__call__) | |
1 0.000 0.000 0.046 0.046 tensor.py:1(<module>) | |
30 0.004 0.000 0.046 0.002 inspect.py:959(getblock) | |
1 0.000 0.000 0.041 0.041 debug.py:42(has_dot) | |
7 0.000 0.000 0.039 0.006 ir.py:1851(simplify_and_reorder) | |
9070 0.014 0.000 0.039 0.000 tokenize.py:429(_tokenize) | |
1 0.000 0.000 0.038 0.038 polyhedron.py:1(<module>) | |
92/82 0.001 0.000 0.038 0.000 <frozen importlib._bootstrap>:1017(_handle_fromlist) | |
14302 0.038 0.000 0.038 0.000 {function Event.record at 0x7f321a0c8a60} | |
930 0.010 0.000 0.036 0.000 basic.py:802(subs) | |
23293 0.011 0.000 0.034 0.000 __init__.py:485(current_device) | |
1 0.000 0.000 0.032 0.032 scheduler.py:987(codegen) | |
1 0.000 0.000 0.031 0.031 triton.py:1032(codegen_nodes) | |
1 0.000 0.000 0.031 0.031 triton.py:1101(codegen_node_schedule) | |
8957 0.006 0.000 0.029 0.000 __init__.py:307(set_device) | |
61 0.028 0.000 0.028 0.000 {method 'read' of '_io.BufferedReader' objects} | |
14 0.000 0.000 0.027 0.002 ir.py:3227(__init__) | |
14 0.000 0.000 0.027 0.002 ir.py:3311(__init__) | |
14306 0.014 0.000 0.027 0.000 streams.py:31(__new__) | |
2 0.000 0.000 0.025 0.013 __init__.py:2(<module>) | |
14321 0.005 0.000 0.025 0.000 _utils.py:567(<lambda>) | |
5 0.000 0.000 0.025 0.005 subprocess.py:984(communicate) | |
1 0.000 0.000 0.025 0.025 codecache.py:152(load) | |
7 0.000 0.000 0.024 0.003 scheduler.py:313(codegen) | |
1380522 function calls (1367278 primitive calls) in 3.192 seconds | |
Ordered by: internal time | |
List reduced from 2159 to 100 due to restriction <100> | |
ncalls tottime percall cumtime percall filename:lineno(function) | |
8957 1.728 0.000 1.737 0.000 {built-in method triton._C.libtriton.triton.runtime.launch} | |
15 0.345 0.023 0.345 0.023 {built-in method torch._C._cuda_synchronize} | |
14 0.128 0.009 0.128 0.009 {built-in method posix.read} | |
5 0.076 0.015 0.076 0.015 {built-in method _posixsubprocess.fork_exec} | |
13 0.071 0.005 0.071 0.005 {built-in method _hashlib.openssl_md5} | |
8957 0.059 0.000 2.008 0.000 code_gen.py:959(__call__) | |
7146 0.052 0.000 0.052 0.000 {method 'zero_' of 'torch._C._TensorBase' objects} | |
1 0.049 0.049 0.049 0.049 {built-in method gc.collect} | |
14302 0.038 0.000 0.038 0.000 {function Event.record at 0x7f321a0c8a60} | |
61 0.028 0.000 0.028 0.000 {method 'read' of '_io.BufferedReader' objects} | |
23282 0.025 0.000 0.120 0.000 _utils.py:7(_get_device_index) | |
14304 0.022 0.000 0.165 0.000 __init__.py:517(current_stream) | |
39479 0.021 0.000 0.021 0.000 {built-in method __new__ of type object at 0x560c66fc39a0} | |
239119/238013 0.021 0.000 0.023 0.000 {built-in method builtins.isinstance} | |
8955 0.021 0.000 2.029 0.000 code_gen.py:1049(kernel_call) | |
23282 0.018 0.000 0.086 0.000 _utils.py:591(_get_device_index) | |
14302 0.015 0.000 0.217 0.000 streams.py:173(record) | |
4 0.015 0.004 2.684 0.671 testing.py:119(do_bench) | |
9070 0.014 0.000 0.039 0.000 tokenize.py:429(_tokenize) | |
14306 0.014 0.000 0.027 0.000 streams.py:31(__new__) | |
7 0.012 0.002 0.012 0.002 {method 'poll' of 'select.poll' objects} | |
14321 0.012 0.000 0.057 0.000 _utils.py:555(_get_device_attr) | |
2 0.011 0.005 0.011 0.005 {built-in method _imp.create_dynamic} | |
23293 0.011 0.000 0.034 0.000 __init__.py:485(current_device) | |
930 0.010 0.000 0.036 0.000 basic.py:802(subs) | |
23293 0.009 0.000 0.009 0.000 {built-in method torch._C._cuda_getDevice} | |
37631 0.009 0.000 0.024 0.000 __init__.py:196(_lazy_init) | |
9719 0.008 0.000 0.008 0.000 {method 'match' of 're.Pattern' objects} | |
37632 0.008 0.000 0.015 0.000 __init__.py:153(is_initialized) | |
8957 0.008 0.000 0.013 0.000 code_gen.py:962(<dictcomp>) | |
14326 0.007 0.000 0.013 0.000 __init__.py:77(is_available) | |
17914 0.007 0.000 0.009 0.000 core.py:340(__init__) | |
37632 0.007 0.000 0.007 0.000 {built-in method torch._C._cuda_isInBadFork} | |
10199/10163 0.006 0.000 0.011 0.000 {built-in method builtins.sorted} | |
32 0.006 0.000 0.006 0.000 {built-in method builtins.compile} | |
14321 0.006 0.000 0.063 0.000 _utils.py:565(_get_current_device_index) | |
8957 0.006 0.000 0.029 0.000 __init__.py:307(set_device) | |
14321 0.005 0.000 0.018 0.000 _utils.py:546(_get_available_device_type) | |
14304 0.005 0.000 0.005 0.000 {built-in method torch._C._cuda_getCurrentStream} | |
18028 0.005 0.000 0.005 0.000 {method 'index' of 'list' objects} | |
51 0.005 0.000 0.005 0.000 {built-in method marshal.loads} | |
8957 0.005 0.000 0.005 0.000 {built-in method torch._C._cuda_setDevice} | |
20253/19676 0.005 0.000 0.006 0.000 {built-in method builtins.getattr} | |
14321 0.005 0.000 0.025 0.000 _utils.py:567(<lambda>) | |
8957 0.005 0.000 0.007 0.000 autotune.py:452(grid_fn) | |
14302 0.004 0.000 0.011 0.000 streams.py:163(__new__) | |
4 0.004 0.001 2.688 0.672 code_gen.py:1037(_bench) | |
8749 0.004 0.000 0.009 0.000 re.py:289(_compile) | |
30 0.004 0.000 0.046 0.002 inspect.py:959(getblock) | |
22654 0.004 0.000 0.004 0.000 {built-in method builtins.hasattr} | |
7644/7464 0.004 0.000 0.010 0.000 sympify.py:102(sympify) | |
37624 0.004 0.000 0.004 0.000 _jit_internal.py:1082(is_scripting) | |
7150 0.003 0.000 0.003 0.000 {function Event.elapsed_time at 0x7f321a0c8c10} | |
3784/1814 0.003 0.000 0.007 0.000 node.py:603(map_aggregate) | |
9028 0.003 0.000 0.003 0.000 inspect.py:909(tokeneater) | |
14330 0.003 0.000 0.003 0.000 {built-in method torch._C._cuda_getDeviceCount} | |
8957 0.003 0.000 0.003 0.000 {built-in method torch._C._cuda_getCurrentRawStream} | |
34440/33719 0.003 0.000 0.003 0.000 {built-in method builtins.len} | |
4683/2928 0.003 0.000 0.056 0.000 cache.py:67(wrapper) | |
1743/1523 0.003 0.000 0.010 0.000 sorting.py:203(ordered) | |
8550 0.003 0.000 0.012 0.000 tokenize.py:98(_compile) | |
4 0.003 0.001 0.008 0.002 testing.py:153(<listcomp>) | |
17973 0.003 0.000 0.003 0.000 {method 'insert' of 'list' objects} | |
4 0.003 0.001 0.008 0.002 testing.py:154(<listcomp>) | |
8224 0.003 0.000 0.004 0.000 <frozen importlib._bootstrap>:389(parent) | |
568/500 0.003 0.000 0.004 0.000 printer.py:294(_print) | |
501 0.002 0.000 0.006 0.000 basic.py:2016(_aresame) | |
8957 0.002 0.000 0.002 0.000 code_gen.py:1482(cdiv) | |
22431 0.002 0.000 0.002 0.000 {method 'items' of 'dict' objects} | |
980 0.002 0.000 0.006 0.000 sorting.py:10(default_sort_key) | |
8550 0.002 0.000 0.003 0.000 types.py:171(__get__) | |
8957 0.002 0.000 0.005 0.000 code_gen.py:32(current_cuda_stream) | |
15540 0.002 0.000 0.002 0.000 {method 'lower' of 'str' objects} | |
4478 0.002 0.000 0.006 0.000 numbers.py:2238(__eq__) | |
8676 0.002 0.000 0.010 0.000 re.py:250(compile) | |
31 0.002 0.000 0.002 0.000 pydot.py:530(create_attribute_methods) | |
7150 0.002 0.000 0.005 0.000 streams.py:204(elapsed_time) | |
1662 0.002 0.000 0.003 0.000 graph.py:123(create_name) | |
746 0.002 0.000 0.004 0.000 function.py:2495(expand) | |
6122 0.002 0.000 0.002 0.000 misc.py:491(as_int) | |
9037 0.002 0.000 0.003 0.000 <string>:1(__new__) | |
118/117 0.002 0.000 0.005 0.000 {built-in method builtins.__build_class__} | |
100 0.001 0.000 0.009 0.000 iterables.py:1173(least_rotation) | |
1470 0.001 0.000 0.003 0.000 basic.py:350(__eq__) | |
571 0.001 0.000 0.003 0.000 sorting.py:180(_nodes) | |
1 0.001 0.001 0.130 0.130 code_gen.py:1093(version_key) | |
459 0.001 0.000 0.004 0.000 permutations.py:384(<listcomp>) | |
3291 0.001 0.000 0.005 0.000 sympify.py:503(_sympify) | |
826/376 0.001 0.000 0.004 0.000 _symbolic_trace.py:254(create_arg) | |
2015 0.001 0.000 0.003 0.000 containers.py:58(__getitem__) | |
8510 0.001 0.000 0.001 0.000 {method 'rpartition' of 'str' objects} | |
30 0.001 0.000 0.014 0.000 graph.py:297(_gen_python_code) | |
4 0.001 0.000 0.006 0.002 testing.py:175(<listcomp>) | |
4 0.001 0.000 0.001 0.000 {built-in method torch.quantile} | |
5658 0.001 0.000 0.002 0.000 numbers.py:2284(__hash__) | |
10 0.001 0.000 0.001 0.000 {built-in method torch.empty_strided} | |
367 0.001 0.000 0.002 0.000 permutations.py:1316(__mul__) | |
1730 0.001 0.000 0.003 0.000 numbers.py:1867(__eq__) | |
8490/6732 0.001 0.000 0.001 0.000 {built-in method builtins.hash} | |
2950/2720 0.001 0.000 0.003 0.000 containers.py:54(<genexpr>) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[compile_fx.py:48 INFO] Compiling FORWARDS graph | |
[triton.py:1098 INFO] schedule: [SchedulerNode(name='buf1'), <class 'torchinductor.codegen.triton.DisableReduction'>, <class 'torchinductor.codegen.triton.EnableReduction'>, SchedulerNode(name='buf2'), SchedulerNode(name='buf3'), <class 'torchinductor.codegen.triton.DisableReduction'>, <class 'torchinductor.codegen.triton.EnableReduction'>, SchedulerNode(name='buf4'), <class 'torchinductor.codegen.triton.DisableReduction'>, <class 'torchinductor.codegen.triton.EnableReduction'>, SchedulerNode(name='buf5'), <class 'torchinductor.codegen.triton.DisableReduction'>, <class 'torchinductor.codegen.triton.EnableReduction'>, SchedulerNode(name='buf6'), <class 'torchinductor.codegen.triton.DisableReduction'>, SchedulerNode(name='buf7'), <class 'torchinductor.codegen.triton.EnableReduction'>] | |
[scheduler.py:944 DEBUG] remove_buffer('buf3') | |
[scheduler.py:944 DEBUG] remove_buffer('buf2') | |
[scheduler.py:944 DEBUG] remove_buffer('buf1') | |
[graph.py:315 INFO] Output code: /tmp/torchinductor_jansel/rh/crhwqgmbqtchqt3v3wdeeszjb352m4vbjbvdovaaeqpzi7tdjxqr.py | |
[debug.py:260 WARNING] model_forward_0 debug trace: /tmp/torchinductor_jansel/rh/crhwqgmbqtchqt3v3wdeeszjb352m4vbjbvdovaaeqpzi7tdjxqr.debug |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import tensor, device | |
import torch.fx as fx | |
from torchdynamo.testing import rand_strided | |
from math import inf | |
from torch.fx.experimental.proxy_tensor import make_fx | |
# torch version: 1.13.0a0+gita089114 | |
# torch cuda version: 11.6 | |
# torch git version: a08911400edb62c9caa0c94d1ce176cf8cb29765 | |
# CUDA Info: | |
# nvcc: NVIDIA (R) Cuda compiler driver | |
# Copyright (c) 2005-2022 NVIDIA Corporation | |
# Built on Tue_Mar__8_18:18:20_PST_2022 | |
# Cuda compilation tools, release 11.6, V11.6.124 | |
# Build cuda_11.6.r11.6/compiler.31057947_0 | |
# GPU Hardware Info: | |
# NVIDIA GeForce RTX 3090 : 1 | |
class Repro(torch.nn.Module): | |
def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5): | |
permute_default = torch.ops.aten.permute.default(primals_1, [1, 0]); primals_1 = None | |
unsqueeze_default = torch.ops.aten.unsqueeze.default(primals_5, 0); primals_5 = None | |
mm_default = torch.ops.aten.mm.default(unsqueeze_default, permute_default); permute_default = None | |
squeeze_dim = torch.ops.aten.squeeze.dim(mm_default, 0); mm_default = None | |
add_tensor = torch.ops.aten.add.Tensor(squeeze_dim, primals_2); squeeze_dim = primals_2 = None | |
convert_element_type_default = torch.ops.prims.convert_element_type.default(add_tensor, torch.float32) | |
var_default = torch.ops.prims.var.default(convert_element_type_default, [0], correction = 0) | |
broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(var_default, [1], []); var_default = None | |
sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0]); convert_element_type_default = None | |
broadcast_in_dim_default_1 = torch.ops.prims.broadcast_in_dim.default(sum_default, [1], []); sum_default = None | |
div_default = torch.ops.prims.div.default(broadcast_in_dim_default_1, 10.0); broadcast_in_dim_default_1 = None | |
add_tensor_1 = torch.ops.aten.add.Tensor(broadcast_in_dim_default, 1e-05); broadcast_in_dim_default = None | |
sqrt_default = torch.ops.aten.sqrt.default(add_tensor_1); add_tensor_1 = None | |
reciprocal_default = torch.ops.aten.reciprocal.default(sqrt_default); sqrt_default = None | |
sub_tensor = torch.ops.aten.sub.Tensor(add_tensor, div_default); add_tensor = div_default = None | |
mul_tensor = torch.ops.aten.mul.Tensor(sub_tensor, reciprocal_default); sub_tensor = None | |
mul_tensor_1 = torch.ops.aten.mul.Tensor(mul_tensor, primals_3) | |
add_tensor_2 = torch.ops.aten.add.Tensor(mul_tensor_1, primals_4); mul_tensor_1 = primals_4 = None | |
convert_element_type_default_2 = torch.ops.prims.convert_element_type.default(add_tensor_2, torch.float32); add_tensor_2 = None | |
relu_default = torch.ops.aten.relu.default(convert_element_type_default_2); convert_element_type_default_2 = None | |
le_scalar = torch.ops.aten.le.Scalar(relu_default, 0) | |
div_tensor = torch.ops.aten.div.Tensor(reciprocal_default, 10); reciprocal_default = None | |
return [relu_default, primals_3, unsqueeze_default, mul_tensor, le_scalar, div_tensor] | |
args = [((10, 10), (10, 1), torch.float32, 'cuda'), ((10,), (1,), torch.float32, 'cuda'), ((10,), (1,), torch.float32, 'cuda'), ((10,), (1,), torch.float32, 'cuda'), ((10,), (1,), torch.float32, 'cuda')] | |
args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args] | |
mod = make_fx(Repro())(*args) | |
from torchinductor.compile_fx import compile_fx_inner | |
compiled = compile_fx_inner(mod, args) | |
compiled(*args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
buf0: ExternKernelSchedulerNode(MatrixMultiply) | |
buf0.writes = {StarDep(name='buf0')} | |
buf0.unmet_dependencies = set() | |
buf0.met_dependencies = {StarDep(name='primals_5'), StarDep(name='primals_1')} | |
buf0.node.kernel = aten.mm.out | |
buf1_buf2_buf3_buf4_buf5_buf6_buf7: FusedSchedulerNode(NoneType) | |
buf1_buf2_buf3_buf4_buf5_buf6_buf7.writes = | |
{ MemoryDep(name='buf1', index=0, size=()), | |
MemoryDep(name='buf1', index=0, size=(s0,)), | |
MemoryDep(name='buf2', index=0, size=()), | |
MemoryDep(name='buf2', index=0, size=(s0,)), | |
MemoryDep(name='buf3', index=0, size=()), | |
MemoryDep(name='buf3', index=0, size=(s0,)), | |
MemoryDep(name='buf4', index=c0, size=(s0,)), | |
MemoryDep(name='buf5', index=c0, size=(s0,)), | |
MemoryDep(name='buf6', index=c0, size=(s0,)), | |
MemoryDep(name='buf7', index=0, size=())} | |
buf1_buf2_buf3_buf4_buf5_buf6_buf7.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))} | |
buf1_buf2_buf3_buf4_buf5_buf6_buf7.met_dependencies = | |
{ MemoryDep(name='primals_2', index=c0, size=(s0,)), | |
MemoryDep(name='primals_3', index=c0, size=(s0,)), | |
MemoryDep(name='primals_4', index=c0, size=(s0,))} | |
buf1_buf2_buf3_buf4_buf5_buf6_buf7.snodes = ['buf1', 'buf2', 'buf3', 'buf4', 'buf5', 'buf6', 'buf7'] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
buf0: ExternKernelSchedulerNode(MatrixMultiply) | |
buf0.writes = {StarDep(name='buf0')} | |
buf0.unmet_dependencies = set() | |
buf0.met_dependencies = {StarDep(name='primals_5'), StarDep(name='primals_1')} | |
buf0.node.kernel = aten.mm.out | |
buf1: SchedulerNode(ComputedBuffer) | |
buf1.writes = | |
{ MemoryDep(name='buf1', index=0, size=()), | |
MemoryDep(name='buf1', index=0, size=(s0,))} | |
buf1.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))} | |
buf1.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))} | |
buf1.group.device = cuda:0 | |
buf1.group.iteration = (1, s0) | |
buf1.sizes = ([], [s0]) | |
class buf1_loop_body: | |
var_ranges = {z0: s0} | |
index0 = z0 | |
index1 = 0 | |
def body(self, ops): | |
get_index = self.get_index('index0') | |
load = ops.load('buf0', get_index, False) | |
get_index_1 = self.get_index('index0') | |
load_1 = ops.load('primals_2', get_index_1, False) | |
add = ops.add(load, load_1) | |
get_index_2 = self.get_index('index1') | |
reduction = ops.reduction('buf1', torch.float32, torch.float32, 'sum', get_index_2, add) | |
return reduction | |
buf2: SchedulerNode(ComputedBuffer) | |
buf2.writes = | |
{ MemoryDep(name='buf2', index=0, size=()), | |
MemoryDep(name='buf2', index=0, size=(s0,))} | |
buf2.unmet_dependencies = | |
{ MemoryDep(name='buf0', index=c0, size=(s0,)), | |
MemoryDep(name='buf1', index=0, size=(s0,))} | |
buf2.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))} | |
buf2.group.device = cuda:0 | |
buf2.group.iteration = (1, s0) | |
buf2.sizes = ([], [s0]) | |
class buf2_loop_body: | |
var_ranges = {z0: s0} | |
index0 = z0 | |
index1 = 0 | |
index2 = s0 | |
def body(self, ops): | |
get_index = self.get_index('index0') | |
load = ops.load('buf0', get_index, False) | |
get_index_1 = self.get_index('index0') | |
load_1 = ops.load('primals_2', get_index_1, False) | |
add = ops.add(load, load_1) | |
get_index_2 = self.get_index('index1') | |
load_2 = ops.load('buf1', get_index_2, False) | |
get_index_3 = self.get_index('index2') | |
index_expr = ops.index_expr(get_index_3, torch.float32) | |
div = ops.div(load_2, index_expr) | |
sub = ops.sub(add, div) | |
square = ops.square(sub) | |
get_index_4 = self.get_index('index1') | |
reduction = ops.reduction('buf2', torch.float32, torch.float32, 'sum', get_index_4, square) | |
return reduction | |
buf3: SchedulerNode(ComputedBuffer) | |
buf3.writes = | |
{ MemoryDep(name='buf3', index=0, size=()), | |
MemoryDep(name='buf3', index=0, size=(s0,))} | |
buf3.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))} | |
buf3.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))} | |
buf3.group.device = cuda:0 | |
buf3.group.iteration = (1, s0) | |
buf3.sizes = ([], [s0]) | |
class buf3_loop_body: | |
var_ranges = {z0: s0} | |
index0 = z0 | |
index1 = 0 | |
def body(self, ops): | |
get_index = self.get_index('index0') | |
load = ops.load('buf0', get_index, False) | |
get_index_1 = self.get_index('index0') | |
load_1 = ops.load('primals_2', get_index_1, False) | |
add = ops.add(load, load_1) | |
get_index_2 = self.get_index('index1') | |
reduction = ops.reduction('buf3', torch.float32, torch.float32, 'sum', get_index_2, add) | |
return reduction | |
buf4: SchedulerNode(ComputedBuffer) | |
buf4.writes = {MemoryDep(name='buf4', index=c0, size=(s0,))} | |
buf4.unmet_dependencies = | |
{ MemoryDep(name='buf0', index=c0, size=(s0,)), | |
MemoryDep(name='buf2', index=0, size=(s0,)), | |
MemoryDep(name='buf3', index=0, size=(s0,))} | |
buf4.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))} | |
buf4.group.device = cuda:0 | |
buf4.group.iteration = (s0, 1) | |
buf4.sizes = ([s0], []) | |
class buf4_loop_body: | |
var_ranges = {z0: s0} | |
index0 = z0 | |
index1 = 0 | |
index2 = s0 | |
def body(self, ops): | |
get_index = self.get_index('index0') | |
load = ops.load('buf0', get_index, False) | |
get_index_1 = self.get_index('index0') | |
load_1 = ops.load('primals_2', get_index_1, False) | |
add = ops.add(load, load_1) | |
get_index_2 = self.get_index('index1') | |
load_2 = ops.load('buf3', get_index_2, False) | |
constant = ops.constant(10.0, torch.float32) | |
div = ops.div(load_2, constant) | |
sub = ops.sub(add, div) | |
get_index_3 = self.get_index('index1') | |
load_3 = ops.load('buf2', get_index_3, False) | |
get_index_4 = self.get_index('index2') | |
index_expr = ops.index_expr(get_index_4, torch.float32) | |
div_1 = ops.div(load_3, index_expr) | |
constant_1 = ops.constant(1e-05, torch.float32) | |
add_1 = ops.add(div_1, constant_1) | |
sqrt = ops.sqrt(add_1) | |
reciprocal = ops.reciprocal(sqrt) | |
mul = ops.mul(sub, reciprocal) | |
get_index_5 = self.get_index('index0') | |
store = ops.store('buf4', get_index_5, mul, None) | |
return store | |
buf5: SchedulerNode(ComputedBuffer) | |
buf5.writes = {MemoryDep(name='buf5', index=c0, size=(s0,))} | |
buf5.unmet_dependencies = {MemoryDep(name='buf4', index=c0, size=(s0,))} | |
buf5.met_dependencies = | |
{ MemoryDep(name='primals_3', index=c0, size=(s0,)), | |
MemoryDep(name='primals_4', index=c0, size=(s0,))} | |
buf5.group.device = cuda:0 | |
buf5.group.iteration = (s0, 1) | |
buf5.sizes = ([s0], []) | |
class buf5_loop_body: | |
var_ranges = {z0: s0} | |
index0 = z0 | |
def body(self, ops): | |
get_index = self.get_index('index0') | |
load = ops.load('buf4', get_index, False) | |
get_index_1 = self.get_index('index0') | |
load_1 = ops.load('primals_3', get_index_1, False) | |
mul = ops.mul(load, load_1) | |
get_index_2 = self.get_index('index0') | |
load_2 = ops.load('primals_4', get_index_2, False) | |
add = ops.add(mul, load_2) | |
relu = ops.relu(add) | |
get_index_3 = self.get_index('index0') | |
store = ops.store('buf5', get_index_3, relu, None) | |
return store | |
buf6: SchedulerNode(ComputedBuffer) | |
buf6.writes = {MemoryDep(name='buf6', index=c0, size=(s0,))} | |
buf6.unmet_dependencies = {MemoryDep(name='buf5', index=c0, size=(s0,))} | |
buf6.met_dependencies = set() | |
buf6.group.device = cuda:0 | |
buf6.group.iteration = (s0, 1) | |
buf6.sizes = ([s0], []) | |
class buf6_loop_body: | |
var_ranges = {z0: s0} | |
index0 = z0 | |
def body(self, ops): | |
get_index = self.get_index('index0') | |
load = ops.load('buf5', get_index, False) | |
constant = ops.constant(0, torch.float32) | |
le = ops.le(load, constant) | |
get_index_1 = self.get_index('index0') | |
store = ops.store('buf6', get_index_1, le, None) | |
return store | |
buf7: SchedulerNode(ComputedBuffer) | |
buf7.writes = {MemoryDep(name='buf7', index=0, size=())} | |
buf7.unmet_dependencies = {MemoryDep(name='buf2', index=0, size=())} | |
buf7.met_dependencies = set() | |
buf7.group.device = cuda:0 | |
buf7.group.iteration = (1, 1) | |
buf7.sizes = ([], []) | |
class buf7_loop_body: | |
var_ranges = {} | |
index0 = 0 | |
index1 = s0 | |
def body(self, ops): | |
get_index = self.get_index('index0') | |
load = ops.load('buf2', get_index, False) | |
get_index_1 = self.get_index('index1') | |
index_expr = ops.index_expr(get_index_1, torch.float32) | |
div = ops.div(load, index_expr) | |
constant = ops.constant(1e-05, torch.float32) | |
add = ops.add(div, constant) | |
sqrt = ops.sqrt(add) | |
reciprocal = ops.reciprocal(sqrt) | |
constant_1 = ops.constant(10, torch.float32) | |
div_1 = ops.div(reciprocal, constant_1) | |
get_index_2 = self.get_index('index0') | |
store = ops.store('buf7', get_index_2, div_1, None) | |
return store | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ctypes import c_void_p, c_long | |
import torch | |
import random | |
from torch import empty_strided, as_strided, device | |
from torchinductor.codecache import CppCodeCache, TritonCodeCache | |
aten = torch.ops.aten | |
import triton | |
import triton.language as tl | |
from torchinductor.triton_ops.autotune import pointwise_heuristics | |
from torchinductor.triton_ops.autotune import reduction_heuristics | |
from torchinductor.triton_ops.autotune import grid | |
@reduction_heuristics(size_hints=[1, 16]) | |
@triton.jit | |
def kernel0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, out_ptr4, out_ptr5, out_ptr6, ks0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1]) | |
xmask = xindex < xnumel | |
rbase = tl.reshape(tl.arange(0, RBLOCK), [1, RBLOCK]) | |
_tmp3 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp0 = tl.load(in_ptr0 + r0, rmask, eviction_policy='evict_last') | |
tmp1 = tl.load(in_ptr1 + r0, rmask, eviction_policy='evict_last') | |
tmp2 = tmp0 + tmp1 | |
_tmp3 = tl.where(xmask & rmask, _tmp3 + tmp2, _tmp3) | |
tmp3 = tl.reshape(tl.sum(_tmp3, 1), [XBLOCK, 1]) | |
_tmp11 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
_tmp12 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp4 = tl.load(in_ptr0 + r0, rmask, eviction_policy='evict_last') | |
tmp5 = tl.load(in_ptr1 + r0, rmask, eviction_policy='evict_last') | |
tmp6 = tmp4 + tmp5 | |
tmp7 = ks0 | |
tmp8 = tmp3 / tmp7 | |
tmp9 = tmp6 - tmp8 | |
tmp10 = tmp9 * tmp9 | |
_tmp11 = tl.where(xmask & rmask, _tmp11 + tmp10, _tmp11) | |
_tmp12 = tl.where(xmask & rmask, _tmp12 + tmp6, _tmp12) | |
tmp11 = tl.reshape(tl.sum(_tmp11, 1), [XBLOCK, 1]) | |
tmp12 = tl.reshape(tl.sum(_tmp12, 1), [XBLOCK, 1]) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp13 = tl.load(in_ptr0 + r0, rmask, eviction_policy='evict_last') | |
tmp14 = tl.load(in_ptr1 + r0, rmask, eviction_policy='evict_last') | |
tmp15 = tmp13 + tmp14 | |
tmp16 = 10.0 | |
tmp17 = tmp12 / tmp16 | |
tmp18 = tmp15 - tmp17 | |
tmp19 = ks0 | |
tmp20 = tmp11 / tmp19 | |
tmp21 = 1e-05 | |
tmp22 = tmp20 + tmp21 | |
tmp23 = tl.sqrt(tmp22) | |
tmp24 = 1 / tmp23 | |
tmp25 = tmp18 * tmp24 | |
tl.store(out_ptr3 + r0 + tl.zeros([XBLOCK, RBLOCK], tl.int32), tmp25, xmask & rmask) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp26 = tl.load(out_ptr3 + r0, rmask, eviction_policy='evict_last') | |
tmp27 = tl.load(in_ptr2 + r0, rmask, eviction_policy='evict_last') | |
tmp29 = tl.load(in_ptr3 + r0, rmask, eviction_policy='evict_last') | |
tmp28 = tmp26 * tmp27 | |
tmp30 = tmp28 + tmp29 | |
tmp31 = tl.maximum(0, tmp30) | |
tl.store(out_ptr4 + r0 + tl.zeros([XBLOCK, RBLOCK], tl.int32), tmp31, xmask & rmask) | |
for roffset in range(0, rnumel, RBLOCK): | |
rindex = roffset + rbase | |
rmask = rindex < rnumel | |
r0 = rindex | |
tmp32 = tl.load(out_ptr4 + r0, rmask, eviction_policy='evict_last') | |
tmp33 = 0 | |
tmp34 = tmp32 <= tmp33 | |
tl.store(out_ptr5 + r0 + tl.zeros([XBLOCK, RBLOCK], tl.int32), tmp34, xmask & rmask) | |
tmp35 = ks0 | |
tmp36 = tmp11 / tmp35 | |
tmp37 = 1e-05 | |
tmp38 = tmp36 + tmp37 | |
tmp39 = tl.sqrt(tmp38) | |
tmp40 = 1 / tmp39 | |
tmp41 = 10 | |
tmp42 = tmp40 / tmp41 | |
tl.store(out_ptr6 + 0 + tl.zeros([XBLOCK, 1], tl.int32), tmp42, None) | |
def call(primals_1, primals_2, primals_3, primals_4, primals_5): | |
primals_1_size = primals_1.size() | |
s0 = primals_1_size[0] | |
buf0 = empty_strided((1, s0), (s0, 1), device='cuda', dtype=torch.float32) | |
aten.mm.out(as_strided(primals_5, (1, s0), (s0, 1)), as_strided(primals_1, (s0, s0), (1, s0)), out=buf0) | |
buf4 = empty_strided((s0, ), (1, ), device='cuda', dtype=torch.float32) | |
buf5 = empty_strided((s0, ), (1, ), device='cuda', dtype=torch.float32) | |
buf6 = empty_strided((s0, ), (1, ), device='cuda', dtype=torch.bool) | |
buf7 = empty_strided((1, ), (1, ), device='cuda', dtype=torch.float32) | |
kernel0[grid(1)](buf0, primals_2, primals_3, primals_4, buf4, buf5, buf6, buf7, s0, 1, s0) | |
return (buf5, primals_3, as_strided(primals_5, (1, s0), (s0, 1)), buf4, buf6, buf7, ) | |
if __name__ == "__main__": | |
from torchdynamo.testing import rand_strided | |
from torchinductor.utils import print_performance | |
primals_1 = rand_strided((10, 10), (10, 1), device='cuda', dtype=torch.float32) | |
primals_2 = rand_strided((10, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_3 = rand_strided((10, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_4 = rand_strided((10, ), (1, ), device='cuda', dtype=torch.float32) | |
primals_5 = rand_strided((10, ), (1, ), device='cuda', dtype=torch.float32) | |
print_performance(lambda: call(primals_1, primals_2, primals_3, primals_4, primals_5)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
graph_diagram.svg
snakeviz compile.prof