Created
January 18, 2017 00:39
-
-
Save jeiting/bf1149fa928f193b00774dbf279d5985 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This will be a very simple control method consisting of one set of weights and biases (theta) | |
that is optimized via CEM. | |
""" | |
import gym | |
from gym import wrappers | |
import numpy as np | |
class DiscreteDeterministicControlPolicy(): | |
def __init__(self, mean, std_dev, observation_dim, action_dim): | |
""" | |
mean - (action_dim * observation_dim + action_dim,) | |
std_dev - (action_dim * observation_dim + action_dim,) | |
The mean and std_dev of the parameter distributions | |
""" | |
w_mean = mean[:action_dim * observation_dim].reshape(observation_dim, action_dim) | |
b_mean = mean[action_dim * observation_dim:].reshape(action_dim) | |
w_var = std_dev[:action_dim * observation_dim].reshape(observation_dim, action_dim) | |
b_var = std_dev[action_dim * observation_dim:].reshape(action_dim) | |
self.W = np.random.normal(w_mean, w_var) | |
self.b = np.random.normal(b_mean, b_var) | |
def get_action(self, observation): | |
logp = np.dot(self.W.T, observation) + self.b | |
p = 1.0 / (1.0 + np.exp(-logp)) | |
p /= p.sum() | |
return p | |
def get_params(self): | |
return np.hstack((self.W.reshape(-1), self.b.reshape(-1))) | |
def __repr__(self): | |
return self.get_params().__repr__() | |
action_dim = 2 | |
observation_dim = 4 | |
num_batches = 10000 | |
num_episodes = 20 | |
num_policies = 25 | |
env = gym.make('CartPole-v1') | |
env = wrappers.Monitor(env, 'cartpolev1-cem-1', force=True) | |
# action = policy.get_action(observation) | |
# Create an initial mean and std_dev | |
num_params = action_dim * observation_dim + action_dim | |
mean = np.zeros(num_params) | |
std_dev = np.ones(num_params) | |
for b in xrange(num_batches): | |
# Create a 100 policies that sample from that distribution | |
policies = [DiscreteDeterministicControlPolicy(mean, std_dev, observation_dim, action_dim) for x in xrange(num_policies)] | |
rewards = np.zeros_like(policies) | |
# Run each policy 100 times | |
for p in xrange(num_policies): | |
policy = policies[p] | |
total_reward = 0.0 | |
for e in xrange(num_episodes): | |
observation = env.reset() | |
done = False | |
while not done: | |
action_distribution = policy.get_action(observation) | |
action = action_distribution.argmax() | |
observation, reward, done, info = env.step(action) | |
total_reward += reward | |
rewards[p] = total_reward / num_episodes | |
# Find the top 20% of the policies | |
twentieth_percentile = np.percentile(rewards, 80) | |
print twentieth_percentile | |
# stack params | |
params = policies[0].get_params() | |
for policy in policies[1:]: | |
params = np.vstack((params, policy.get_params())) | |
# Find the mean and std_dev of all their params, that is the new | |
elite_params = params[rewards >= twentieth_percentile] | |
print rewards | |
mean = elite_params.mean(axis=0) | |
std_dev = elite_params.std(axis=0) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment