Skip to content

Instantly share code, notes, and snippets.

@asmeurer
Last active January 16, 2024 02:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save asmeurer/d5e7de610fb932786dcfd0db2c11f870 to your computer and use it in GitHub Desktop.
Save asmeurer/d5e7de610fb932786dcfd0db2c11f870 to your computer and use it in GitHub Desktop.
import torch
from torch.testing import make_tensor
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils.benchmark import Timer, Compare
from torch._inductor.compile_fx import compile_fx_inner, cudagraphify_impl
from torch._inductor.decomposition import decompositions
from itertools import product
import numpy as np
torch._logging.set_logs(output_code=True)
benchmark_name = "isin"
elements_sizes = [10, 10000, 1000000]
test_elements_sizes = [10, 10000, 1000000]
highest_element = [10, 1000, int(1e8)]
# Required for the sorting algorithm to be used. Ensure both sides of the
# condition are covered
assert test_elements_sizes[0] < 10*elements_sizes[-1]**0.145
assert test_elements_sizes[-1] >= 10*elements_sizes[0]**0.145
def gen_inputs(unique=False):
rng = np.random.default_rng(0)
for s, test_s, M in product(elements_sizes, test_elements_sizes,
highest_element):
if unique:
if M < s or M < test_s:
continue
# torch does not have an equivalent of numpy.random.choice
# https://github.com/pytorch/pytorch/issues/16897, so we just
# generate arrays with numpy then move them to torch
np_elements = rng.choice(M, size=s, replace=False)
np_test_elements = rng.choice(M, size=test_s, replace=False)
yield (torch.from_numpy(np_elements).to(device="cuda:0"),
torch.from_numpy(np_test_elements).to(device="cuda:0"),
M)
else:
yield (make_tensor((s,), low=0, high=M, dtype=torch.int64, device="cuda:0"),
make_tensor((test_s,), low=0, high=M, dtype=torch.int64, device="cuda:0"),
M)
def benchmark(label, f, elements, test_elements, highest_element, unique, assume_unique):
return Timer("f([elements, test_elements, highest_element])",
globals=locals(),
label=benchmark_name,
description=label,
sub_label=f"{tuple(elements.shape)}, {tuple(test_elements.shape)}, {highest_element}, {unique}, {assume_unique}",
num_threads=torch.get_num_threads()).blocked_autorange(min_run_time=2)
def compare(elements, test_elements, highest_element, unique, assume_unique):
def f(args):
elements, test_elements, highest_element = args
val = torch.ops.aten.isin(elements, test_elements, assume_unique=assume_unique)
return (val,)
print(f"{tuple(elements.shape)}, {tuple(test_elements.shape)}, {highest_element}, {unique}, {assume_unique}")
args = [elements, test_elements, highest_element]
# print("Decomposed")
decomposed = make_fx(f, decomposition_table=decompositions, tracing_mode="fake")(args)
compiled_decomposed = compile_fx_inner(decomposed, args, cudagraphs=False)
yield benchmark("Decomposed", compiled_decomposed, *args, unique=unique, assume_unique=assume_unique)
# print("Lowering")
non_decomposed = make_fx(f, tracing_mode="fake")(args)
compiled_nondecomposed = compile_fx_inner(non_decomposed, args, cudagraphs=False)
yield benchmark("Lowering", compiled_nondecomposed, *args, unique=unique, assume_unique=assume_unique)
# Just show the first two generated kernels
torch._logging.set_logs(output_code=False)
# print("Eager")
# cuda_f = cudagraphify_impl(f, args, static_input_idxs=tuple(range(len(args))))
# yield benchmark("Eager", cuda_f, *args)
yield benchmark("Eager", f, *args, unique=unique, assume_unique=assume_unique)
results = []
for args in gen_inputs(unique=True):
for res in compare(*args, unique=True, assume_unique=True):
results.append(res)
for res in compare(*args, unique=True, assume_unique=False):
results.append(res)
for args in gen_inputs(unique=False):
for res in compare(*args, unique=False, assume_unique=False):
results.append(res)
compare = Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
$python bench_isin.py
(10,), (10,), 10, True, True
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] Output code:
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from ctypes import c_void_p, c_long
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] import torch
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] import math
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] import random
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] import os
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] import tempfile
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from math import inf, nan
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.hooks import run_intermediate_hooks
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.utils import maybe_profile
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.codegen.memory_planning import _align as align
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch import device, empty, empty_strided
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.codecache import AsyncCompile
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.select_algorithm import extern_kernels
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] aten = torch.ops.aten
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] inductor_ops = torch.ops.inductor
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] alloc_from_pool = torch.ops.inductor._alloc_from_pool
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] async_compile = AsyncCompile()
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] # kernel path: /tmp/torchinductor_aaronmeurer/q4/cq4ikhwhzmouh4y4yr6yfdjxy6gcgf5cnj2jbhbbmkweiklven7n.py
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] # Source Nodes: [], Original ATen: []
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] triton_per_fused_0 = async_compile.triton('triton_', '''
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] import triton
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] import triton.language as tl
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.ir import ReductionHint
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.ir import TileHint
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.triton_heuristics import AutotuneHint, persistent_reduction
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.utils import instance_descriptor
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor import triton_helpers
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] @persistent_reduction(
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] size_hints=[16, 16],
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] reduction_hint=ReductionHint.INNER,
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] filename=__file__,
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] triton_meta={'signature': {0: '*i64', 1: '*i64', 2: '*i1', 3: 'i32', 4: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_0', 'mutated_arg_names': []}
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] )
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] @triton.jit
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] xnumel = 10
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] rnumel = 10
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] RBLOCK: tl.constexpr = 16
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] xoffset = tl.program_id(0) * XBLOCK
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] xmask = xindex < xnumel
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] rindex = tl.arange(0, RBLOCK)[None, :]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] rmask = rindex < rnumel
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] x0 = xindex
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] r1 = rindex
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] tmp0 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] tmp2 = tmp0 == tmp1
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] tmp5 = tl.where(rmask & xmask, tmp3, 0)
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] tmp6 = triton_helpers.any(tmp5, 1)[:, None]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] tl.store(out_ptr0 + (x0), tmp6, xmask)
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] ''')
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] import triton
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] import triton.language as tl
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.triton_heuristics import grid, start_graph, end_graph
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] async_compile.wait(globals())
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] del async_compile
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] def call(args):
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] args_1, args_2, args_3 = args
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] args.clear()
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] assert_size_stride(args_1, (10, ), (1, ))
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] assert_size_stride(args_2, (10, ), (1, ))
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] with torch.cuda._DeviceGuard(0):
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] torch.cuda.set_device(0)
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] buf0 = empty((10, ), device='cuda', dtype=torch.bool)
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] # Source Nodes: [], Original ATen: []
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] stream0 = get_raw_stream(0)
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] triton_per_fused_0.run(args_1, args_2, buf0, 10, 10, grid=grid(10), stream=stream0)
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] del args_1
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] del args_2
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] return (buf0, )
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] def benchmark_compiled_module(times=10, repeat=10):
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._dynamo.testing import rand_strided
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.utils import print_performance
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] args_1 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.int64)
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] args_2 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.int64)
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] args_3 = 10
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] fn = lambda: call([args_1, args_2, args_3])
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] return print_performance(fn, times=times, repeat=repeat)
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] if __name__ == "__main__":
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.wrapper_benchmark import compiled_module_main
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG] compiled_module_main('None', benchmark_compiled_module)
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:05,041] torch._inductor.graph.__output_code: [INFO] Output code written to: /tmp/torchinductor_aaronmeurer/h2/ch2d5pkompgzxiq75vunwyergcz6sakx7ueqx6y6kjbzvsfsgenv.py
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] Output code:
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] from ctypes import c_void_p, c_long
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] import torch
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] import math
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] import random
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] import os
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] import tempfile
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] from math import inf, nan
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.hooks import run_intermediate_hooks
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.utils import maybe_profile
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.codegen.memory_planning import _align as align
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] from torch import device, empty, empty_strided
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.codecache import AsyncCompile
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.select_algorithm import extern_kernels
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] aten = torch.ops.aten
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] inductor_ops = torch.ops.inductor
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] alloc_from_pool = torch.ops.inductor._alloc_from_pool
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] async_compile = AsyncCompile()
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] async_compile.wait(globals())
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] del async_compile
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] def call(args):
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] args_1, args_2, args_3 = args
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] args.clear()
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] assert_size_stride(args_1, (10, ), (1, ))
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] assert_size_stride(args_2, (10, ), (1, ))
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] with torch.cuda._DeviceGuard(0):
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] torch.cuda.set_device(0)
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] # Source Nodes: [], Original ATen: []
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] buf0 = aten.isin.Tensor_Tensor(args_1, args_2, assume_unique=True)
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] del args_1
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] del args_2
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] buf1 = buf0
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] return (buf1, )
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] def benchmark_compiled_module(times=10, repeat=10):
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] from torch._dynamo.testing import rand_strided
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.utils import print_performance
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] args_1 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.int64)
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] args_2 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.int64)
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] args_3 = 10
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] fn = lambda: call([args_1, args_2, args_3])
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] return print_performance(fn, times=times, repeat=repeat)
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] if __name__ == "__main__":
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.wrapper_benchmark import compiled_module_main
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG] compiled_module_main('None', benchmark_compiled_module)
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [DEBUG]
[2024-01-15 19:00:07,353] torch._inductor.graph.__output_code: [INFO] Output code written to: /tmp/torchinductor_aaronmeurer/d7/cd7mgpsoiaxbfej6hghzfc52gawhxwpivhvpkbvumxpjbelnlxn7.py
(10,), (10,), 10, True, False
(10,), (10,), 1000, True, True
(10,), (10,), 1000, True, False
(10,), (10,), 100000000, True, True
(10,), (10,), 100000000, True, False
(10,), (10000,), 100000000, True, True
(10,), (10000,), 100000000, True, False
(10,), (1000000,), 100000000, True, True
(10,), (1000000,), 100000000, True, False
(10000,), (10,), 100000000, True, True
(10000,), (10,), 100000000, True, False
(10000,), (10000,), 100000000, True, True
(10000,), (10000,), 100000000, True, False
(10000,), (1000000,), 100000000, True, True
(10000,), (1000000,), 100000000, True, False
(1000000,), (10,), 100000000, True, True
(1000000,), (10,), 100000000, True, False
(1000000,), (10000,), 100000000, True, True
(1000000,), (10000,), 100000000, True, False
(1000000,), (1000000,), 100000000, True, True
(1000000,), (1000000,), 100000000, True, False
(10,), (10,), 10, False, False
(10,), (10,), 1000, False, False
(10,), (10,), 100000000, False, False
(10,), (10000,), 10, False, False
(10,), (10000,), 1000, False, False
(10,), (10000,), 100000000, False, False
(10,), (1000000,), 10, False, False
(10,), (1000000,), 1000, False, False
(10,), (1000000,), 100000000, False, False
(10000,), (10,), 10, False, False
(10000,), (10,), 1000, False, False
(10000,), (10,), 100000000, False, False
(10000,), (10000,), 10, False, False
(10000,), (10000,), 1000, False, False
(10000,), (10000,), 100000000, False, False
(10000,), (1000000,), 10, False, False
(10000,), (1000000,), 1000, False, False
(10000,), (1000000,), 100000000, False, False
(1000000,), (10,), 10, False, False
(1000000,), (10,), 1000, False, False
(1000000,), (10,), 100000000, False, False
(1000000,), (10000,), 10, False, False
(1000000,), (10000,), 1000, False, False
(1000000,), (10000,), 100000000, False, False
(1000000,), (1000000,), 10, False, False
(1000000,), (1000000,), 1000, False, False
(1000000,), (1000000,), 100000000, False, False
[------------------------------------------ isin -----------------------------------------]
| Decomposed | Lowering | Eager
1 threads: --------------------------------------------------------------------------------
(10,), (10,), 10, True, True | 16 | 33 | 26
(10,), (10,), 10, True, False | 16 | 31 | 26
(10,), (10,), 1000, True, True | 16 | 32 | 26
(10,), (10,), 1000, True, False | 16 | 31 | 26
(10,), (10,), 100000000, True, True | 16 | 32 | 26
(10,), (10,), 100000000, True, False | 16 | 31 | 26
(10,), (10000,), 100000000, True, True | 160 | 120 | 114
(10,), (10000,), 100000000, True, False | 78 | 250 | 242
(10,), (1000000,), 100000000, True, True | 263 | 228 | 228
(10,), (1000000,), 100000000, True, False | 178 | 425 | 417
(10000,), (10,), 100000000, True, True | 16 | 32 | 26
(10000,), (10,), 100000000, True, False | 16 | 31 | 26
(10000,), (10000,), 100000000, True, True | 160 | 129 | 120
(10000,), (10000,), 100000000, True, False | 78 | 300 | 300
(10000,), (1000000,), 100000000, True, True | 260 | 228 | 228
(10000,), (1000000,), 100000000, True, False | 181 | 470 | 463
(1000000,), (10,), 100000000, True, True | 16 | 68 | 68
(1000000,), (10,), 100000000, True, False | 16 | 68 | 68
(1000000,), (10000,), 100000000, True, True | 261 | 229 | 228
(1000000,), (10000,), 100000000, True, False | 98 | 543 | 536
(1000000,), (1000000,), 100000000, True, True | 448 | 424 | 424
(1000000,), (1000000,), 100000000, True, False | 320 | 785 | 778
(10,), (10,), 10, False, False | 16 | 31 | 26
(10,), (10,), 1000, False, False | 16 | 31 | 26
(10,), (10,), 100000000, False, False | 16 | 31 | 26
(10,), (10000,), 10, False, False | 78 | 210 | 200
(10,), (10000,), 1000, False, False | 77 | 210 | 210
(10,), (10000,), 100000000, False, False | 77 | 250 | 240
(10,), (1000000,), 10, False, False | 155 | 270 | 259
(10,), (1000000,), 1000, False, False | 163 | 270 | 264
(10,), (1000000,), 100000000, False, False | 178 | 420 | 413
(10000,), (10,), 10, False, False | 16 | 31 | 26
(10000,), (10,), 1000, False, False | 16 | 30 | 26
(10000,), (10,), 100000000, False, False | 16 | 31 | 26
(10000,), (10000,), 10, False, False | 79 | 260 | 250
(10000,), (10000,), 1000, False, False | 78 | 263 | 260
(10000,), (10000,), 100000000, False, False | 78 | 300 | 290
(10000,), (1000000,), 10, False, False | 155 | 315 | 310
(10000,), (1000000,), 1000, False, False | 164 | 321 | 310
(10000,), (1000000,), 100000000, False, False | 181 | 471 | 464
(1000000,), (10,), 10, False, False | 16 | 68 | 68
(1000000,), (10,), 1000, False, False | 16 | 68 | 68
(1000000,), (10,), 100000000, False, False | 16 | 68 | 68
(1000000,), (10000,), 10, False, False | 85 | 360 | 358
(1000000,), (10000,), 1000, False, False | 93 | 388 | 380
(1000000,), (10000,), 100000000, False, False | 98 | 549 | 542
(1000000,), (1000000,), 10, False, False | 176 | 423 | 414
(1000000,), (1000000,), 1000, False, False | 278 | 446 | 439
(1000000,), (1000000,), 100000000, False, False | 322 | 778 | 772
Times are in microseconds (us).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment