Skip to content

Instantly share code, notes, and snippets.

@HenryJia
Created April 9, 2020 16:03
Show Gist options
  • Save HenryJia/930916775c11bc5c6debb87c046965e5 to your computer and use it in GitHub Desktop.
Save HenryJia/930916775c11bc5c6debb87c046965e5 to your computer and use it in GitHub Desktop.
import time, os, sys, argparse
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
np.random.seed(94103)
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from lighter.modules.utils import Flatten
from lighter.train import AsynchronousLoader
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--root_dir', type=str, help='Root directory containing the folder with the MNIST dataset')
parser.add_argument('--use_async', action='store_true', help='Whether to use the asynchronous loader')
args = parser.parse_args()
train_set = MNIST(args.root_dir, train=True, download=True, transform=ToTensor())
validation_set = MNIST(args.root_dir, train=False, download=True, transform=ToTensor())
model = nn.Sequential(nn.Conv2d(1, 16, 3, padding=1),
nn.LeakyReLU(),
nn.MaxPool2d((2, 2)),
nn.Conv2d(16, 32, 3, padding=1),
nn.LeakyReLU(),
nn.MaxPool2d((2, 2)),
Flatten(),
nn.Linear(32 * 7 * 7, 512),
nn.LeakyReLU(),
nn.Linear(512, 10),
nn.LogSoftmax(dim=1)).to(device=torch.device('cuda:0'))
loss = nn.NLLLoss().to(device=torch.device('cuda:0'))
optim = Adam(model.parameters(), lr=3e-4)
if args.use_async:
print('Using AsynchronousLoader')
train_loader = AsynchronousLoader(train_set, device=torch.device('cuda:0'), batch_size=1024, shuffle=True)
validation_loader = AsynchronousLoader(validation_set, device=torch.device('cuda:0'), batch_size=1024, shuffle=True)
else:
print('Using Dataloader')
train_loader = DataLoader(train_set, batch_size=1024, shuffle=True, pin_memory=True, num_workers=10)
validation_loader = DataLoader(validation_set, batch_size=1024, shuffle=True, pin_memory=True, num_workers=10)
t0 = time.time()
for i in range(100):
pb = tqdm(total=len(train_loader))
for x, y in train_loader:
if not args.use_async:
x = x.to(device=torch.device('cuda:0'), non_blocking=True)
y = y.to(device=torch.device('cuda:0'), non_blocking=True)
out = model(x)
l = loss(out, y)
optim.zero_grad()
l.backward()
optim.step()
pb.set_postfix(train_loss=l.item())
pb.update(1)
pb.close()
pb = tqdm(total=len(validation_loader))
for x, y in validation_loader:
if not args.use_async:
x = x.to(device=torch.device('cuda:0'), non_blocking=True)
y = y.to(device=torch.device('cuda:0'), non_blocking=True)
out = model(x)
l = loss(out, y)
pb.set_postfix(validation_loss=l.item())
pb.update(1)
pb.close()
t1 = time.time()
print('Total training time for 100 epochs:', t1 - t0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment