Last active
May 7, 2018 16:33
-
-
Save Edouard360/4615b4dc35b58c0a2a89be57e7b165d5 to your computer and use it in GitHub Desktop.
Iterating with the dataloader through different Datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import numpy as np | |
import torch | |
from scvi.dataset.dataset import GeneExpressionDataset | |
from scvi.dataset.synthetic import SyntheticDataset | |
class NamedDataset1(GeneExpressionDataset): | |
def __init__(self, X, local_means, local_vars, batch_indices, labels, gene_names=None, n_batches=1): | |
self.total_size, self.nb_genes = X.shape | |
my_type = np.dtype( | |
[('counts', 'f4', (X.shape[1],)), | |
('local_mean', 'f4'), | |
('local_var', 'f4'), | |
('batch_indices', 'f4'), | |
('labels', 'f4')]) | |
new_X = [(X[i], local_means[i][0], | |
local_vars[i][0], | |
batch_indices[i][0], | |
labels[i][0]) for i in range(len(X))] | |
self.structured_array = np.array(new_X, dtype=my_type) | |
def __getitem__(self, idx): | |
return self.structured_array[idx] | |
@staticmethod | |
def collate_fn(batch): | |
counts, local_mean, local_var, batch_indices, labels = [], [], [], [], [] | |
for b in batch: | |
counts += [b[0].reshape(1, -1)] | |
local_mean += [[b[1]]] | |
local_var += [[b[2]]] | |
batch_indices += [[b[3]]] | |
labels += [[b[4]]] | |
counts = np.concatenate(counts) | |
return [torch.FloatTensor(l) for l in [counts, local_mean, local_var, batch_indices, labels]] | |
class NamedDataset2(GeneExpressionDataset): | |
def __init__(self, X, local_means, local_vars, batch_indices, labels, gene_names=None, n_batches=1): | |
self.total_size, self.nb_genes = X.shape | |
my_type = np.dtype( | |
[('counts', 'f4', (1, X.shape[1])), | |
('local_mean', 'f4', (1, 1)), | |
('local_var', 'f4', (1, 1)), | |
('batch_indices', 'f4', (1, 1)), | |
('labels', 'f4', (1, 1))]) | |
new_X = [(X[i], local_means[i][0], | |
local_vars[i][0], | |
batch_indices[i][0], | |
labels[i][0]) for i in range(len(X))] | |
self.structured_array = np.array(new_X, dtype=my_type) | |
def __getitem__(self, idx): | |
return self.structured_array[idx] | |
@staticmethod | |
def collate_fn(batch): | |
counts, local_mean, local_var, batch_indices, labels = [], [], [], [], [] | |
for b in batch: | |
counts += [b[0]] | |
local_mean += [b[1]] | |
local_var += [b[2]] | |
batch_indices += [b[3]] | |
labels += [b[4]] | |
return [torch.FloatTensor(np.concatenate(l)) | |
for l in [counts, local_mean, local_var, batch_indices, labels]] | |
class ConcatenatedDataset1(GeneExpressionDataset): | |
def __init__(self, X, local_means, local_vars, batch_indices, labels, gene_names=None, n_batches=1): | |
self.total_size, self.nb_genes = X.shape | |
self.X = np.concatenate([X, local_means, local_vars, batch_indices, labels], axis=1) | |
def __getitem__(self, idx): | |
return self.X[idx, :self.nb_genes], self.X[idx, self.nb_genes], self.X[idx, self.nb_genes + 1], \ | |
self.X[idx, self.nb_genes + 2], self.X[idx, self.nb_genes + 3] | |
@staticmethod | |
def collate_fn(batch): | |
cat_batch = np.concatenate([np.concatenate([el[0], np.array(el[1:])]).reshape(1, -1) for el in batch], axis=0) | |
return torch.FloatTensor(cat_batch[:, :-4]), torch.FloatTensor(cat_batch[:, [-4]]), \ | |
torch.FloatTensor(cat_batch[:, [-3]]), torch.FloatTensor(cat_batch[:, [-2]]), \ | |
torch.FloatTensor(cat_batch[:, [-1]]) | |
class ConcatenatedDataset2(ConcatenatedDataset1): | |
def __getitem__(self, idx): | |
return self.X[idx] | |
@staticmethod | |
def collate_fn(batch): | |
cat_batch = np.concatenate(batch).reshape((len(batch), -1)) | |
return torch.FloatTensor(cat_batch[:, :-4]), torch.FloatTensor(cat_batch[:, [-4]]), \ | |
torch.FloatTensor(cat_batch[:, [-3]]), torch.FloatTensor(cat_batch[:, [-2]]), \ | |
torch.FloatTensor(cat_batch[:, [-1]]) | |
class ConcatenatedDataset3(ConcatenatedDataset2): | |
def __init__(self, *args, **kwargs): | |
super(ConcatenatedDataset2, self).__init__(*args, **kwargs) | |
self.X = torch.FloatTensor(self.X) | |
@staticmethod | |
def collate_fn(batch): | |
cat_batch = torch.cat(batch).view(len(batch), -1) | |
return cat_batch[:, :-4], cat_batch[:, [-4]], \ | |
cat_batch[:, [-3]], cat_batch[:, [-2]], \ | |
cat_batch[:, [-1]] | |
class SyntheticNamedDataset1(SyntheticDataset, NamedDataset1): | |
pass | |
class SyntheticNamedDataset2(SyntheticDataset, NamedDataset2): | |
pass | |
class SyntheticConcatenatedDataset1(SyntheticDataset, ConcatenatedDataset1): | |
pass | |
class SyntheticConcatenatedDataset2(SyntheticDataset, ConcatenatedDataset2): | |
pass | |
class SyntheticConcatenatedDataset3(SyntheticDataset, ConcatenatedDataset3): | |
pass | |
original_setup = ''' | |
from scvi.utils import to_cuda | |
from scvi.dataset.dataset import GeneExpressionDataset | |
from scvi.dataset.synthetic import SyntheticDataset | |
from torch.utils.data import DataLoader | |
gene_dataset = SyntheticDataset() | |
loader = DataLoader(gene_dataset, batch_size=128) | |
''' | |
names = ['SyntheticNamedDataset1', 'SyntheticNamedDataset2', | |
'SyntheticConcatenatedDataset1', 'SyntheticConcatenatedDataset2', | |
'SyntheticConcatenatedDataset3'] | |
main_setup = ''' | |
from __main__ import % | |
from scvi.utils import to_cuda | |
from torch.utils.data import DataLoader | |
gene_dataset = %() | |
loader = DataLoader(gene_dataset, batch_size=128, collate_fn=%.collate_fn) | |
''' | |
setups = [re.sub('%', name, main_setup) for name in names] | |
stmt = ''' | |
for _ in loader: | |
pass | |
''' | |
cuda_stmt = ''' | |
for tensors in loader: | |
tensors = to_cuda(tensors, async=False) | |
''' | |
cuda_stmt_async = ''' | |
for tensors in loader: | |
tensors = to_cuda(tensors, async=True) | |
''' | |
import timeit | |
number = 1000 | |
print("Original setup : ", timeit.timeit(stmt=stmt, setup=original_setup, number=number)) | |
print("CUDA ASYNC: Original setup : ", timeit.timeit(stmt=cuda_stmt_async, setup=original_setup, number=number)) | |
print("CUDA: Original setup : ", timeit.timeit(stmt=cuda_stmt, setup=original_setup, number=number)) | |
for name, setup in zip(names, setups): | |
print(name, " : ", timeit.timeit(stmt=stmt, setup=setup, number=number)) | |
if name != 'SyntheticConcatenatedDataset3': # doesn't work since data not contiguous | |
print("CUDA ASYNC:", name, " : ", timeit.timeit(stmt=cuda_stmt_async, setup=setup, number=number)) | |
print("CUDA :", name, " : ", timeit.timeit(stmt=cuda_stmt, setup=setup, number=number)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment