Skip to content

Instantly share code, notes, and snippets.

@gngdb
Last active May 18, 2022 05:32
Show Gist options
  • Star 16 You must be signed in to star a gist
  • Fork 7 You must be signed in to fork a gist
  • Save gngdb/a9f912df362a85b37c730154ef3c294b to your computer and use it in GitHub Desktop.
Save gngdb/a9f912df362a85b37c730154ef3c294b to your computer and use it in GitHub Desktop.
Wrap PyTorch functions for scipy's optimize.minimize: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html (I also made a repo to do this https://github.com/gngdb/pytorch-minimize, although I had forgotten about this gist at the time)
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from scipy import optimize
from obj import PyTorchObjective
from tqdm import tqdm
if __name__ == '__main__':
# whatever this initialises to is our "true" W
linear = nn.Linear(32,32)
linear = linear.eval()
# input X
N = 10000
X = torch.Tensor(N,32)
X.uniform_(0.,1.) # fill with uniform
eps = torch.Tensor(N,32)
eps.normal_(0., 1e-4)
# output Y
with torch.no_grad():
Y = linear(X) #+ eps
# make module executing the experiment
class Objective(nn.Module):
def __init__(self):
super(Objective, self).__init__()
self.linear = nn.Linear(32,32)
self.linear = self.linear.train()
self.X, self.Y = X, Y
def forward(self):
output = self.linear(self.X)
return F.mse_loss(output, self.Y).mean()
objective = Objective()
maxiter = 100
with tqdm(total=maxiter) as pbar:
def verbose(xk):
pbar.update(1)
# try to optimize that function with scipy
obj = PyTorchObjective(objective)
xL = optimize.minimize(obj.fun, obj.x0, method='BFGS', jac=obj.jac,
callback=verbose, options={'gtol': 1e-6, 'disp': True,
'maxiter':maxiter})
#xL = optimize.minimize(obj.fun, obj.x0, method='CG', jac=obj.jac)# , options={'gtol': 1e-2})
import torch
from scipy import optimize
import torch.nn.functional as F
import math
import numpy as np
from functools import reduce
from collections import OrderedDict
class PyTorchObjective(object):
"""PyTorch objective function, wrapped to be called by scipy.optimize."""
def __init__(self, obj_module):
self.f = obj_module # some pytorch module, that produces a scalar loss
# make an x0 from the parameters in this module
parameters = OrderedDict(obj_module.named_parameters())
self.param_shapes = {n:parameters[n].size() for n in parameters}
# ravel and concatenate all parameters to make x0
self.x0 = np.concatenate([parameters[n].data.numpy().ravel()
for n in parameters])
def unpack_parameters(self, x):
"""optimize.minimize will supply 1D array, chop it up for each parameter."""
i = 0
named_parameters = OrderedDict()
for n in self.param_shapes:
param_len = reduce(lambda x,y: x*y, self.param_shapes[n])
# slice out a section of this length
param = x[i:i+param_len]
# reshape according to this size, and cast to torch
param = param.reshape(*self.param_shapes[n])
named_parameters[n] = torch.from_numpy(param)
# update index
i += param_len
return named_parameters
def pack_grads(self):
"""pack all the gradients from the parameters in the module into a
numpy array."""
grads = []
for p in self.f.parameters():
grad = p.grad.data.numpy()
grads.append(grad.ravel())
return np.concatenate(grads)
def is_new(self, x):
# if this is the first thing we've seen
if not hasattr(self, 'cached_x'):
return True
else:
# compare x to cached_x to determine if we've been given a new input
x, self.cached_x = np.array(x), np.array(self.cached_x)
error = np.abs(x - self.cached_x)
return error.max() > 1e-8
def cache(self, x):
# unpack x and load into module
state_dict = self.unpack_parameters(x)
self.f.load_state_dict(state_dict)
# store the raw array as well
self.cached_x = x
# zero the gradient
self.f.zero_grad()
# use it to calculate the objective
obj = self.f()
# backprop the objective
obj.backward()
self.cached_f = obj.item()
self.cached_jac = self.pack_grads()
def fun(self, x):
if self.is_new(x):
self.cache(x)
return self.cached_f
def jac(self, x):
if self.is_new(x):
self.cache(x)
return self.cached_jac
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment