Skip to content

Instantly share code, notes, and snippets.

@hushell
Last active November 26, 2018 19:42
Show Gist options
  • Save hushell/ff48d6b1a89f81d162d97a2b41ccfb13 to your computer and use it in GitHub Desktop.
Save hushell/ff48d6b1a89f81d162d97a2b41ccfb13 to your computer and use it in GitHub Desktop.
svrg mnist
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import svrg_solver
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--inners', type=int, default=200, metavar='N',
help='number of inner updates to train (default: 200)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
args.inners = train_loader.sampler.num_samples # TODO: try m < n
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
#x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
#x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
model = Net()
model_prev = Net()
if args.cuda:
model.cuda()
model_prev.cuda()
sgd = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
svrg = svrg_solver.SVRG(model.parameters(), model_prev.parameters(), lr=args.lr)
svrg.copy_params()
def train(epoch, opt, inners):
model.train()
verbose = True
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx > inners:
break
if batch_idx == 3:
verbose = False
def closure_sgd():
sgd.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
return loss
def closure_svrg():
# net
model.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
# net_prev
model_prev.zero_grad()
output = model_prev(data)
loss0 = F.nll_loss(output, target)
loss0.backward()
return loss
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
if opt is 'sgd':
loss = sgd.step(closure_sgd)
elif opt is 'svrg':
loss = svrg.step(closure_svrg, verbose)
else:
raise TypeError('not supported optimizer!')
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))
def test(epoch):
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(output, target).data[0]
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()
test_loss = test_loss
test_loss /= len(test_loader) # loss function already averages over batch size
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
#import ipdb; ipdb.set_trace()
for epoch in range(1, args.epochs + 1):
def closure(data, target):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
model_prev.zero_grad()
output = model_prev(data)
loss0 = F.nll_loss(output, target)
loss0.backward()
return loss0
if epoch == -1:
train(epoch, 'sgd', args.inners)
else:
# full grad: g_k = f'(y_k)
svrg.full_grad(train_loader, closure) # TODO: itertools.islice(train_loader, i_end)
# one epoch, m steps of v^t = f'_i(x^t) - f'_i(y_k) + g_k, x^t+1 = x^t - lr * v^t
train(epoch, 'svrg', args.inners)
# copy params
svrg.copy_params()
# acceleration
# validation
test(epoch)
from torch.optim.optimizer import Optimizer, required
class SVRG(Optimizer):
r"""Implements stochastic variance reduction gradient descent (optionally with momentum).
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
Example:
.. note::
"""
def __init__(self, params, params_prev, lr=required, weight_decay=0):
defaults = dict(lr=lr, weight_decay=weight_decay)
super(SVRG, self).__init__(params, defaults)
for group in self.param_groups:
group['params_prev'] = list(params_prev)
def __setstate__(self, state):
super(SVRG, self).__setstate__(state)
def step(self, closure=None, verbose=False):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure() # forward, backward for model and model_prev
for group in self.param_groups:
weight_decay = group['weight_decay']
for p, p0 in zip(group['params'], group['params_prev']):
if p.grad is None:
continue
state = self.state[p0]
assert 'full_grad' in state # should be non-empty
d_p = p.grad.data
d_p0 = p0.grad.data
v = d_p - d_p0 + state['full_grad']
if verbose:
print('v.norm = {}, fg.norm = {}'.format(v.norm(), state['full_grad'].norm()))
if state['full_grad'].norm() < 1e-12:
import ipdb; ipdb.set_trace()
if weight_decay != 0:
v.add_(weight_decay, p.data)
p.data.add_(-group['lr'], v)
return loss
def full_grad(self, data_loader, closure):
"""
"""
num_samples = data_loader.sampler.num_samples
for i, (data, target) in enumerate(data_loader):
# forward, backward -> grad
closure(data, target)
for group in self.param_groups:
for p in group['params_prev']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
if 'full_grad' not in state:
state['full_grad'] = grad.new().resize_as_(grad).zero_()
if i == 0:
state['full_grad'].zero_()
state['full_grad'].add_(1/float(num_samples), grad)
def copy_params(self):
"""
"""
for group in self.param_groups:
for p, p0 in zip(group['params'], group['params_prev']):
#if p.grad is None:
# continue
p0.data.copy_(p.data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment