Skip to content

Instantly share code, notes, and snippets.

@aliwaqas333
Created June 19, 2020 10:28
Show Gist options
  • Save aliwaqas333/0f6060664f63ed27471cfa5307f0fb9a to your computer and use it in GitHub Desktop.
Save aliwaqas333/0f6060664f63ed27471cfa5307f0fb9a to your computer and use it in GitHub Desktop.
basic function for GPU feature in PyTorch
def get_default_device():
"""Pick GPU if available, else CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
device = get_default_device()
def to_device(data, device):
"""Move tensor(s) to chosen device"""
if isinstance(data, (list,tuple)):
return [to_device(x, device) for x in data]
return data.to(device, non_blocking=True)
class DeviceDataLoader():
"""Wrap a dataloader to move data to a device"""
def __init__(self, dl, device):
self.dl = dl
self.device = device
def __iter__(self):
"""Yield a batch of data after moving it to device"""
for b in self.dl:
yield to_device(b, self.device)
def __len__(self):
"""Number of batches"""
return len(self.dl)
# Moving dataloaders to GPU
train_dl = DeviceDataLoader(train_dl, device)
val_dl = DeviceDataLoader(val_dl, device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment