Skip to content

Instantly share code, notes, and snippets.

Created February 7, 2022 22:12
Show Gist options
  • Save rejuvyesh/0c0995ac81d8c75efada7797a292f611 to your computer and use it in GitHub Desktop.
Save rejuvyesh/0c0995ac81d8c75efada7797a292f611 to your computer and use it in GitHub Desktop.
DLPack reproduce segfault
using PyCall
using DLPack
using Test
using Zygote
using ChainRulesCore
torch = pyimport("torch")
functorch = pyimport("functorch")
dlpack = pyimport("torch.utils.dlpack")
def buffer_implicit(fn, buffers):
def newfn(params, inputs):
return fn(params, buffers, inputs)
return newfn
pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject
reversedims(x::AbstractArray{T,N}) where {T,N} = permutedims(x, N:-1:1)
function ReverseDimsArray(a::AbstractArray{T,N}) where {T<:AbstractFloat,N}
PermutedDimsArray(a, N:-1:1)
struct TorchModuleWrapper
function TorchModuleWrapper(torch_module)
pybuiltin("isinstance")(torch_module, torch.nn.Module) || error("Not a torch.nn.Module")
device = torch.device("cpu")
funmod, params, buffers = functorch.make_functional_with_buffers(torch_module)
dtype = params[1].dtype
# TODO: shouldn't requrei reversedims
# Ideally should not even require conversion to array, it's already DLPack
jlparams = map(params) do x
reversedims(Array(DLArray(x, pyto_dlpack)))
return TorchModuleWrapper(funmod, dtype, device, jlparams, buffers)
maybecontiguous(x::AbstractArray) = Array(x)
mayebecontiguous(x::StridedArray) = x
function (wrap::TorchModuleWrapper)(args...)
# TODO: handle multiple outputs
params = wrap.params
tensor_out = wrap.torch_stateless_module(Tuple(map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), params)),
wrap.buffers, map(x -> DLPack.share((x), pyfrom_dlpack), args)...)
res = ReverseDimsArray(DLArray(tensor_out, pyto_dlpack))
return res
function ChainRulesCore.rrule(wrap::TorchModuleWrapper, args...)
params = wrap.params
torch_primal, torch_vjpfun = functorch.vjp(py"buffer_implicit"(wrap.torch_stateless_module, wrap.buffers), Tuple(map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), params)),
map(x -> DLPack.share((x), pyfrom_dlpack).to(dtype = wrap.dtype, device = wrap.device).requires_grad_(true), args)...)
project = ProjectTo(args)
function TorchModuleWrapper_pullback(Δ)
torch_tangent_vals = torch_vjpfun(DLPack.share((maybecontiguous(Δ)), pyfrom_dlpack))
jlparams_tangents = map(x -> ReverseDimsArray(DLArray(x, pyto_dlpack)), torch_tangent_vals[1])
args_tangents = project(map(x -> ReverseDimsArray(DLArray(x, pyto_dlpack)), torch_tangent_vals[2:end]))
return (Tangent{TorchModuleWrapper}(; torch_stateless_module = NoTangent(), dtype = NoTangent(), device = NoTangent(), params = jlparams_tangents, buffers = NoTangent()), args_tangents...)
res = ReverseDimsArray(DLArray(torch_primal, pyto_dlpack))
return res, TorchModuleWrapper_pullback
batchsize = 1
indim = 3
outdim = 2
hiddendim = 4
function compare_grad_wrt_params(modelwrap, inputs...)
params = map(x -> torch.as_tensor(copy(ReverseDimsArray(x))).to(device = modelwrap.device, dtype = modelwrap.dtype).requires_grad_(true), (modelwrap.params))
torch_out = modelwrap.torch_stateless_module(params, modelwrap.buffers, map(z->torch.as_tensor(PyReverseDims(copy(z))).to(dtype=modelwrap.dtype), inputs)...).sum()
torchgrad = map(x-> ReverseDimsArray(x.numpy()), torch.autograd.grad(torch_out, params))
grad, = Zygote.gradient(m->sum(m(inputs...)), modelwrap)
@test length(torchgrad) == length(grad.params)
for i in 1:length(grad.params)
@test isapprox(torchgrad[i], grad.params[i])
@test length(grad.params) == length(modelwrap.params)
@test grad.params[1] !== nothing
@test grad.params[2] !== nothing
@test size(grad.params[1]) == size(modelwrap.params[1])
@test size(grad.params[2]) == size(modelwrap.params[2])
lin = torch.nn.Linear(indim, outdim)
torchparams = Tuple([copy(DLArray(p, pyto_dlpack)) for p in lin.parameters()]) # (outdim, indim), (outdim,)),
linwrap = TorchModuleWrapper(lin)
x = randn(Float32, indim, batchsize)
y = linwrap(x)
compare_grad_wrt_params(linwrap, deepcopy(x))
Copy link

Can you add @show DLPack.PYCALL_NOOP_DELETER just after loading the packages and then share the output of this when it segfaults?

Copy link

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
┌ Warning: `vendor()` is deprecated, use `BLAS.get_config()` and inspect the output instead
│   caller = npyinitialize() at numpy.jl:67
└ @ PyCall ~/.julia/packages/PyCall/L0fLP/src/numpy.jl:67
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
error in running finalizer: ReadOnlyMemoryError()
@pyglobalobj at /home/jagupt/.julia/packages/PyCall/L0fLP/src/startup.jl:145
unknown function (ip: (nil))
error in running finalizer: ReadOnlyMemoryError()
@pyglobalobj at /home/jagupt/.julia/packages/PyCall/L0fLP/src/startup.jl:145
unknown function (ip: 0x961e6ff)
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
error in running finalizer: ReadOnlyMemoryError()
@pyglobalobj at /home/jagupt/.julia/packages/PyCall/L0fLP/src/startup.jl:145
unknown function (ip: 0x7f2fa3f509bf)

signal (11): Segmentation fault
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/stresstest_dlpack.jl:104
jl_system_image_data at /home/jagupt/.julia/juliaup/julia-1.7.1+0~x64/lib/julia/ (unknown line)
Allocations: 64202361 (Pool: 64181938; Big: 20423); GC: 67

Copy link

pabloferz commented Feb 9, 2022

Does changing the first two lines by

using Distributed
@everywhere using PyCall
@everywhere using DLPack

fixes the issue?

EDIT: No, I don't think it does.

Copy link

Yep, it doesn't fix the issue. Although would love to understand your train of thought that led to this test!

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
┌ Warning: `vendor()` is deprecated, use `BLAS.get_config()` and inspect the output instead
│   caller = npyinitialize() at numpy.jl:67
└ @ PyCall ~/.julia/packages/PyCall/L0fLP/src/numpy.jl:67
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0

signal (11): Segmentation fault
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/stresstest_dlpack.jl:107
unknown function (ip: 0x7f39d7588580)
Allocations: 62960043 (Pool: 62939873; Big: 20170); GC: 69

Copy link

Ok. I think I found the culprit. Going over the pytorch repo I found that the DLManagedTensor is captured, so we need not only to keep the array around, but also the tensor. I'll update the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment