Created December 27, 2017 02:04
Edited Cartpole
Classic cart-pole system implemented by Rich Sutton et al.
Copied from
Edit reset in order to change initial stae to down low
import logging
import math
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
logger = logging.getLogger(__name__)
class CartPoleEnv(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second' : 50
def __init__(self):
self.gravity = 9.8
self.masscart = 1.0
self.masspole = 0.1
self.total_mass = (self.masspole + self.masscart)
self.length = 0.5 # actually half the pole's length
self.polemass_length = (self.masspole * self.length)
self.force_mag = 10.0
self.tau = 0.02 # seconds between state updates
# Angle at which to fail the episode
# we expect full swings
self.theta_threshold_radians = np.pi #12 * 2 * math.pi / 360
self.x_threshold = 2.4
# Angle limit set to 2 * theta_threshold_radians so failing observation is still within bounds
high = np.array([
self.x_threshold * 2,
self.theta_threshold_radians * 2,
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Box(-high, high)
self.viewer = None
self.state = None
self.steps_beyond_done = None
self.steps = 0
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def _step(self, action):
assert self.action_space.contains(action), "%r (%s) invalid"%(action, type(action))
state = self.state
x, x_dot, theta, theta_dot = state
force = self.force_mag if action==1 else -self.force_mag
costheta = math.cos(theta)
sintheta = math.sin(theta)
temp = (force + self.polemass_length * theta_dot * theta_dot * sintheta) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta* temp) / (self.length * (4.0/3.0 - self.masspole * costheta * costheta / self.total_mass))
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
x = x + self.tau * x_dot
x_dot = x_dot + self.tau * xacc
theta = theta + self.tau * theta_dot
theta_dot = theta_dot + self.tau * thetaacc
self.state = (x,x_dot,theta,theta_dot)
done = x < -self.x_threshold \
or x > self.x_threshold \
or theta < -self.theta_threshold_radians \
or theta > self.theta_threshold_radians
done = bool(done)
self.steps += 1
done = bool(self.steps > 500)
reward = 0.0
if x < -self.x_threshold or x > self.x_threshold:
reward -= 1.0
reward += np.cos(theta)
if not done:
reward = 1.0
elif self.steps_beyond_done is None:
# Pole just fell!
self.steps_beyond_done = 0
reward = 1.0
if self.steps_beyond_done == 0:
logger.warning("You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.")
self.steps_beyond_done += 1
reward = 0.0
return np.array(self.state), reward, done, {}
def _reset(self):
#self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
#x, xdot, theta, thetadot
self.state = np.array([0, 0, np.pi, 0])
self.steps_beyond_done = None
return np.array(self.state)
def _render(self, mode='human', close=False):
if close:
if self.viewer is not None:
self.viewer = None
screen_width = 600
screen_height = 400
world_width = self.x_threshold*2
scale = screen_width/world_width
carty = 100 # TOP OF CART
polewidth = 10.0
polelen = scale * 1.0
cartwidth = 50.0
cartheight = 30.0
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.Viewer(screen_width, screen_height)
l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2
axleoffset =cartheight/4.0
cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
self.carttrans = rendering.Transform()
l,r,t,b = -polewidth/2,polewidth/2,polelen-polewidth/2,-polewidth/2
pole = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
self.poletrans = rendering.Transform(translation=(0, axleoffset))
self.axle = rendering.make_circle(polewidth/2)
self.track = rendering.Line((0,carty), (screen_width,carty))
if self.state is None: return None
x = self.state
cartx = x[0]*scale+screen_width/2.0 # MIDDLE OF CART
self.carttrans.set_translation(cartx, carty)
return self.viewer.render(return_rgb_array = mode=='rgb_array')
