Skip to content

Instantly share code, notes, and snippets.

@Laurawly
Created January 13, 2019 07:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Laurawly/66e8105c8db300bbce0771c1e58853ad to your computer and use it in GitHub Desktop.
Save Laurawly/66e8105c8db300bbce0771c1e58853ad to your computer and use it in GitHub Desktop.
def test_sort():
dshape = (1, 500)
data = tvm.placeholder(dshape, name="data")
np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype)
print(np_data)
np_result = np.argsort(-np_data)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
out = topi.cuda.sort(data)
s = topi.generic.schedule_sort(out)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype="int32"), ctx)
f = tvm.build(s, [data, out], device)
f(tvm_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
for device in ['cuda']:
check_device(device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment