Skip to content

Instantly share code, notes, and snippets.

@Edouard360
Last active May 7, 2018 16:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Edouard360/4615b4dc35b58c0a2a89be57e7b165d5 to your computer and use it in GitHub Desktop.
Save Edouard360/4615b4dc35b58c0a2a89be57e7b165d5 to your computer and use it in GitHub Desktop.
Iterating with the dataloader through different Datasets
import re
import numpy as np
import torch
from scvi.dataset.dataset import GeneExpressionDataset
from scvi.dataset.synthetic import SyntheticDataset
class NamedDataset1(GeneExpressionDataset):
def __init__(self, X, local_means, local_vars, batch_indices, labels, gene_names=None, n_batches=1):
self.total_size, self.nb_genes = X.shape
my_type = np.dtype(
[('counts', 'f4', (X.shape[1],)),
('local_mean', 'f4'),
('local_var', 'f4'),
('batch_indices', 'f4'),
('labels', 'f4')])
new_X = [(X[i], local_means[i][0],
local_vars[i][0],
batch_indices[i][0],
labels[i][0]) for i in range(len(X))]
self.structured_array = np.array(new_X, dtype=my_type)
def __getitem__(self, idx):
return self.structured_array[idx]
@staticmethod
def collate_fn(batch):
counts, local_mean, local_var, batch_indices, labels = [], [], [], [], []
for b in batch:
counts += [b[0].reshape(1, -1)]
local_mean += [[b[1]]]
local_var += [[b[2]]]
batch_indices += [[b[3]]]
labels += [[b[4]]]
counts = np.concatenate(counts)
return [torch.FloatTensor(l) for l in [counts, local_mean, local_var, batch_indices, labels]]
class NamedDataset2(GeneExpressionDataset):
def __init__(self, X, local_means, local_vars, batch_indices, labels, gene_names=None, n_batches=1):
self.total_size, self.nb_genes = X.shape
my_type = np.dtype(
[('counts', 'f4', (1, X.shape[1])),
('local_mean', 'f4', (1, 1)),
('local_var', 'f4', (1, 1)),
('batch_indices', 'f4', (1, 1)),
('labels', 'f4', (1, 1))])
new_X = [(X[i], local_means[i][0],
local_vars[i][0],
batch_indices[i][0],
labels[i][0]) for i in range(len(X))]
self.structured_array = np.array(new_X, dtype=my_type)
def __getitem__(self, idx):
return self.structured_array[idx]
@staticmethod
def collate_fn(batch):
counts, local_mean, local_var, batch_indices, labels = [], [], [], [], []
for b in batch:
counts += [b[0]]
local_mean += [b[1]]
local_var += [b[2]]
batch_indices += [b[3]]
labels += [b[4]]
return [torch.FloatTensor(np.concatenate(l))
for l in [counts, local_mean, local_var, batch_indices, labels]]
class ConcatenatedDataset1(GeneExpressionDataset):
def __init__(self, X, local_means, local_vars, batch_indices, labels, gene_names=None, n_batches=1):
self.total_size, self.nb_genes = X.shape
self.X = np.concatenate([X, local_means, local_vars, batch_indices, labels], axis=1)
def __getitem__(self, idx):
return self.X[idx, :self.nb_genes], self.X[idx, self.nb_genes], self.X[idx, self.nb_genes + 1], \
self.X[idx, self.nb_genes + 2], self.X[idx, self.nb_genes + 3]
@staticmethod
def collate_fn(batch):
cat_batch = np.concatenate([np.concatenate([el[0], np.array(el[1:])]).reshape(1, -1) for el in batch], axis=0)
return torch.FloatTensor(cat_batch[:, :-4]), torch.FloatTensor(cat_batch[:, [-4]]), \
torch.FloatTensor(cat_batch[:, [-3]]), torch.FloatTensor(cat_batch[:, [-2]]), \
torch.FloatTensor(cat_batch[:, [-1]])
class ConcatenatedDataset2(ConcatenatedDataset1):
def __getitem__(self, idx):
return self.X[idx]
@staticmethod
def collate_fn(batch):
cat_batch = np.concatenate(batch).reshape((len(batch), -1))
return torch.FloatTensor(cat_batch[:, :-4]), torch.FloatTensor(cat_batch[:, [-4]]), \
torch.FloatTensor(cat_batch[:, [-3]]), torch.FloatTensor(cat_batch[:, [-2]]), \
torch.FloatTensor(cat_batch[:, [-1]])
class ConcatenatedDataset3(ConcatenatedDataset2):
def __init__(self, *args, **kwargs):
super(ConcatenatedDataset2, self).__init__(*args, **kwargs)
self.X = torch.FloatTensor(self.X)
@staticmethod
def collate_fn(batch):
cat_batch = torch.cat(batch).view(len(batch), -1)
return cat_batch[:, :-4], cat_batch[:, [-4]], \
cat_batch[:, [-3]], cat_batch[:, [-2]], \
cat_batch[:, [-1]]
class SyntheticNamedDataset1(SyntheticDataset, NamedDataset1):
pass
class SyntheticNamedDataset2(SyntheticDataset, NamedDataset2):
pass
class SyntheticConcatenatedDataset1(SyntheticDataset, ConcatenatedDataset1):
pass
class SyntheticConcatenatedDataset2(SyntheticDataset, ConcatenatedDataset2):
pass
class SyntheticConcatenatedDataset3(SyntheticDataset, ConcatenatedDataset3):
pass
original_setup = '''
from scvi.utils import to_cuda
from scvi.dataset.dataset import GeneExpressionDataset
from scvi.dataset.synthetic import SyntheticDataset
from torch.utils.data import DataLoader
gene_dataset = SyntheticDataset()
loader = DataLoader(gene_dataset, batch_size=128)
'''
names = ['SyntheticNamedDataset1', 'SyntheticNamedDataset2',
'SyntheticConcatenatedDataset1', 'SyntheticConcatenatedDataset2',
'SyntheticConcatenatedDataset3']
main_setup = '''
from __main__ import %
from scvi.utils import to_cuda
from torch.utils.data import DataLoader
gene_dataset = %()
loader = DataLoader(gene_dataset, batch_size=128, collate_fn=%.collate_fn)
'''
setups = [re.sub('%', name, main_setup) for name in names]
stmt = '''
for _ in loader:
pass
'''
cuda_stmt = '''
for tensors in loader:
tensors = to_cuda(tensors, async=False)
'''
cuda_stmt_async = '''
for tensors in loader:
tensors = to_cuda(tensors, async=True)
'''
import timeit
number = 1000
print("Original setup : ", timeit.timeit(stmt=stmt, setup=original_setup, number=number))
print("CUDA ASYNC: Original setup : ", timeit.timeit(stmt=cuda_stmt_async, setup=original_setup, number=number))
print("CUDA: Original setup : ", timeit.timeit(stmt=cuda_stmt, setup=original_setup, number=number))
for name, setup in zip(names, setups):
print(name, " : ", timeit.timeit(stmt=stmt, setup=setup, number=number))
if name != 'SyntheticConcatenatedDataset3': # doesn't work since data not contiguous
print("CUDA ASYNC:", name, " : ", timeit.timeit(stmt=cuda_stmt_async, setup=setup, number=number))
print("CUDA :", name, " : ", timeit.timeit(stmt=cuda_stmt, setup=setup, number=number))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment