Skip to content

Instantly share code, notes, and snippets.

@yanatan16
Last active March 15, 2024 19:05
Show Gist options
  • Save yanatan16/5420795 to your computer and use it in GitHub Desktop.
Save yanatan16/5420795 to your computer and use it in GitHub Desktop.
Simultaneous Perturbation Stochastic Approximation code in python
'''
Simultaneous Perturbation Stochastic Approximation
Author: Jon Eisen
License: MIT
This code defines runs SPSA using iterators.
A quick intro to iterators:
Iterators are like arrays except that we don't store the whole array, we just
store how to get to the next element. In this way, we can create infinite
iterators. In python, iterators can act very similar to arrays.
numpy (a number processing library) is not used here so that pypy (an alternate
python implementation which is faster) can be used.
'''
from itertools import count, izip
# A simple function that returns its argument
identity = lambda x: x
def SPSA(y, t0, a, c, delta, constraint=identity):
'''
Creates an Simultaneous Perturbation Stochastic Approximation iterator.
y - a function of theta that returns a scalar
t0 - the starting value of theta
a - an iterable of a_k values
c - an iterable of c_k values
delta - a function of no parameters which creates the delta vector
constraint - a function of theta that returns theta
'''
theta = t0
# Pull off the ak and ck values forever
for ak, ck in izip(a, c):
# Get estimated gradient
gk = estimate_gk(y, theta, delta, ck)
# Adjust theta using SA
theta = [t - ak * gkk for t, gkk in izip(theta, gk)]
# Constrain
theta = constraint(theta)
yield theta # This makes this function become an iterator
def estimate_gk(y, theta, delta, ck):
'''Helper function to estimate gk from SPSA'''
# Generate Delta vector
delta_k = delta()
# Get the two perturbed values of theta
# list comprehensions like this are quite nice
ta = [t + ck * dk for t, dk in izip(theta, delta_k)]
tb = [t - ck * dk for t, dk in izip(theta, delta_k)]
# Calculate g_k(theta_k)
ya, yb = y(ta), y(tb)
gk = [(ya-yb) / (2*ck*dk) for dk in delta_k]
return gk
def standard_ak(a, A, alpha):
'''Create a generator for values of a_k in the standard form.'''
# Parentheses makes this an iterator comprehension
# count() is an infinite iterator as 0, 1, 2, ...
return ( a / (k + 1 + A) ** alpha for k in count() )
def standard_ck(c, gamma):
'''Create a generator for values of c_k in the standard form.'''
return ( c / (k + 1) ** gamma for k in count() )
class Bernoulli:
'''
Bernoulli Perturbation distributions.
p is the dimension
+/- r are the alternate values
'''
def __init__(self, r=1, p=2):
self.p = p
self.r = r
def __call__(self):
return [random.choice((-self.r, self.r)) for _ in xrange(self.p)]
class LossFunction:
''' A base class for loss functions which defines y as L+epsilon '''
def y(self, theta):
return self.L(theta) + self.epsilon(theta)
from spsa import *
from itertools import islice, izip, tee
import random
def nth(iterable, n, default=None):
"Returns the nth item or a default value"
return next(islice(iterable, n, None), default)
class SkewedQuarticLoss(LossFunction):
'''
Skewed Quartic Loss function.
Initialize with vector length p.
Functions, L, y, and epsilon available
'''
def __init__(self, p, sigma):
x = 1./p
self.B = [[x if i >= j else 0 for i in xrange(p)] for j in xrange(p)]
self.sigmasq = sigma ** 2
def L(self, theta):
bt = [dot(Br, theta) for Br in self.B]
return dot(bt,bt) + sum((.1 * b**3 + .01 * b**4 for b in bt))
def epsilon(self, theta):
return random.gauss(0, self.sigmasq) # multiply by stdev
def run_spsa(n=1000, replications=40):
p = 20
loss = SkewedQuarticLoss(p, sigma=1)
theta0 = [1 for _ in xrange(p)]
c = standard_ck(c=1, gamma=.101)
a = standard_ak(a=1, A=100, alpha=.602)
delta = Bernoulli(p=p)
# tee is a useful function to split an iterator into n independent runs of that iterator
ac = izip(tee(a,n),tee(c,n))
losses = []
for a, c in islice(ac, replications):
theta_iter = SPSA(a=a, c=c, y=loss.y, t0=theta0, delta=delta)
terminal_theta = nth(theta_iter, n) # Get 1000th theta
terminal_loss = loss.L(terminal_theta)
losses += [terminal_loss]
return losses # You can calculate means/variances from this data.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment