Skip to content

Instantly share code, notes, and snippets.

@njp947
Last active February 25, 2017 13:17
Show Gist options
  • Save njp947/0ce775605cc52d7d814ac822523d41c7 to your computer and use it in GitHub Desktop.
Save njp947/0ce775605cc52d7d814ac822523d41c7 to your computer and use it in GitHub Desktop.
CE 0.1
import argparse
import numpy
import keras
import gym
def ce(f, th_mean, sigma0):
n_elite = int(numpy.round(200*0.2))
th_std = numpy.ones_like(th_mean) * sigma0
for _ in range(50):
ths = numpy.array([th_mean + dth for dth in th_std[None,:]*numpy.random.randn(200, th_mean.size)])
ys = numpy.array([f(th) for th in ths])
elite_inds = ys.argsort()[::-1][:n_elite]
elite_ths = ths[elite_inds]
th_mean = elite_ths.mean(axis=0)
th_std = elite_ths.std(axis=0)
parser = argparse.ArgumentParser()
parser.add_argument("environment")
args = parser.parse_args()
environment = gym.make(args.environment)
model = keras.models.Sequential([
keras.layers.Dense(10, activation="tanh", input_shape=environment.observation_space.shape),
keras.layers.Dense(5, activation="tanh"),
keras.layers.Dense(environment.action_space.n)])
shapes = [weight.shape for weight in model.get_weights()]
def get_solution(weights):
return numpy.concatenate([weight.reshape(-1) for weight in weights])
def set_weights(solution):
model.set_weights([solution[1:1+numpy.prod(shape)].reshape(shape) for shape in shapes])
def get_action(observation):
return numpy.argmax(model.predict_on_batch(observation))
shape = (1,) + environment.observation_space.shape
def get_reward():
observation = environment.reset()
Reward = 0
done = False
while not done:
observation = observation.reshape(shape)
action = get_action(observation)
observation, reward, done, _info = environment.step(action)
Reward += reward
return Reward
def f(x):
set_weights(x)
Reward = get_reward()
return Reward
x0 = get_solution(model.get_weights())
environment.monitor.start("gym")
ce(f, x0, 0.1)
environment.monitor.close()
gym.upload("gym", algorithm_id="alg_wJvr38jvQTe9jPCN040VCA")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment