Skip to content

Instantly share code, notes, and snippets.

@maxpagels
Created January 27, 2018 17:12
Show Gist options
  • Save maxpagels/a670c0e44f733393984a7d9066b330d3 to your computer and use it in GitHub Desktop.
Save maxpagels/a670c0e44f733393984a7d9066b330d3 to your computer and use it in GitHub Desktop.
# Evolution Strategies for Reinforcement Learning
# See: https://blog.openai.com/evolution-strategies/
import numpy as np
from keras.layers import Dense
from keras.models import Sequential
np.random.seed(0)
model = Sequential()
layer1 = Dense(2,input_dim=5)
model.add(layer1)
layer2 = Dense(1)
model.add(layer2)
# the function we want to optimize
def f(w):
w1 = np.array([w[0:5],w[5:10]]).T
b1 = np.array(w[10:12])
w2 = np.array([w[12:13],w[13:14]])
b2 = np.array(w[14:15])
layer1.set_weights([w1,b1])
layer2.set_weights([w2,b2])
res = model.predict(np.array([[1,0,1,0,0],[1,1,1,1,1]]))
total_res = 0
for row in res:
total_res += row[0]
return -(total_res - 4)**2
# hyperparameters
npop = 50 # population size
sigma = 0.8 # noise standard deviation
alpha = 0.01 # learning rate
# start the optimization
solution = np.array([0.5, 0.1, -0.3])
w = np.random.randn(10 + 2 + 2 + 1) # our initial guess is random
for i in range(10000):
# print current fitness of the most likely parameter setting
if i % 2 == 0:
print('iter %d. w: %s, solution: %s, reward: %f' %
(i, str(w), str(solution), f(w)))
# initialize memory for a population of w's, and their rewards
N = np.random.randn(npop, 10 + 2 + 2 + 1) # samples from a normal distribution N(0,1)
R = np.zeros(npop)
for j in range(npop):
w_try = w + sigma*N[j] # jitter w using gaussian of sigma 0.1
R[j] = f(w_try) # evaluate the jittered version
# standardize the rewards to have a gaussian distribution
A = (R - np.mean(R)) / np.std(R)
# perform the parameter update. The matrix multiply below
# is just an efficient way to sum up all the rows of the noise matrix N,
# where each row N[j] is weighted by A[j]
w = w + alpha/(npop*sigma) * np.dot(N.T, A)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment