Last active
November 26, 2018 19:42
-
-
Save hushell/ff48d6b1a89f81d162d97a2b41ccfb13 to your computer and use it in GitHub Desktop.
svrg mnist
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import argparse | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torchvision import datasets, transforms | |
from torch.autograd import Variable | |
import svrg_solver | |
import os | |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
# Training settings | |
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | |
parser.add_argument('--batch-size', type=int, default=64, metavar='N', | |
help='input batch size for training (default: 64)') | |
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', | |
help='input batch size for testing (default: 1000)') | |
parser.add_argument('--epochs', type=int, default=10, metavar='N', | |
help='number of epochs to train (default: 10)') | |
parser.add_argument('--inners', type=int, default=200, metavar='N', | |
help='number of inner updates to train (default: 200)') | |
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', | |
help='learning rate (default: 0.01)') | |
parser.add_argument('--momentum', type=float, default=0.5, metavar='M', | |
help='SGD momentum (default: 0.5)') | |
parser.add_argument('--no-cuda', action='store_true', default=False, | |
help='enables CUDA training') | |
parser.add_argument('--seed', type=int, default=1, metavar='S', | |
help='random seed (default: 1)') | |
parser.add_argument('--log-interval', type=int, default=100, metavar='N', | |
help='how many batches to wait before logging training status') | |
args = parser.parse_args() | |
args.cuda = not args.no_cuda and torch.cuda.is_available() | |
torch.manual_seed(args.seed) | |
if args.cuda: | |
torch.cuda.manual_seed(args.seed) | |
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} | |
train_loader = torch.utils.data.DataLoader( | |
datasets.MNIST('../data', train=True, download=True, | |
transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
])), | |
batch_size=args.batch_size, shuffle=True, **kwargs) | |
args.inners = train_loader.sampler.num_samples # TODO: try m < n | |
test_loader = torch.utils.data.DataLoader( | |
datasets.MNIST('../data', train=False, transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
])), | |
batch_size=args.batch_size, shuffle=True, **kwargs) | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | |
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | |
self.conv2_drop = nn.Dropout2d() | |
self.fc1 = nn.Linear(320, 50) | |
self.fc2 = nn.Linear(50, 10) | |
def forward(self, x): | |
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | |
#x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | |
x = F.relu(F.max_pool2d(self.conv2(x), 2)) | |
x = x.view(-1, 320) | |
x = F.relu(self.fc1(x)) | |
#x = F.dropout(x, training=self.training) | |
x = self.fc2(x) | |
return F.log_softmax(x) | |
model = Net() | |
model_prev = Net() | |
if args.cuda: | |
model.cuda() | |
model_prev.cuda() | |
sgd = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) | |
svrg = svrg_solver.SVRG(model.parameters(), model_prev.parameters(), lr=args.lr) | |
svrg.copy_params() | |
def train(epoch, opt, inners): | |
model.train() | |
verbose = True | |
for batch_idx, (data, target) in enumerate(train_loader): | |
if batch_idx > inners: | |
break | |
if batch_idx == 3: | |
verbose = False | |
def closure_sgd(): | |
sgd.zero_grad() | |
output = model(data) | |
loss = F.nll_loss(output, target) | |
loss.backward() | |
return loss | |
def closure_svrg(): | |
# net | |
model.zero_grad() | |
output = model(data) | |
loss = F.nll_loss(output, target) | |
loss.backward() | |
# net_prev | |
model_prev.zero_grad() | |
output = model_prev(data) | |
loss0 = F.nll_loss(output, target) | |
loss0.backward() | |
return loss | |
if args.cuda: | |
data, target = data.cuda(), target.cuda() | |
data, target = Variable(data), Variable(target) | |
if opt is 'sgd': | |
loss = sgd.step(closure_sgd) | |
elif opt is 'svrg': | |
loss = svrg.step(closure_svrg, verbose) | |
else: | |
raise TypeError('not supported optimizer!') | |
if batch_idx % args.log_interval == 0: | |
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |
epoch, batch_idx * len(data), len(train_loader.dataset), | |
100. * batch_idx / len(train_loader), loss.data[0])) | |
def test(epoch): | |
model.eval() | |
test_loss = 0 | |
correct = 0 | |
for data, target in test_loader: | |
if args.cuda: | |
data, target = data.cuda(), target.cuda() | |
data, target = Variable(data, volatile=True), Variable(target) | |
output = model(data) | |
test_loss += F.nll_loss(output, target).data[0] | |
pred = output.data.max(1)[1] # get the index of the max log-probability | |
correct += pred.eq(target.data).cpu().sum() | |
test_loss = test_loss | |
test_loss /= len(test_loader) # loss function already averages over batch size | |
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | |
test_loss, correct, len(test_loader.dataset), | |
100. * correct / len(test_loader.dataset))) | |
#import ipdb; ipdb.set_trace() | |
for epoch in range(1, args.epochs + 1): | |
def closure(data, target): | |
if args.cuda: | |
data, target = data.cuda(), target.cuda() | |
data, target = Variable(data), Variable(target) | |
model_prev.zero_grad() | |
output = model_prev(data) | |
loss0 = F.nll_loss(output, target) | |
loss0.backward() | |
return loss0 | |
if epoch == -1: | |
train(epoch, 'sgd', args.inners) | |
else: | |
# full grad: g_k = f'(y_k) | |
svrg.full_grad(train_loader, closure) # TODO: itertools.islice(train_loader, i_end) | |
# one epoch, m steps of v^t = f'_i(x^t) - f'_i(y_k) + g_k, x^t+1 = x^t - lr * v^t | |
train(epoch, 'svrg', args.inners) | |
# copy params | |
svrg.copy_params() | |
# acceleration | |
# validation | |
test(epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.optim.optimizer import Optimizer, required | |
class SVRG(Optimizer): | |
r"""Implements stochastic variance reduction gradient descent (optionally with momentum). | |
Args: | |
params (iterable): iterable of parameters to optimize or dicts defining | |
parameter groups | |
lr (float): learning rate | |
Example: | |
.. note:: | |
""" | |
def __init__(self, params, params_prev, lr=required, weight_decay=0): | |
defaults = dict(lr=lr, weight_decay=weight_decay) | |
super(SVRG, self).__init__(params, defaults) | |
for group in self.param_groups: | |
group['params_prev'] = list(params_prev) | |
def __setstate__(self, state): | |
super(SVRG, self).__setstate__(state) | |
def step(self, closure=None, verbose=False): | |
"""Performs a single optimization step. | |
Arguments: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
loss = closure() # forward, backward for model and model_prev | |
for group in self.param_groups: | |
weight_decay = group['weight_decay'] | |
for p, p0 in zip(group['params'], group['params_prev']): | |
if p.grad is None: | |
continue | |
state = self.state[p0] | |
assert 'full_grad' in state # should be non-empty | |
d_p = p.grad.data | |
d_p0 = p0.grad.data | |
v = d_p - d_p0 + state['full_grad'] | |
if verbose: | |
print('v.norm = {}, fg.norm = {}'.format(v.norm(), state['full_grad'].norm())) | |
if state['full_grad'].norm() < 1e-12: | |
import ipdb; ipdb.set_trace() | |
if weight_decay != 0: | |
v.add_(weight_decay, p.data) | |
p.data.add_(-group['lr'], v) | |
return loss | |
def full_grad(self, data_loader, closure): | |
""" | |
""" | |
num_samples = data_loader.sampler.num_samples | |
for i, (data, target) in enumerate(data_loader): | |
# forward, backward -> grad | |
closure(data, target) | |
for group in self.param_groups: | |
for p in group['params_prev']: | |
if p.grad is None: | |
continue | |
grad = p.grad.data | |
state = self.state[p] | |
if 'full_grad' not in state: | |
state['full_grad'] = grad.new().resize_as_(grad).zero_() | |
if i == 0: | |
state['full_grad'].zero_() | |
state['full_grad'].add_(1/float(num_samples), grad) | |
def copy_params(self): | |
""" | |
""" | |
for group in self.param_groups: | |
for p, p0 in zip(group['params'], group['params_prev']): | |
#if p.grad is None: | |
# continue | |
p0.data.copy_(p.data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment