Skip to content

Instantly share code, notes, and snippets.

@rejuvyesh
Created February 7, 2022 22:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rejuvyesh/0c0995ac81d8c75efada7797a292f611 to your computer and use it in GitHub Desktop.
Save rejuvyesh/0c0995ac81d8c75efada7797a292f611 to your computer and use it in GitHub Desktop.
DLPack reproduce segfault
using PyCall
using DLPack
using Test
using Zygote
using ChainRulesCore
torch = pyimport("torch")
functorch = pyimport("functorch")
dlpack = pyimport("torch.utils.dlpack")
py"""
def buffer_implicit(fn, buffers):
def newfn(params, inputs):
return fn(params, buffers, inputs)
return newfn
"""
pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject
reversedims(x::AbstractArray{T,N}) where {T,N} = permutedims(x, N:-1:1)
function ReverseDimsArray(a::AbstractArray{T,N}) where {T<:AbstractFloat,N}
PermutedDimsArray(a, N:-1:1)
end
struct TorchModuleWrapper
torch_stateless_module::PyObject
dtype::PyObject
device::PyObject
params::Tuple
buffers::Tuple
end
function TorchModuleWrapper(torch_module)
pybuiltin("isinstance")(torch_module, torch.nn.Module) || error("Not a torch.nn.Module")
device = torch.device("cpu")
funmod, params, buffers = functorch.make_functional_with_buffers(torch_module)
dtype = params[1].dtype
# TODO: shouldn't requrei reversedims
# Ideally should not even require conversion to array, it's already DLPack
jlparams = map(params) do x
reversedims(Array(DLArray(x, pyto_dlpack)))
end
return TorchModuleWrapper(funmod, dtype, device, jlparams, buffers)
end
maybecontiguous(x::AbstractArray) = Array(x)
mayebecontiguous(x::StridedArray) = x
function (wrap::TorchModuleWrapper)(args...)
# TODO: handle multiple outputs
params = wrap.params
tensor_out = wrap.torch_stateless_module(Tuple(map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), params)),
wrap.buffers, map(x -> DLPack.share((x), pyfrom_dlpack), args)...)
res = ReverseDimsArray(DLArray(tensor_out, pyto_dlpack))
return res
end
function ChainRulesCore.rrule(wrap::TorchModuleWrapper, args...)
params = wrap.params
torch_primal, torch_vjpfun = functorch.vjp(py"buffer_implicit"(wrap.torch_stateless_module, wrap.buffers), Tuple(map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), params)),
map(x -> DLPack.share((x), pyfrom_dlpack).to(dtype = wrap.dtype, device = wrap.device).requires_grad_(true), args)...)
project = ProjectTo(args)
function TorchModuleWrapper_pullback(Δ)
torch_tangent_vals = torch_vjpfun(DLPack.share((maybecontiguous(Δ)), pyfrom_dlpack))
jlparams_tangents = map(x -> ReverseDimsArray(DLArray(x, pyto_dlpack)), torch_tangent_vals[1])
args_tangents = project(map(x -> ReverseDimsArray(DLArray(x, pyto_dlpack)), torch_tangent_vals[2:end]))
return (Tangent{TorchModuleWrapper}(; torch_stateless_module = NoTangent(), dtype = NoTangent(), device = NoTangent(), params = jlparams_tangents, buffers = NoTangent()), args_tangents...)
end
res = ReverseDimsArray(DLArray(torch_primal, pyto_dlpack))
return res, TorchModuleWrapper_pullback
end
batchsize = 1
indim = 3
outdim = 2
hiddendim = 4
function compare_grad_wrt_params(modelwrap, inputs...)
params = map(x -> torch.as_tensor(copy(ReverseDimsArray(x))).to(device = modelwrap.device, dtype = modelwrap.dtype).requires_grad_(true), (modelwrap.params))
torch_out = modelwrap.torch_stateless_module(params, modelwrap.buffers, map(z->torch.as_tensor(PyReverseDims(copy(z))).to(dtype=modelwrap.dtype), inputs)...).sum()
torchgrad = map(x-> ReverseDimsArray(x.numpy()), torch.autograd.grad(torch_out, params))
grad, = Zygote.gradient(m->sum(m(inputs...)), modelwrap)
@test length(torchgrad) == length(grad.params)
for i in 1:length(grad.params)
@test isapprox(torchgrad[i], grad.params[i])
end
@test length(grad.params) == length(modelwrap.params)
@test grad.params[1] !== nothing
@test grad.params[2] !== nothing
@test size(grad.params[1]) == size(modelwrap.params[1])
@test size(grad.params[2]) == size(modelwrap.params[2])
end
lin = torch.nn.Linear(indim, outdim)
torchparams = Tuple([copy(DLArray(p, pyto_dlpack)) for p in lin.parameters()]) # (outdim, indim), (outdim,)),
linwrap = TorchModuleWrapper(lin)
x = randn(Float32, indim, batchsize)
y = linwrap(x)
compare_grad_wrt_params(linwrap, deepcopy(x))
@pabloferz
Copy link

Can you add @show DLPack.PYCALL_NOOP_DELETER just after loading the packages and then share the output of this when it segfaults?

@rejuvyesh
Copy link
Author

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
┌ Warning: `vendor()` is deprecated, use `BLAS.get_config()` and inspect the output instead
│   caller = npyinitialize() at numpy.jl:67
└ @ PyCall ~/.julia/packages/PyCall/L0fLP/src/numpy.jl:67
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
error in running finalizer: ReadOnlyMemoryError()
@pyglobalobj at /home/jagupt/.julia/packages/PyCall/L0fLP/src/startup.jl:145
unknown function (ip: (nil))
error in running finalizer: ReadOnlyMemoryError()
@pyglobalobj at /home/jagupt/.julia/packages/PyCall/L0fLP/src/startup.jl:145
unknown function (ip: 0x961e6ff)
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)
julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f3085010b00
error in running finalizer: ReadOnlyMemoryError()
@pyglobalobj at /home/jagupt/.julia/packages/PyCall/L0fLP/src/startup.jl:145
unknown function (ip: 0x7f2fa3f509bf)

signal (11): Segmentation fault
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/stresstest_dlpack.jl:104
jl_system_image_data at /home/jagupt/.julia/juliaup/julia-1.7.1+0~x64/lib/julia/sys.so (unknown line)
Allocations: 64202361 (Pool: 64181938; Big: 20423); GC: 67

@pabloferz
Copy link

pabloferz commented Feb 9, 2022

Does changing the first two lines by

using Distributed
@everywhere using PyCall
@everywhere using DLPack

fixes the issue?

EDIT: No, I don't think it does.

@rejuvyesh
Copy link
Author

Yep, it doesn't fix the issue. Although would love to understand your train of thought that led to this test!

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
┌ Warning: `vendor()` is deprecated, use `BLAS.get_config()` and inspect the output instead
│   caller = npyinitialize() at numpy.jl:67
└ @ PyCall ~/.julia/packages/PyCall/L0fLP/src/numpy.jl:67
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0

signal (11): Segmentation fault
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/stresstest_dlpack.jl:107
unknown function (ip: 0x7f39d7588580)
Allocations: 62960043 (Pool: 62939873; Big: 20170); GC: 69

@pabloferz
Copy link

Ok. I think I found the culprit. Going over the pytorch repo I found that the DLManagedTensor is captured https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/DLConvertor.cpp#L243, so we need not only to keep the array around, but also the tensor. I'll update the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment