Last active
May 25, 2018 16:45
-
-
Save iandanforth/05683376a37a04b7fce217dd23bb14ce to your computer and use it in GitHub Desktop.
Cartpole implemented using Pymunk 2D physics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pygame | |
import pymunk | |
from pygame.locals import (QUIT, KEYDOWN, K_ESCAPE) | |
############################################################################## | |
# Pygame | |
def handlePygameEvents(): | |
for event in pygame.event.get(): | |
if event.type == QUIT: | |
sys.exit(0) | |
elif event.type == KEYDOWN and event.key == K_ESCAPE: | |
sys.exit(0) | |
############################################################################## | |
# Pymunk | |
def addTrack(screen_width, space, track_pos_y, padding): | |
track_body, track_shape = getTrack( | |
screen_width, | |
padding | |
) | |
track_body.position = (-padding / 2, track_pos_y) | |
space.add(track_shape, track_body) | |
return track_body, track_shape | |
def addCart( | |
screen_width, | |
space, | |
cart_width, | |
cart_height, | |
cart_mass, | |
track_pos_y | |
): | |
cart_body, cart_shape = getCart( | |
cart_width, | |
cart_height, | |
cart_mass | |
) | |
cart_body.position = ( | |
(screen_width / 2) - (cart_width / 2), | |
track_pos_y | |
) | |
space.add(cart_shape, cart_body) | |
return cart_body, cart_shape | |
def addPole( | |
screen_width, | |
space, | |
pole_length, | |
pole_mass, | |
track_pos_y, | |
cart_height | |
): | |
pole_body, pole_shape = getPole(pole_length, pole_mass) | |
pole_body.position = ( | |
(screen_width / 2), | |
track_pos_y + (cart_height / 2) | |
) | |
space.add(pole_shape, pole_body) | |
return pole_body, pole_shape | |
def addConstraints(space, cart_shape, track_shape, pole_shape): | |
constraints = getCartConstraints( | |
cart_shape, | |
track_shape, | |
pole_shape | |
) | |
space.add(*constraints) | |
return constraints | |
def getShapeWidthHeight(shape): | |
bb = shape.bb | |
return ((bb.right - bb.left), (bb.top - bb.bottom)) | |
def getCartConstraints(cart_shape, track_shape, pole_shape): | |
cart_width, cart_height = getShapeWidthHeight(cart_shape) | |
track_width, _ = getShapeWidthHeight(track_shape) | |
track_c_1 = pymunk.GrooveJoint( | |
track_shape.body, | |
cart_shape.body, | |
(0, 0), # Groove start on track | |
(track_width, 0), # Groove end on track | |
# Body local anchor on cart | |
(0, 0) | |
) | |
# Make constraints as 'strong' as possible | |
track_c_1.error_bias = 0.0001 | |
track_c_2 = pymunk.GrooveJoint( | |
track_shape.body, | |
cart_shape.body, | |
(0, 0), # Groove start on track | |
(track_width, 0), # Groove end on track | |
# Body local anchor on cart | |
(cart_width, 0) | |
) | |
track_c_2.error_bias = 0.0001 | |
cart_pole_c = pymunk.PivotJoint( | |
cart_shape.body, | |
pole_shape.body, | |
# Body local anchor on cart | |
(cart_width / 2, cart_height / 2), | |
# Body local achor on pole | |
(0, 0) | |
) | |
cart_pole_c.error_bias = 0.0001 | |
return (track_c_1, track_c_2, cart_pole_c) | |
def getPole(length, mass, friction=1.0): | |
body = pymunk.Body(0, 0) | |
shape = pymunk.Segment( | |
body, | |
(0, 0), | |
(0, length), | |
5 | |
) | |
shape.sensor = True # Disable collision | |
shape.mass = mass | |
shape.friction = friction | |
return (body, shape) | |
def getTrack(track_length, padding): | |
track_body = pymunk.Body(body_type=pymunk.Body.STATIC) | |
track_shape = pymunk.Segment( | |
track_body, | |
(0, 0), | |
(track_length + padding, 0), | |
2 | |
) | |
track_shape.sensor = True # Disable collision | |
return (track_body, track_shape) | |
def getCart(width, height, mass): | |
inertia = pymunk.moment_for_box( | |
mass, | |
(width, height) | |
) | |
body = pymunk.Body(mass, inertia) | |
shape = getPymunkRect( | |
body, | |
width, | |
height | |
) | |
return (body, shape) | |
def getPymunkRect(body, width, height): | |
shape = pymunk.Poly(body, [ | |
(0, 0), | |
(width, 0), | |
(width, height), | |
(0, height) | |
]) | |
return shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Classic cart-pole system implemented by Rich Sutton et al. | |
Copied from http://incompleteideas.net/sutton/book/code/pole.c | |
permalink: https://perma.cc/C9ZM-652R | |
Pymunk version by Ian Danforth | |
""" | |
import math | |
import gym | |
import pygame | |
import pymunk | |
import pymunk.pygame_util | |
import numpy as np | |
from gym import spaces, logger | |
from gym.utils import seeding | |
from . import cartpole_utils as utils | |
class PymunkCartPoleEnv(gym.Env): | |
metadata = { | |
'render.modes': ['human', 'rgb_array'], | |
'video.frames_per_second': 50 | |
} | |
def __init__(self): | |
# Pygame and display setup | |
pygame.init() | |
self.screen_width = 600 | |
self.screen_height = 400 | |
self.screen = pygame.display.set_mode( | |
(self.screen_width, self.screen_height) | |
) | |
pygame.display.set_caption("pymunk_cartpole.py") | |
self.clock = pygame.time.Clock() | |
self._initPymunk() | |
# Action Space | |
self.force_mag = 100.0 | |
self.min_action = -1.0 | |
self.max_action = 1.0 | |
self.action_space = spaces.Box( | |
low=self.min_action, | |
high=self.max_action, | |
shape=(1,) | |
) | |
# Observation Space | |
# Angle at which to fail the episode | |
self.theta_threshold_radians = 12 * 2 * math.pi / 360 | |
self.x_threshold = 2.4 | |
# Angle limit set to 2 * theta_threshold_radians so failing observation | |
# is still within bounds | |
high = np.array([ | |
self.x_threshold * 2, | |
np.finfo(np.float32).max, | |
self.theta_threshold_radians * 2, | |
np.finfo(np.float32).max]) | |
self.observation_space = spaces.Box(-high, high) | |
self.steps_beyond_done = None | |
def _initPymunk(self): | |
# Simulation space | |
self.space = pymunk.Space() | |
self.space.gravity = (0.0, -900.0) | |
self.space.iterations = 20 # Double default | |
# Debug draw setup (called in render()) | |
self.draw_options = pymunk.pygame_util.DrawOptions(self.screen) | |
self.draw_options.flags = 3 | |
# Track | |
track_pos_y = 100 | |
# Track outside of view area | |
padding = 400 | |
self.track_body, self.track_shape = utils.addTrack( | |
self.screen_width, | |
self.space, | |
track_pos_y, | |
padding | |
) | |
# Cart | |
cart_width = 60 | |
cart_height = 30 | |
cart_mass = 1.0 | |
self.cart_body, self.cart_shape = utils.addCart( | |
self.screen_width, | |
self.space, | |
cart_width, | |
cart_height, | |
cart_mass, | |
track_pos_y | |
) | |
# Pole | |
pole_length = 100 | |
pole_mass = 0.1 | |
self.pole_body, self.pole_shape = utils.addPole( | |
self.screen_width, | |
self.space, | |
pole_length, | |
pole_mass, | |
track_pos_y, | |
cart_height | |
) | |
# Constraints | |
self.constraints = utils.addConstraints( | |
self.space, | |
self.cart_shape, | |
self.track_shape, | |
self.pole_shape | |
) | |
def seed(self, seed=None): | |
self.np_random, seed = seeding.np_random(seed) | |
return [seed] | |
def step(self, action): | |
""" | |
- Take action | |
- Step the physics of the world | |
- Check for 'done' conditions | |
- Return reward as appropriate | |
Note: render() must be called at least once before | |
this method is called otherwise pymunk breaks. | |
# e.g. OverflowError: Python int too large to convert to C long | |
""" | |
force = self.force_mag * action | |
force = [-10.0] | |
self.cart_body.apply_force_at_local_point( | |
force, | |
self.cart_body.center_of_gravity | |
) | |
theta = self.pole_body.angle % (math.pi * 2) | |
x = self.cart_body.position[0] | |
done = x < 0.0 \ | |
or x > self.screen_width \ | |
or theta < -self.theta_threshold_radians \ | |
or theta > self.theta_threshold_radians | |
done = bool(done) | |
if not done: | |
reward = 1.0 | |
elif self.steps_beyond_done is None: | |
# Pole just fell! | |
self.steps_beyond_done = 0 | |
reward = 1.0 | |
else: | |
if self.steps_beyond_done == 0: | |
logger.warn(""" | |
You are calling 'step()' even though this environment has already returned | |
done = True. You should always call 'reset()' once you receive 'done = True' | |
Any further steps are undefined behavior. | |
""") | |
self.steps_beyond_done += 1 | |
reward = 0.0 | |
self.space.step(1 / 50.0) | |
return True, reward, done, {} | |
def render(self, mode='human'): | |
utils.handlePygameEvents() | |
# Redraw all objects | |
self.screen.fill((255, 255, 255)) | |
self.space.debug_draw(self.draw_options) | |
pygame.display.flip() | |
self.clock.tick(50) | |
def reset(self): | |
if self.space: | |
del self.space | |
self._initPymunk() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from gym.envs.registration import register | |
register( | |
id='PymunkCartPole-v0', | |
entry_point='envs:PymunkCartPoleEnv', | |
) | |
# Make a new continuous cartpole | |
env = gym.make('PymunkCartPole-v0') | |
env.reset() | |
for i in range(1000): | |
env.render() # Must be called before step | |
action = env.action_space.sample() | |
env.step(action) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Warning: Untested Code
Developed on OSX, but has not yet been used as part of a full learning scenario. Do not expect a trained agent to work in this environment out of the box. If you have questions feel free to email me at (iandanforth at gmail)
Requirements
Instructions
Copy these files into a directory of your choice then: