Skip to content

Instantly share code, notes, and snippets.

@RuABraun
Created Apr 7, 2019
Embed
What would you like to do?
Weird performance difference with pytorch
import logging
import numpy as np
import plac
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
DATA='data'
BATCH_SIZE = 256
DEVICE = torch.device("cuda")
af = F.relu
logger = logging.getLogger(__name__)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.cnn1 = nn.Conv2d(1, 32, 5, 2, padding=2)
self.cnn2 = nn.Conv2d(32, 64, 3, 2, padding=0)
self.cnn3 = nn.Conv2d(64, 64, 3, 2, padding=0)
self.fc1 = nn.Linear(256, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x, adddim=False):
logger.debug('forward %s', x.size())
if adddim:
x = torch.unsqueeze(x, 1)
h = af(self.cnn1(x))
h = af(self.cnn2(h))
h = af(self.cnn3(h))
h = h.view(x.size(0), -1)
h = F.relu(self.fc1(h))
h = self.fc2(h)
return h
def eval_test(model, dset, adddim):
tot_cost, acc = 0., 0.
for i, sample in enumerate(dset):
feat = sample[0].to(DEVICE)
targ = sample[1].to(DEVICE)
out = model(feat, adddim)
cost = F.nll_loss(F.log_softmax(out, dim=1), targ)
tot_cost += cost.item()
pred = out.argmax(1)
acc += pred.eq(targ).sum().item() / 256
tot_cost /= i
acc /= i
print('Test cost / acc: {} / {}'.format(tot_cost, acc))
class FMNIST(torch.utils.data.Dataset):
def __init__(self, data, targets):
self.data = data.type(torch.FloatTensor)
self.targets = targets
def __getitem__(self, index):
return self.data[index], self.targets[index]
def __len__(self):
return len(self.data)
def run_training(train_loader, test_loader, adddim=False):
model = CNN()
model = model.to(DEVICE)
model = model.train()
optimizer = torch.optim.Adam(model.parameters(), 0.003, weight_decay=1e-6)
modelparams = [p for p in model.parameters()]
avg_loss = 0.
print('Starting training.')
epochnum = torch.tensor([0.0], requires_grad=False)
for i in range(10):
for j, sample in enumerate(train_loader):
feat = sample[0].to(DEVICE)
targ = sample[1].to(DEVICE)
model.zero_grad()
out = model(feat, adddim)
loss = F.nll_loss(F.log_softmax(out, dim=1), targ, reduction='sum')
loss.backward()
nn.utils.clip_grad_norm_(modelparams, 2.0)
optimizer.step()
c = loss.item() / BATCH_SIZE
avg_loss = 0.1 * c + 0.9 * avg_loss
epochnum += 1.
print(f'Done epoch {epochnum.item()}, avg loss is {avg_loss:.3f}')
eval_test(model, test_loader, adddim)
def run_standard():
dset_train = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), download=True)
dset_test = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), train=False, download=True)
print(len(dset_train), len(dset_test))
train_loader = torch.utils.data.DataLoader(dataset=dset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=dset_test, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
run_training(train_loader, test_loader)
def run_modified():
tmp_dset_train = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), download=True)
tmp_dset_test = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), train=False, download=True)
dset_train = FMNIST(tmp_dset_train.data, tmp_dset_train.targets)
dset_test = FMNIST(tmp_dset_test.data, tmp_dset_test.targets)
print(len(dset_train), len(dset_test))
train_loader = torch.utils.data.DataLoader(dataset=dset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=dset_test, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
run_training(train_loader, test_loader, True)
def main(
recipe_number,
debug: ("Print debug messages", "flag", "d")
):
if debug:
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)
else:
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
torch.manual_seed(1)
np.set_printoptions(linewidth=300)
torch.set_printoptions(profile="full")
if recipe_number == '1':
run_standard()
elif recipe_number == '2':
run_modified()
plac.call(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment