Skip to content

Instantly share code, notes, and snippets.

@rejuvyesh
Last active February 9, 2022 17:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rejuvyesh/be230c57faa1bffeffc57f6d4f9a9514 to your computer and use it in GitHub Desktop.
Save rejuvyesh/be230c57faa1bffeffc57f6d4f9a9514 to your computer and use it in GitHub Desktop.
DLPACk segfault reproduce on CUDA+Jax
using PyCall
using CUDA
using DLPack
using Test
#using Zygote
#using ChainRulesCore
@show DLPack.PYCALL_NOOP_DELETER
jax = pyimport("jax")
dlpack = pyimport("jax.dlpack")
numpy = pyimport("numpy")
pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject
@testset "dlpack" begin
key = jax.random.PRNGKey(0)
for dims in ((10,), (1, 10), (2, 3, 5), (2, 3, 4, 5))
xto = jax.random.normal(key, dims)
xjl = DLArray(xto, pyto_dlpack)
@test isapprox(numpy.array(xto), Array(xjl))
end
end
cujl = CuArray(randn(Float32, 2, 3, 5))
cujax = DLPack.share(cujl, pyfrom_dlpack)
@test (jax.numpy.sum(cujax).item()) ≈ sum(cujl)
@rejuvyesh
Copy link
Author

Edit: I haven't been able to get these again, huh.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment