Last active
June 3, 2022 08:01
-
-
Save terasakisatoshi/2dc01a625bab95b893320becd1aba11e to your computer and use it in GitHub Desktop.
Torch: DataLoader, DLpack, PyCall
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using PyCall | |
using ProgressMeter | |
using DLPack | |
py""" | |
import torch | |
from tqdm import tqdm | |
# just in case | |
# https://github.com/pytorch/pytorch/issues/11201 | |
#torch.multiprocessing.set_sharing_strategy(sharing_strategy) | |
sharing_strategy = "file_system" | |
def set_worker_sharing_strategy(worker_id: int) -> None: | |
torch.multiprocessing.set_sharing_strategy(sharing_strategy) | |
class PyData(torch.utils.data.Dataset): | |
def __init__(self, ndata): | |
self.ndata = ndata | |
def __len__(self): | |
return self.ndata | |
def __getitem__(self, idx): | |
return torch.rand(3,128,224) | |
ndata = 100000 | |
pydata = PyData(ndata=ndata) | |
def maintorch(): | |
loader = torch.utils.data.DataLoader( | |
pydata, batch_size=128, num_workers=16, | |
) | |
for batch in tqdm(loader): | |
assert batch.data.size()[1:4] == (3,128,224) | |
for batch in tqdm(loader): | |
assert batch.data.size()[1:4] == (3,128,224) | |
print("Dekita-dekyu: Done") | |
""" | |
using PyCall | |
struct PyWrapDataset | |
pyobj::PyObject | |
end | |
PyObject(t::PyWrapDataset) = t.pyobj | |
Base.propertynames(t::PyWrapDataset) = propertynames(getfield(t, :pyobj)) | |
function Base.getproperty(t::PyWrapDataset, s::Symbol) | |
if s ∈ fieldnames(PyWrapDataset) | |
return getfield(t, s) | |
else | |
return getproperty(getfield(t, :pyobj), s) | |
end | |
end | |
Base.getindex(dset::PyWrapDataset, idx) = get(dset.pyobj, idx) | |
Base.length(dset::PyWrapDataset) = pybuiltin("len")(dset) | |
const torch = py"torch" | |
function to_dlpack(o) | |
@pycall torch.to_dlpack(o)::PyObject | |
end | |
function maindlpack() | |
rd = PyWrapDataset(py"PyData"(py"ndata")) | |
loader = torch.utils.data.DataLoader( | |
rd, batch_size=128, num_workers=16, | |
#worker_init_fn=py"set_worker_sharing_strategy", | |
) | |
@showprogress for (i, pybatch) in enumerate(loader) | |
jlbatch = DLPack.wrap(pybatch, to_dlpack) | |
@assert size(jlbatch)[1:3]==(224, 128, 3) | |
gcfull = false | |
GC.gc(gcfull) | |
end | |
end | |
@time py"maintorch"() | |
@time maindlpack() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Result