Skip to content

Instantly share code, notes, and snippets.

@andrewssobral
Forked from karpathy/nes.py
Created Mar 26, 2017
Embed
What would you like to do?
Natural Evolution Strategies (NES) toy example that optimizes a quadratic function
"""
A bare bones examples of optimizing a black-box function (f) using
Natural Evolution Strategies (NES), where the parameter distribution is a
gaussian of fixed standard deviation.
"""
import numpy as np
np.random.seed(0)
# the function we want to optimize
def f(w):
# here we would normally:
# ... 1) create a neural network with weights w
# ... 2) run the neural network on the environment for some time
# ... 3) sum up and return the total reward
# but for the purposes of an example, lets try to minimize
# the L2 distance to a specific solution vector. So the highest reward
# we can achieve is 0, when the vector w is exactly equal to solution
reward = -np.sum(np.square(solution - w))
return reward
# hyperparameters
npop = 50 # population size
sigma = 0.1 # noise standard deviation
alpha = 0.001 # learning rate
# start the optimization
solution = np.array([0.5, 0.1, -0.3])
w = np.random.randn(3) # our initial guess is random
for i in range(300):
# print current fitness of the most likely parameter setting
if i % 20 == 0:
print('iter %d. w: %s, solution: %s, reward: %f' %
(i, str(w), str(solution), f(w)))
# initialize memory for a population of w's, and their rewards
N = np.random.randn(npop, 3) # samples from a normal distribution N(0,1)
R = np.zeros(npop)
for j in range(npop):
w_try = w + sigma*N[j] # jitter w using gaussian of sigma 0.1
R[j] = f(w_try) # evaluate the jittered version
# standardize the rewards to have a gaussian distribution
A = (R - np.mean(R)) / np.std(R)
# perform the parameter update. The matrix multiply below
# is just an efficient way to sum up all the rows of the noise matrix N,
# where each row N[j] is weighted by A[j]
w = w + alpha/(npop*sigma) * np.dot(N.T, A)
# when run, prints:
# iter 0. w: [ 1.76405235 0.40015721 0.97873798], solution: [ 0.5 0.1 -0.3], reward: -3.323094
# iter 20. w: [ 1.63796944 0.36987244 0.84497941], solution: [ 0.5 0.1 -0.3], reward: -2.678783
# iter 40. w: [ 1.50042904 0.33577052 0.70329169], solution: [ 0.5 0.1 -0.3], reward: -2.063040
# iter 60. w: [ 1.36438269 0.29247833 0.56990397], solution: [ 0.5 0.1 -0.3], reward: -1.540938
# iter 80. w: [ 1.2257328 0.25622233 0.43607161], solution: [ 0.5 0.1 -0.3], reward: -1.092895
# iter 100. w: [ 1.08819889 0.22827364 0.30415088], solution: [ 0.5 0.1 -0.3], reward: -0.727430
# iter 120. w: [ 0.95675286 0.19282042 0.16682465], solution: [ 0.5 0.1 -0.3], reward: -0.435164
# iter 140. w: [ 0.82214521 0.16161165 0.03600742], solution: [ 0.5 0.1 -0.3], reward: -0.220475
# iter 160. w: [ 0.70282088 0.12935569 -0.09779598], solution: [ 0.5 0.1 -0.3], reward: -0.082885
# iter 180. w: [ 0.58380424 0.11579811 -0.21083135], solution: [ 0.5 0.1 -0.3], reward: -0.015224
# iter 200. w: [ 0.52089064 0.09897718 -0.2761225 ], solution: [ 0.5 0.1 -0.3], reward: -0.001008
# iter 220. w: [ 0.50861791 0.10220363 -0.29023563], solution: [ 0.5 0.1 -0.3], reward: -0.000174
# iter 240. w: [ 0.50428202 0.10834192 -0.29828744], solution: [ 0.5 0.1 -0.3], reward: -0.000091
# iter 260. w: [ 0.50147991 0.1044559 -0.30255291], solution: [ 0.5 0.1 -0.3], reward: -0.000029
# iter 280. w: [ 0.50208135 0.0986722 -0.29841024], solution: [ 0.5 0.1 -0.3], reward: -0.000009
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment