Skip to content

Instantly share code, notes, and snippets.

@RuABraun
Created April 7, 2019 13:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RuABraun/4923a89fa1246846c021b4d8d86ed0c4 to your computer and use it in GitHub Desktop.
Save RuABraun/4923a89fa1246846c021b4d8d86ed0c4 to your computer and use it in GitHub Desktop.
Weird performance difference with pytorch
import logging
import numpy as np
import plac
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
DATA='data'
BATCH_SIZE = 256
DEVICE = torch.device("cuda")
af = F.relu
logger = logging.getLogger(__name__)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.cnn1 = nn.Conv2d(1, 32, 5, 2, padding=2)
self.cnn2 = nn.Conv2d(32, 64, 3, 2, padding=0)
self.cnn3 = nn.Conv2d(64, 64, 3, 2, padding=0)
self.fc1 = nn.Linear(256, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x, adddim=False):
logger.debug('forward %s', x.size())
if adddim:
x = torch.unsqueeze(x, 1)
h = af(self.cnn1(x))
h = af(self.cnn2(h))
h = af(self.cnn3(h))
h = h.view(x.size(0), -1)
h = F.relu(self.fc1(h))
h = self.fc2(h)
return h
def eval_test(model, dset, adddim):
tot_cost, acc = 0., 0.
for i, sample in enumerate(dset):
feat = sample[0].to(DEVICE)
targ = sample[1].to(DEVICE)
out = model(feat, adddim)
cost = F.nll_loss(F.log_softmax(out, dim=1), targ)
tot_cost += cost.item()
pred = out.argmax(1)
acc += pred.eq(targ).sum().item() / 256
tot_cost /= i
acc /= i
print('Test cost / acc: {} / {}'.format(tot_cost, acc))
class FMNIST(torch.utils.data.Dataset):
def __init__(self, data, targets):
self.data = data.type(torch.FloatTensor)
self.targets = targets
def __getitem__(self, index):
return self.data[index], self.targets[index]
def __len__(self):
return len(self.data)
def run_training(train_loader, test_loader, adddim=False):
model = CNN()
model = model.to(DEVICE)
model = model.train()
optimizer = torch.optim.Adam(model.parameters(), 0.003, weight_decay=1e-6)
modelparams = [p for p in model.parameters()]
avg_loss = 0.
print('Starting training.')
epochnum = torch.tensor([0.0], requires_grad=False)
for i in range(10):
for j, sample in enumerate(train_loader):
feat = sample[0].to(DEVICE)
targ = sample[1].to(DEVICE)
model.zero_grad()
out = model(feat, adddim)
loss = F.nll_loss(F.log_softmax(out, dim=1), targ, reduction='sum')
loss.backward()
nn.utils.clip_grad_norm_(modelparams, 2.0)
optimizer.step()
c = loss.item() / BATCH_SIZE
avg_loss = 0.1 * c + 0.9 * avg_loss
epochnum += 1.
print(f'Done epoch {epochnum.item()}, avg loss is {avg_loss:.3f}')
eval_test(model, test_loader, adddim)
def run_standard():
dset_train = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), download=True)
dset_test = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), train=False, download=True)
print(len(dset_train), len(dset_test))
train_loader = torch.utils.data.DataLoader(dataset=dset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=dset_test, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
run_training(train_loader, test_loader)
def run_modified():
tmp_dset_train = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), download=True)
tmp_dset_test = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), train=False, download=True)
dset_train = FMNIST(tmp_dset_train.data, tmp_dset_train.targets)
dset_test = FMNIST(tmp_dset_test.data, tmp_dset_test.targets)
print(len(dset_train), len(dset_test))
train_loader = torch.utils.data.DataLoader(dataset=dset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=dset_test, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
run_training(train_loader, test_loader, True)
def main(
recipe_number,
debug: ("Print debug messages", "flag", "d")
):
if debug:
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)
else:
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
torch.manual_seed(1)
np.set_printoptions(linewidth=300)
torch.set_printoptions(profile="full")
if recipe_number == '1':
run_standard()
elif recipe_number == '2':
run_modified()
plac.call(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment