Skip to content

Instantly share code, notes, and snippets.

@rejuvyesh
Created February 7, 2022 22:12
Show Gist options
  • Save rejuvyesh/0c0995ac81d8c75efada7797a292f611 to your computer and use it in GitHub Desktop.
Save rejuvyesh/0c0995ac81d8c75efada7797a292f611 to your computer and use it in GitHub Desktop.
DLPack reproduce segfault
using PyCall
using DLPack
using Test
using Zygote
using ChainRulesCore
torch = pyimport("torch")
functorch = pyimport("functorch")
dlpack = pyimport("torch.utils.dlpack")
py"""
def buffer_implicit(fn, buffers):
def newfn(params, inputs):
return fn(params, buffers, inputs)
return newfn
"""
pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject
reversedims(x::AbstractArray{T,N}) where {T,N} = permutedims(x, N:-1:1)
function ReverseDimsArray(a::AbstractArray{T,N}) where {T<:AbstractFloat,N}
PermutedDimsArray(a, N:-1:1)
end
struct TorchModuleWrapper
torch_stateless_module::PyObject
dtype::PyObject
device::PyObject
params::Tuple
buffers::Tuple
end
function TorchModuleWrapper(torch_module)
pybuiltin("isinstance")(torch_module, torch.nn.Module) || error("Not a torch.nn.Module")
device = torch.device("cpu")
funmod, params, buffers = functorch.make_functional_with_buffers(torch_module)
dtype = params[1].dtype
# TODO: shouldn't requrei reversedims
# Ideally should not even require conversion to array, it's already DLPack
jlparams = map(params) do x
reversedims(Array(DLArray(x, pyto_dlpack)))
end
return TorchModuleWrapper(funmod, dtype, device, jlparams, buffers)
end
maybecontiguous(x::AbstractArray) = Array(x)
mayebecontiguous(x::StridedArray) = x
function (wrap::TorchModuleWrapper)(args...)
# TODO: handle multiple outputs
params = wrap.params
tensor_out = wrap.torch_stateless_module(Tuple(map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), params)),
wrap.buffers, map(x -> DLPack.share((x), pyfrom_dlpack), args)...)
res = ReverseDimsArray(DLArray(tensor_out, pyto_dlpack))
return res
end
function ChainRulesCore.rrule(wrap::TorchModuleWrapper, args...)
params = wrap.params
torch_primal, torch_vjpfun = functorch.vjp(py"buffer_implicit"(wrap.torch_stateless_module, wrap.buffers), Tuple(map(x -> DLPack.share((x), pyfrom_dlpack).requires_grad_(true), params)),
map(x -> DLPack.share((x), pyfrom_dlpack).to(dtype = wrap.dtype, device = wrap.device).requires_grad_(true), args)...)
project = ProjectTo(args)
function TorchModuleWrapper_pullback(Δ)
torch_tangent_vals = torch_vjpfun(DLPack.share((maybecontiguous(Δ)), pyfrom_dlpack))
jlparams_tangents = map(x -> ReverseDimsArray(DLArray(x, pyto_dlpack)), torch_tangent_vals[1])
args_tangents = project(map(x -> ReverseDimsArray(DLArray(x, pyto_dlpack)), torch_tangent_vals[2:end]))
return (Tangent{TorchModuleWrapper}(; torch_stateless_module = NoTangent(), dtype = NoTangent(), device = NoTangent(), params = jlparams_tangents, buffers = NoTangent()), args_tangents...)
end
res = ReverseDimsArray(DLArray(torch_primal, pyto_dlpack))
return res, TorchModuleWrapper_pullback
end
batchsize = 1
indim = 3
outdim = 2
hiddendim = 4
function compare_grad_wrt_params(modelwrap, inputs...)
params = map(x -> torch.as_tensor(copy(ReverseDimsArray(x))).to(device = modelwrap.device, dtype = modelwrap.dtype).requires_grad_(true), (modelwrap.params))
torch_out = modelwrap.torch_stateless_module(params, modelwrap.buffers, map(z->torch.as_tensor(PyReverseDims(copy(z))).to(dtype=modelwrap.dtype), inputs)...).sum()
torchgrad = map(x-> ReverseDimsArray(x.numpy()), torch.autograd.grad(torch_out, params))
grad, = Zygote.gradient(m->sum(m(inputs...)), modelwrap)
@test length(torchgrad) == length(grad.params)
for i in 1:length(grad.params)
@test isapprox(torchgrad[i], grad.params[i])
end
@test length(grad.params) == length(modelwrap.params)
@test grad.params[1] !== nothing
@test grad.params[2] !== nothing
@test size(grad.params[1]) == size(modelwrap.params[1])
@test size(grad.params[2]) == size(modelwrap.params[2])
end
lin = torch.nn.Linear(indim, outdim)
torchparams = Tuple([copy(DLArray(p, pyto_dlpack)) for p in lin.parameters()]) # (outdim, indim), (outdim,)),
linwrap = TorchModuleWrapper(lin)
x = randn(Float32, indim, batchsize)
y = linwrap(x)
compare_grad_wrt_params(linwrap, deepcopy(x))
@pabloferz
Copy link

pabloferz commented Feb 9, 2022

Does changing the first two lines by

using Distributed
@everywhere using PyCall
@everywhere using DLPack

fixes the issue?

EDIT: No, I don't think it does.

@rejuvyesh
Copy link
Author

Yep, it doesn't fix the issue. Although would love to understand your train of thought that led to this test!

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
┌ Warning: `vendor()` is deprecated, use `BLAS.get_config()` and inspect the output instead
│   caller = npyinitialize() at numpy.jl:67
└ @ PyCall ~/.julia/packages/PyCall/L0fLP/src/numpy.jl:67
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0
Test Passed
  Expression: size(grad.params[2]) == size(modelwrap.params[2])
   Evaluated: (2,) == (2,)

julia> include("test/stresstest_dlpack.jl")
DLPack.PYCALL_NOOP_DELETER = Ptr{Nothing} @0x00007f39f1456bf0

signal (11): Segmentation fault
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/stresstest_dlpack.jl:107
unknown function (ip: 0x7f39d7588580)
Allocations: 62960043 (Pool: 62939873; Big: 20170); GC: 69

@pabloferz
Copy link

Ok. I think I found the culprit. Going over the pytorch repo I found that the DLManagedTensor is captured https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/DLConvertor.cpp#L243, so we need not only to keep the array around, but also the tensor. I'll update the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment