Skip to content

Instantly share code, notes, and snippets.

View zh4ngx's full-sized avatar

Andy Zhang zh4ngx

  • San Francisco, CA
View GitHub Profile
@zh4ngx
zh4ngx / cross_entropy.py
Created July 4, 2017 19:17
Cleaned up CartPole
# Source: http://rl-gym-doc.s3-website-us-west-2.amazonaws.com/mlss/lab1.html
import gym
import numpy as np
from gym.wrappers.monitoring import Monitor
from policy import Policy
# Task settings:
env = gym.make('CartPole-v0') # Change as needed
env = Monitor(env, 'tmp/cart-pole-cross-entropy-1', force=True)
@zh4ngx
zh4ngx / evaluation.py
Created July 4, 2017 05:21
Monte Carlo EM CartPole-v0 with exponentially weighted variance
from utils import make_policy
def do_episode(policy, env, max_steps, render=False):
total_rew = 0
ob = env.reset()
for t in range(max_steps):
a = policy.act(ob)
(ob, reward, done, _info) = env.step(a)
total_rew += reward
@zh4ngx
zh4ngx / evaluation.py
Created July 4, 2017 04:58
Monte Carlo EM - weighted sampling of mean/variance of theta by reward
from utils import make_policy
def do_episode(policy, env, max_steps, render=False):
total_rew = 0
ob = env.reset()
for t in range(max_steps):
a = policy.act(ob)
(ob, reward, done, _info) = env.step(a)
total_rew += reward
@zh4ngx
zh4ngx / cross_entropy.py
Created July 4, 2017 04:31
Cross Entropy (Evolutionary Strategy) on CartPole-v0 - somewhat overparameterized
# Source: http://rl-gym-doc.s3-website-us-west-2.amazonaws.com/mlss/lab1.html
import gym
import numpy as np
from gym.wrappers.monitoring import Monitor
from evaluation import noisy_evaluation, do_episode
from utils import get_dim_theta, make_policy
# Task settings:
env = gym.make('CartPole-v0') # Change as needed
@zh4ngx
zh4ngx / cart_pole_cem_3.py
Created July 4, 2017 01:36
CartPole-v0 Cross Entropy Method with Minimal Params
# Source: http://rl-gym-doc.s3-website-us-west-2.amazonaws.com/mlss/lab1.html
import gym
import numpy as np
from gym.spaces import Discrete, Box
from gym.wrappers.monitoring import Monitor
# ================================================================
# Policies
# ================================================================
@zh4ngx
zh4ngx / cart_pole_cem_1
Created July 4, 2017 01:35
CartPole-v0 Cross Entropy Method - Affine Function and Base HyperParams
# Source: http://rl-gym-doc.s3-website-us-west-2.amazonaws.com/mlss/lab1.html
import gym
import numpy as np
from gym.spaces import Discrete, Box
from gym.wrappers.monitoring import Monitor
# ================================================================
# Policies
# ================================================================
@zh4ngx
zh4ngx / cart_pole_cem_2.py
Created July 4, 2017 01:28
CartPole-v0 Cross Entropy Method with no bias
# Source: http://rl-gym-doc.s3-website-us-west-2.amazonaws.com/mlss/lab1.html
import gym
import numpy as np
from gym.spaces import Discrete, Box
from gym.wrappers.monitoring import Monitor
# ================================================================
# Policies
# ================================================================
@zh4ngx
zh4ngx / hill_climb_4
Created July 1, 2017 03:53
CartPole - Hill Climb v4 - Correct hill climb and reduced noise & variance (MC-10)
import gym
import numpy as np
from gym.wrappers.monitoring import Monitor
MC_POLICY_EVAL_EP = 10
BASE_NOISE_FACTOR = 0.1
NUM_POLICY_EVAL = 500
env = gym.make('CartPole-v0')
env = Monitor(env, 'tmp/cart-pole-hill-climb-4', force=True)
@zh4ngx
zh4ngx / hill_climb_3.py
Created July 1, 2017 03:20
CartPole-v0 Hill Climb + MC(10) + Gaussian Noise (Sigma 0.5)
import gym
import numpy as np
from gym.wrappers.monitoring import Monitor
MC_POLICY_EVAL_EP = 10
BASE_NOISE_FACTOR = 0.5
NUM_POLICY_EVAL = 500
env = gym.make('CartPole-v0')
@zh4ngx
zh4ngx / hill_climb_2.py
Created July 1, 2017 02:34
CartPole-v0 Hill Climb with MC(10) Eval + Simulated Annealing
import gym
import numpy as np
from gym.wrappers.monitoring import Monitor
env = gym.make('CartPole-v0')
env = Monitor(env, 'tmp/cart-pole-hill-climb-2', force=True)
print("Action space: {0}".format(env.action_space))
print("Observation space: {0}\n\tLow: {1}\n\tHigh: {2}".format(
env.observation_space,