Skip to content

Instantly share code, notes, and snippets.

@wkcn
Last active August 7, 2018 06:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wkcn/1d0151c898582541e6eb74f0162eaa87 to your computer and use it in GitHub Desktop.
Save wkcn/1d0151c898582541e6eb74f0162eaa87 to your computer and use it in GitHub Desktop.
DLPackTest
import mxnet as mx
import numpy as np
import torch
from torch.utils import dlpack
def test_dlpack():
for dtype in [np.float32, np.int32]:
for shape in [(3, 4, 5, 6), (2, 10), (15,)]:
a = mx.nd.random.uniform(shape = shape)
a_np = a.asnumpy()
pack = a.to_dlpack_for_read()
b = mx.nd.from_dlpack(pack)
a_copy = a.copy()
pack2 = a_copy.to_dlpack_for_write()
c = mx.nd.from_dlpack(pack2)
pack3 = mx.nd.to_dlpack_for_read(a)
d = mx.nd.from_dlpack(pack3)
a_copy = a.copy()
pack4 = mx.nd.to_dlpack_for_write(a_copy)
e = mx.nd.from_dlpack(pack4)
del a, pack, pack2, pack3, pack4
b_np = b.asnumpy()
c_np = c.asnumpy()
d_np = d.asnumpy()
e_np = e.asnumpy()
mx.test_utils.assert_almost_equal(a_np, b_np)
mx.test_utils.assert_almost_equal(a_np, c_np)
mx.test_utils.assert_almost_equal(a_np, d_np)
mx.test_utils.assert_almost_equal(a_np, e_np)
def test_dlpack_torch():
a = torch.tensor([1,2,3])
b = dlpack.to_dlpack(a)
c = mx.nd.from_dlpack(b)
a_np = a.numpy()
c_np = c.asnumpy()
mx.test_utils.assert_almost_equal(a_np, c_np)
# torch doesn't allow dlpack's strides nullptr, so we don't test it. :-(
test_dlpack()
test_dlpack_torch()
print ("OK")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment