Skip to content

Instantly share code, notes, and snippets.

Created February 14, 2021 16:41
Show Gist options
  • Save binarybana/ec17e4b9bedebf03a0aca65df1611bb1 to your computer and use it in GitHub Desktop.
Save binarybana/ec17e4b9bedebf03a0aca65df1611bb1 to your computer and use it in GitHub Desktop.
import tvm
from tvm import relay
from tvm.contrib import graph_runtime
import numpy as np
shapes = []
for sort in [1000, 10000, 100000, 1000000]:
for non_sort in [1, 10, 100]:
for axis in [0, 1, 2]:
if non_sort * non_sort * sort < 1e9:
shape = [non_sort] * 3
shape[axis] = sort
shapes.append([tuple(shape), axis])
shapes += [
[(4507,), 0],
[(1, 122640), 1],
[(1, 120000), 1],
[(1, 30000), 1],
[(1, 7500), 1],
[(1, 1000), 1],
envs = [(, 'opencl'),
(tvm.cpu(0), 'llvm')]
#(tvm.metal(0), 'metal')]
def is_sorted(x, axis):
nz_diffs = np.diff(x, axis=axis) < 0
return nz_diffs.sum() == 0
for shape, axis in shapes:
for ctx, target in envs:
x = relay.var("x", relay.TensorType(shape, "int32"))
z = relay.argsort(x, axis=axis, is_ascend=True, dtype="int32")
func = relay.Function([x], z)
mod =
with tvm.transform.PassContext(opt_level=3):
lib =, target)
m = graph_runtime.GraphModule(lib['default'](ctx))
np_x = np.random.randint(10000, size=shape, dtype='int32')
res = m.get_output(0).asnumpy()
# assert(is_sorted(np.take_along_axis(np_x, res, axis=axis), axis))
b = is_sorted(np.take_along_axis(np_x, res, axis=axis), axis)
if b:
print(f"{target}, {shape}, {axis} Correct sort")
print(f"{target}, {shape}, {axis} BAD SORT!")
print(f"{target}, {shape}, {axis} BAD SORT!")
import pandas as pa
data = {}
for ctx, target in envs:
values = []
for shape, axis in shapes:
print(f"Shape: {shape}, axis: {axis}, target: {target}")
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.argsort(x, axis=axis, is_ascend=True, dtype="int32")
func = relay.Function([x], z)
mod =
with tvm.transform.PassContext(opt_level=3):
lib =, target)
m = graph_runtime.GraphModule(lib['default'](ctx))
ftimer = m.module.time_evaluator("run", ctx, number=10, repeat=5)
prof_res = np.array(ftimer().results)
print(shape, "\t", "%.2f ms" %
(np.mean(prof_res) * 1000))
data[target] = values
import timeit
values = []
for shape, axis in shapes:
x = np.empty(shape, dtype="float32")
t = timeit.Timer('x.argsort(axis=axis)', globals=globals())
prof_res = t.autorange()
t = prof_res[1]/prof_res[0]
print(f"{shape} \t {t * 1000:.2f} ms")
data['numpy'] = values
d = pa.DataFrame(data, index=[str(x) for x in shapes])
d['numpy/ocl'] = d['numpy']/d['opencl']
d['numpy/llvm'] = d['numpy']/d['llvm']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment