Skip to content

Instantly share code, notes, and snippets.

@niceboy120
Forked from iandanforth/continuous_cartpole.py
Created August 5, 2022 21:15
Show Gist options
  • Save niceboy120/e3538e5bb98023eb2a2f61ebbff91406 to your computer and use it in GitHub Desktop.
Save niceboy120/e3538e5bb98023eb2a2f61ebbff91406 to your computer and use it in GitHub Desktop.
Continuous Cartpole for OpenAI Gym
"""
Classic cart-pole system implemented by Rich Sutton et al.
Copied from http://incompleteideas.net/sutton/book/code/pole.c
permalink: https://perma.cc/C9ZM-652R
Continuous version by Ian Danforth
"""
import math
import gym
from gym import spaces, logger
from gym.utils import seeding
import numpy as np
class ContinuousCartPoleEnv(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 50
}
def __init__(self):
self.gravity = 9.8
self.masscart = 1.0
self.masspole = 0.1
self.total_mass = (self.masspole + self.masscart)
self.length = 0.5 # actually half the pole's length
self.polemass_length = (self.masspole * self.length)
self.force_mag = 30.0
self.tau = 0.02 # seconds between state updates
self.min_action = -1.0
self.max_action = 1.0
# Angle at which to fail the episode
self.theta_threshold_radians = 12 * 2 * math.pi / 360
self.x_threshold = 2.4
# Angle limit set to 2 * theta_threshold_radians so failing observation
# is still within bounds
high = np.array([
self.x_threshold * 2,
np.finfo(np.float32).max,
self.theta_threshold_radians * 2,
np.finfo(np.float32).max])
self.action_space = spaces.Box(
low=self.min_action,
high=self.max_action,
shape=(1,)
)
self.observation_space = spaces.Box(-high, high)
self.seed()
self.viewer = None
self.state = None
self.steps_beyond_done = None
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def stepPhysics(self, force):
x, x_dot, theta, theta_dot = self.state
costheta = math.cos(theta)
sintheta = math.sin(theta)
temp = (force + self.polemass_length * theta_dot * theta_dot * sintheta) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta * temp) / \
(self.length * (4.0/3.0 - self.masspole * costheta * costheta / self.total_mass))
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
x = x + self.tau * x_dot
x_dot = x_dot + self.tau * xacc
theta = theta + self.tau * theta_dot
theta_dot = theta_dot + self.tau * thetaacc
return (x, x_dot, theta, theta_dot)
def step(self, action):
assert self.action_space.contains(action), \
"%r (%s) invalid" % (action, type(action))
# Cast action to float to strip np trappings
force = self.force_mag * float(action)
self.state = self.stepPhysics(force)
x, x_dot, theta, theta_dot = self.state
done = x < -self.x_threshold \
or x > self.x_threshold \
or theta < -self.theta_threshold_radians \
or theta > self.theta_threshold_radians
done = bool(done)
if not done:
reward = 1.0
elif self.steps_beyond_done is None:
# Pole just fell!
self.steps_beyond_done = 0
reward = 1.0
else:
if self.steps_beyond_done == 0:
logger.warn("""
You are calling 'step()' even though this environment has already returned
done = True. You should always call 'reset()' once you receive 'done = True'
Any further steps are undefined behavior.
""")
self.steps_beyond_done += 1
reward = 0.0
return np.array(self.state), reward, done, {}
def reset(self):
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
self.steps_beyond_done = None
return np.array(self.state)
def render(self, mode='human'):
screen_width = 600
screen_height = 400
world_width = self.x_threshold * 2
scale = screen_width /world_width
carty = 100 # TOP OF CART
polewidth = 10.0
polelen = scale * 1.0
cartwidth = 50.0
cartheight = 30.0
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.Viewer(screen_width, screen_height)
l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
axleoffset = cartheight / 4.0
cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
self.carttrans = rendering.Transform()
cart.add_attr(self.carttrans)
self.viewer.add_geom(cart)
l, r, t, b = -polewidth / 2, polewidth / 2, polelen-polewidth / 2, -polewidth / 2
pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
pole.set_color(.8, .6, .4)
self.poletrans = rendering.Transform(translation=(0, axleoffset))
pole.add_attr(self.poletrans)
pole.add_attr(self.carttrans)
self.viewer.add_geom(pole)
self.axle = rendering.make_circle(polewidth / 2)
self.axle.add_attr(self.poletrans)
self.axle.add_attr(self.carttrans)
self.axle.set_color(.5, .5, .8)
self.viewer.add_geom(self.axle)
self.track = rendering.Line((0, carty), (screen_width, carty))
self.track.set_color(0, 0, 0)
self.viewer.add_geom(self.track)
if self.state is None:
return None
x = self.state
cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART
self.carttrans.set_translation(cartx, carty)
self.poletrans.set_rotation(-x[2])
return self.viewer.render(return_rgb_array=(mode == 'rgb_array'))
def close(self):
if self.viewer:
self.viewer.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment