Skip to content

Instantly share code, notes, and snippets.

@terasakisatoshi
Last active June 3, 2022 08:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save terasakisatoshi/2dc01a625bab95b893320becd1aba11e to your computer and use it in GitHub Desktop.
Save terasakisatoshi/2dc01a625bab95b893320becd1aba11e to your computer and use it in GitHub Desktop.
Torch: DataLoader, DLpack, PyCall
using PyCall
using ProgressMeter
using DLPack
py"""
import torch
from tqdm import tqdm
# just in case
# https://github.com/pytorch/pytorch/issues/11201
#torch.multiprocessing.set_sharing_strategy(sharing_strategy)
sharing_strategy = "file_system"
def set_worker_sharing_strategy(worker_id: int) -> None:
torch.multiprocessing.set_sharing_strategy(sharing_strategy)
class PyData(torch.utils.data.Dataset):
def __init__(self, ndata):
self.ndata = ndata
def __len__(self):
return self.ndata
def __getitem__(self, idx):
return torch.rand(3,128,224)
ndata = 100000
pydata = PyData(ndata=ndata)
def maintorch():
loader = torch.utils.data.DataLoader(
pydata, batch_size=128, num_workers=16,
)
for batch in tqdm(loader):
assert batch.data.size()[1:4] == (3,128,224)
for batch in tqdm(loader):
assert batch.data.size()[1:4] == (3,128,224)
print("Dekita-dekyu: Done")
"""
using PyCall
struct PyWrapDataset
pyobj::PyObject
end
PyObject(t::PyWrapDataset) = t.pyobj
Base.propertynames(t::PyWrapDataset) = propertynames(getfield(t, :pyobj))
function Base.getproperty(t::PyWrapDataset, s::Symbol)
if s ∈ fieldnames(PyWrapDataset)
return getfield(t, s)
else
return getproperty(getfield(t, :pyobj), s)
end
end
Base.getindex(dset::PyWrapDataset, idx) = get(dset.pyobj, idx)
Base.length(dset::PyWrapDataset) = pybuiltin("len")(dset)
const torch = py"torch"
function to_dlpack(o)
@pycall torch.to_dlpack(o)::PyObject
end
function maindlpack()
rd = PyWrapDataset(py"PyData"(py"ndata"))
loader = torch.utils.data.DataLoader(
rd, batch_size=128, num_workers=16,
#worker_init_fn=py"set_worker_sharing_strategy",
)
@showprogress for (i, pybatch) in enumerate(loader)
jlbatch = DLPack.wrap(pybatch, to_dlpack)
@assert size(jlbatch)[1:3]==(224, 128, 3)
gcfull = false
GC.gc(gcfull)
end
end
@time py"maintorch"()
@time maindlpack()
@terasakisatoshi
Copy link
Author

Result

julia> @time py"maintorch"()
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [00:10<00:00, 75.35it/s]
Dekita-dekyu: Done
 10.380944 seconds (5 allocations: 272 bytes)

julia> @time maindlpack()
Progress: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:09
  9.486080 seconds (90.71 k allocations: 5.206 MiB, 5.97% gc time, 0.24% compilation time)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment