Skip to content

Instantly share code, notes, and snippets.

@zh4ngx
Created July 4, 2017 04:58
Show Gist options
  • Save zh4ngx/f3ac2a1232f1c916e5933aa6ae4fdaf0 to your computer and use it in GitHub Desktop.
Save zh4ngx/f3ac2a1232f1c916e5933aa6ae4fdaf0 to your computer and use it in GitHub Desktop.
Monte Carlo EM - weighted sampling of mean/variance of theta by reward
from utils import make_policy
def do_episode(policy, env, max_steps, render=False):
total_rew = 0
ob = env.reset()
for t in range(max_steps):
a = policy.act(ob)
(ob, reward, done, _info) = env.step(a)
total_rew += reward
if render and t % 3 == 0:
env.render()
if done:
break
return total_rew
def noisy_evaluation(env, theta, num_steps):
policy = make_policy(env, theta)
rew = do_episode(policy, env, num_steps)
return rew
# Source: http://rl-gym-doc.s3-website-us-west-2.amazonaws.com/mlss/lab1.html
# Implementation of Monte-Carlo Expectation Maximization
import gym
import numpy as np
from gym.wrappers.monitoring import Monitor
from evaluation import noisy_evaluation, do_episode
from utils import get_dim_theta, make_policy
# Task settings:
env = gym.make('CartPole-v0') # Change as needed
env = Monitor(env, 'tmp/cart-pole-monte-carlo-em-1', force=True)
num_steps = 500 # maximum length of episode
# Alg settings:
n_iter = 100 # number of iterations of CEM
batch_size = 25 # number of samples per batch
dim_theta = get_dim_theta(env)
# Initialize mean and variance
theta_mean = np.zeros(dim_theta)
theta_variance = np.ones(dim_theta)
# Now, for the algorithm
for iteration in range(n_iter):
# Sample parameter vectors
thetas = np.vstack([np.random.multivariate_normal(theta_mean, np.diag(theta_variance)) for _ in range(batch_size)])
rewards = [noisy_evaluation(env, theta, num_steps) for theta in thetas]
# Weight parameters by score
# Update theta_mean, theta_std
theta_mean = np.average(thetas, axis=0, weights=rewards)
theta_variance = np.average((thetas - theta_mean) ** 2, axis=0, weights=rewards)
if iteration % 10 == 0:
print("iteration %i. mean f: %8.3g. max f: %8.3g" % (iteration, np.mean(rewards), np.max(rewards)))
print("theta mean %s \n theta std %s" % (theta_mean, theta_variance))
do_episode(make_policy(env, theta_mean), env, num_steps, render=True)
env.close()
# ================================================================
# Policies
# ================================================================
import numpy as np
class DeterministicDiscreteActionLinearPolicy(object):
def __init__(self, theta, ob_space, ac_space):
"""
dim_ob: dimension of observations
n_actions: number of actions
theta: flat vector of parameters
"""
dim_ob = ob_space.shape[0]
n_actions = ac_space.n
assert len(theta) == (dim_ob + 1) * n_actions
self.W = theta[0: dim_ob * n_actions].reshape(dim_ob, n_actions)
self.b = theta[dim_ob * n_actions: None].reshape(1, n_actions)
def act(self, ob):
"""
"""
y = ob.dot(self.W) + self.b
a = y.argmax()
return a
class DeterministicContinuousActionLinearPolicy(object):
def __init__(self, theta, ob_space, ac_space):
"""
dim_ob: dimension of observations
dim_ac: dimension of action vector
theta: flat vector of parameters
"""
self.ac_space = ac_space
dim_ob = ob_space.shape[0]
dim_ac = ac_space.shape[0]
assert len(theta) == (dim_ob + 1) * dim_ac
self.W = theta[0: dim_ob * dim_ac].reshape(dim_ob, dim_ac)
self.b = theta[dim_ob * dim_ac: None]
def act(self, ob):
a = np.clip(ob.dot(self.W) + self.b, self.ac_space.low, self.ac_space.high)
return a
from gym.spaces import Discrete, Box
from policy import DeterministicDiscreteActionLinearPolicy, DeterministicContinuousActionLinearPolicy
def make_policy(environment, theta):
if isinstance(environment.action_space, Discrete):
return DeterministicDiscreteActionLinearPolicy(theta,
environment.observation_space,
environment.action_space)
elif isinstance(environment.action_space, Box):
return DeterministicContinuousActionLinearPolicy(theta,
environment.observation_space,
environment.action_space)
else:
raise NotImplementedError
def get_dim_theta(env):
if isinstance(env.action_space, Discrete):
return (env.observation_space.shape[0] + 1) * env.action_space.n
elif isinstance(env.action_space, Box):
return (env.observation_space.shape[0] + 1) * env.action_space.shape[0]
else:
raise NotImplementedError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment