Skip to content

Instantly share code, notes, and snippets.

@cpebble
Created October 23, 2022 15:55
Show Gist options
  • Save cpebble/d8b2a059798ac3e54a2ae00755158250 to your computer and use it in GitHub Desktop.
Save cpebble/d8b2a059798ac3e54a2ae00755158250 to your computer and use it in GitHub Desktop.
import torch.multiprocessing as mp
mp.set_sharing_strategy("file_system")
class InMemoryMap(Dataset):
def __init__(self,
ds,
gpu=False):
self.ds = ds
self.device = torch.device("cuda:0" if gpu else "cpu")
self.samples = None
self.labels = None
self.len = len(ds)
def loadToMem(self, num_workers=6):
cxt = mp.get_context('fork')
with cxt.Pool(num_workers) as pool:
data = pool.map(self.ds.__getitem__, range(self.len))
pool.close()
# Ugly unpacking
samples = [x for (x,_) in data]
labels = [y for (_,y) in data]
self.samples = torch.stack(samples).to(self.device)
self.labels = torch.tensor(labels).to(self.device)
def __getitem__(self, index):
assert(self.samples != None)
return self.samples[index], self.labels[index]
def __len__(self):
return self.len
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment