Skip to content

Instantly share code, notes, and snippets.

@ryanpeach
Created February 10, 2017 01:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ryanpeach/a6e989c2f36bee9c926bd1a8cb0c190e to your computer and use it in GitHub Desktop.
Save ryanpeach/a6e989c2f36bee9c926bd1a8cb0c190e to your computer and use it in GitHub Desktop.
import random as rand
from functools import partial
def twiddle(run, args, p, dp, tol = 0.2, N = 100, logger = None):
""" Uses gradient descent to find the optimal value of p as input for function run.
run is a function which takes p as an argument and returns an error (with 0 being optimal) as an output.
dp is the initial magnitute for each index of p to begin
N is the max number of iterations, after which the best value of p is returned.
tol is the max error allowed, under which this function will terminate. """
best_err, best_p, best_dp, n = 1000000, None, None, 0
#logger.debug("Best P: {0},\nBest Error: {1}\n".format(best_p, best_err))
while abs(sum(dp)) > tol:
# Break if past the max
if N != None and n > N:
break
index = list(range(len(p)))
rand.shuffle(index)
for i in index:
p[i] += dp[i]
err = run(tuple(p), *args)
if err < best_err:
best_err = err
best_p = p
best_dp = dp
dp[i] *= 1.1
if logger != None:
logger.debug("P: {0},\nDP: {1},\nRUN(P): {2}".format(p,dp,run(tuple(p), *args)))
logger.debug("Best P: {0},\nBest Error: {1}\n".format(best_p, best_err))
else:
p[i]-=2*dp[i]
err = run(tuple(p), *args)
if err < best_err:
best_err = err
best_p = p
best_dp = dp
dp[i] *= 1.1
if logger != None:
logger.debug("P: {0},\nDP: {1},\nRUN(P): {2}".format(p,dp,run(tuple(p), *args)))
logger.debug("Best P: {0},\nBest Error: {1}\n".format(best_p, best_err))
else:
p[i] += dp[i]
dp[i] *= 0.9
if logger != None:
logger.debug("Unsuccessful, Error: {0}".format(err))
n += 1
if logger != None:
logger.debug("Best P: {0},\tBest Error: {1}".format(best_p, best_err))
return best_p, best_dp, best_err
def randomstart(run, runtime, ranges, dranges, tol = .02, N = 100, logger = None):
best_err, best_p, n = 100000, None, 0
while n < runtime or runtime < 0:
p = [rand.uniform(a,b) for a,b in ranges]
dp = [rand.uniform(a,b) for a,b in dranges]
if logger != None: logger.info("Starting P: {0}, \tDP: {1}".format(p,dp))
#try:
p, err = twiddle(run, p, dp, tol, N, logger)
if err < best_err:
best_err = err
best_p = p
if logger != None: logger.critical("Best P: {0},\tBest Error: {1}\n\n".format(best_p, best_err))
else:
if logger != None: logger.info("Error: {0}".format(err))
#except Exception as e:
# if logger != None: logger.error("Error Produced: {0},\tP: {1},\tDP: {2}\n".format(e,p,dp))
# n += 1
return best_err, best_p
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment