Created
April 9, 2020 16:29
-
-
Save HenryJia/17e3a647cc2da1dd0ceeb6365bdfeaac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
import torch | |
from lighter.train import AsynchronousLoader | |
from torch.utils.data import DataLoader | |
batch_size = 2 ** 13 | |
length = 100 | |
device = torch.device('cuda:0') | |
print('Running Async speed test at batch size {}MB for {} batches'.format(batch_size ** 2 * 4 / 2 ** 20, length)) | |
class RandomDataset: | |
def __getitem__(self, i): | |
return torch.randn(batch_size) | |
def __len__(self): | |
return length * batch_size | |
dataset = RandomDataset() | |
loader = AsynchronousLoader(dataset, device=device, batch_size=batch_size, queue_size=10, shuffle=True, workers=10) | |
# First time to wake CUDA up | |
main_stream = torch.cuda.default_stream(device=device) | |
y = torch.zeros(batch_size, batch_size).to(device=device) | |
for i, x in enumerate(loader): | |
y = y + torch.mm(x, x) | |
y.sum().item() | |
# Now actually time things | |
y = torch.zeros(batch_size, batch_size).to(device=device) | |
# Synchronize the main stream to make sure everything is done before starting our test | |
main_stream.synchronize() | |
t0 = time.time() | |
for i, x in enumerate(loader): | |
y = y + torch.mm(x, x) | |
y.sum().item() | |
t1 = time.time() | |
print('Async time', t1 - t0) | |
loader = DataLoader( | |
dataset, | |
num_workers=10, | |
pin_memory=True, | |
batch_size=batch_size, | |
shuffle=True) | |
loader = DataLoader(dataset, num_workers=10, pin_memory=True, batch_size=batch_size, shuffle=True) | |
y = torch.zeros(batch_size, batch_size).to(device=device) | |
main_stream.synchronize() | |
t0 = time.time() | |
for i, x in enumerate(loader): | |
a = x.to(device=device, non_blocking=True) | |
y = y + torch.mm(a, a) | |
y.sum().item() | |
t1 = time.time() | |
print('Non sync time', t1 - t0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment