Created
November 9, 2020 11:58
-
-
Save ngoodger/6cf50a05c9b3c189be30fab34ab5d85c to your computer and use it in GitHub Desktop.
PyTorch Gym Environments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import math | |
import torch | |
from gym import spaces, logger | |
import numpy as np | |
class Pendulum(gym.Env): | |
metadata = { | |
'render.modes': ['human', 'rgb_array'], | |
'video.frames_per_second': 30 | |
} | |
def __init__(self, env_count=1, device="cpu", g=10.0): | |
self.max_speed = 8 | |
self.max_torque = 2. | |
self.dt = .05 | |
self.g = g | |
self.m = 1. | |
self.l = 1. | |
self.viewer = None | |
self.env_count = env_count | |
self.device = device | |
""" | |
high = np.array([1., 1., self.max_speed], dtype=np.float32) | |
self.action_space = spaces.Box( | |
low=-self.max_torque, | |
high=self.max_torque, shape=(1,), | |
dtype=np.float32 | |
) | |
self.observation_space = spaces.Box( | |
low=-high, | |
high=high, | |
dtype=np.float32 | |
) | |
""" | |
high = np.array([1., 1., self.max_speed], dtype=np.float32) | |
self.action_space = spaces.Box( | |
low=-self.max_torque, | |
high=self.max_torque, shape=(1,), | |
dtype=np.float32 | |
) | |
self.observation_space = spaces.Box( | |
low=-high, | |
high=high, | |
dtype=np.float32 | |
) | |
#TODO seeding | |
#self.seed() | |
def seed(self, seed=None): | |
self.np_random, seed = seeding.np_random(seed) | |
return [seed] | |
def step(self, u): | |
th, thdot = self.state[:, 0], self.state[:, 1] # th := theta | |
g = self.g | |
m = self.m | |
l = self.l | |
dt = self.dt | |
u = torch.clamp(u, -self.max_torque, self.max_torque).squeeze(1) | |
self.last_u = u # for rendering | |
costs = angle_normalize(th) ** 2 + .1 * thdot ** 2 + .001 * (u ** 2) | |
newthdot = thdot + (-3 * g / (2 * l) * torch.sin(th + math.pi) + 3. / (m * l ** 2) * u) * dt | |
newth = th + newthdot * dt | |
newthdot = torch.clamp(newthdot, -self.max_speed, self.max_speed) | |
self.state = torch.stack((newth, newthdot), dim=1) | |
return self._get_obs(), -costs, False, {} | |
def reset(self): | |
theta = math.pi * (2. * torch.rand(self.env_count, device=self.device) - 0.5) | |
thetadot = (2. * torch.rand(self.env_count, device=self.device) - 0.5) | |
self.state = torch.stack((theta, thetadot), dim=1) | |
self.last_u = None | |
return self._get_obs() | |
def _get_obs(self): | |
theta, thetadot = self.state[:, 0], self.state[:, 1] | |
return torch.stack(([torch.cos(theta), torch.sin(theta), thetadot]), dim=1) | |
def render(self, mode='human'): | |
if self.viewer is None: | |
from gym.envs.classic_control import rendering | |
self.viewer = rendering.Viewer(500, 500) | |
self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2) | |
rod = rendering.make_capsule(1, .2) | |
rod.set_color(.8, .3, .3) | |
self.pole_transform = rendering.Transform() | |
rod.add_attr(self.pole_transform) | |
self.viewer.add_geom(rod) | |
axle = rendering.make_circle(.05) | |
axle.set_color(0, 0, 0) | |
self.viewer.add_geom(axle) | |
#fname = path.join(path.dirname(__file__), "assets/clockwise.png") | |
fname = path.join(path.dirname(""), "assets/clockwise.png") | |
self.img = rendering.Image(fname, 1., 1.) | |
self.imgtrans = rendering.Transform() | |
self.img.add_attr(self.imgtrans) | |
self.viewer.add_onetime(self.img) | |
self.pole_transform.set_rotation(self.state[0] + math.pi / 2) | |
if self.last_u: | |
self.imgtrans.scale = (-self.last_u / 2, torch.abs(self.last_u) / 2) | |
return self.viewer.render(return_rgb_array=mode == 'rgb_array') | |
def close(self): | |
if self.viewer: | |
self.viewer.close() | |
self.viewer = None | |
def angle_normalize(x): | |
return (((x + math.pi) % (2 * math.pi)) - math.pi) | |
""" | |
Classic cart-pole system implemented by Rich Sutton et al. | |
Copied from http://incompleteideas.net/sutton/book/code/pole.c | |
permalink: https://perma.cc/C9ZM-652R | |
""" | |
class CartPole(gym.Env): | |
""" | |
Description: | |
A pole is attached by an un-actuated joint to a cart, which moves along | |
a frictionless track. The pendulum starts upright, and the goal is to | |
prevent it from falling over by increasing and reducing the cart's | |
velocity. | |
Source: | |
This environment corresponds to the version of the cart-pole problem | |
described by Barto, Sutton, and Anderson | |
Observation: | |
Type: Box(4) | |
Num Observation Min Max | |
0 Cart Position -4.8 4.8 | |
1 Cart Velocity -Inf Inf | |
2 Pole Angle -0.418 rad (-24 deg) 0.418 rad (24 deg) | |
3 Pole Angular Velocity -Inf Inf | |
Actions: | |
Type: Discrete(2) | |
Num Action | |
0 Push cart to the left | |
1 Push cart to the right | |
Note: The amount the velocity that is reduced or increased is not | |
fixed; it depends on the angle the pole is pointing. This is because | |
the center of gravity of the pole increases the amount of energy needed | |
to move the cart underneath it | |
Reward: | |
Reward is 1 for every step taken, including the termination step | |
Starting State: | |
All observations are assigned a uniform random value in [-0.05..0.05] | |
Episode Termination: | |
Pole Angle is more than 12 degrees. | |
Cart Position is more than 2.4 (center of the cart reaches the edge of | |
the display). | |
Episode length is greater than 200. | |
Solved Requirements: | |
Considered solved when the average return is greater than or equal to | |
195.0 over 100 consecutive trials. | |
""" | |
metadata = { | |
'render.modes': ['human', 'rgb_array'], | |
'video.frames_per_second': 50 | |
} | |
def __init__(self, env_count=1, device="cpu"): | |
self.gravity = 9.8 | |
self.masscart = 1.0 | |
self.masspole = 0.1 | |
self.total_mass = (self.masspole + self.masscart) | |
self.length = 0.5 # actually half the pole's length | |
self.polemass_length = (self.masspole * self.length) | |
self.force_mag = 10.0 | |
self.tau = 0.02 # seconds between state updates | |
self.kinematics_integrator = 'euler' | |
# Angle at which to fail the episode | |
self.theta_threshold_radians = 12 * 2 * math.pi / 360 | |
self.x_threshold = 2.4 | |
self.env_count = env_count | |
# Angle limit set to 2 * theta_threshold_radians so failing observation | |
# is still within bounds. | |
high = np.array([self.x_threshold * 2, | |
np.finfo(np.float32).max, | |
self.theta_threshold_radians * 2, | |
np.finfo(np.float32).max], | |
dtype=np.float32) | |
self.action_space = spaces.Discrete(2) | |
self.observation_space = spaces.Box(-high, high, dtype=np.float32) | |
self.seed() | |
self.viewer = None | |
self.state = None | |
self.done = torch.full([env_count], True, dtype=torch.bool, device=device) | |
self.state = torch.zeros([self.env_count, 4], dtype=torch.float32, device=device) | |
self.device = device | |
def seed(self, seed=None): | |
return [seed] | |
def step(self, action): | |
#breakpoint() | |
# All env must already have been reset. | |
self.done[:] = False | |
x, x_dot, theta, theta_dot = self.state[:, 0], self.state[:, 1], self.state[:, 2], self.state[:, 3] | |
#breakpoint() | |
force = self.force_mag * ((action * 2.) - 1.) | |
costheta = torch.cos(theta) | |
sintheta = torch.sin(theta) | |
# For the interested reader: | |
# https://coneural.org/florian/papers/05_cart_pole.pdf | |
temp = (force + self.polemass_length * theta_dot ** 2 * sintheta) / self.total_mass | |
thetaacc = ((self.gravity * sintheta - costheta * temp) | |
/ (self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass))) | |
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass | |
if self.kinematics_integrator == 'euler': | |
x = x + self.tau * x_dot | |
x_dot = x_dot + self.tau * xacc | |
theta = theta + self.tau * theta_dot | |
theta_dot = theta_dot + self.tau * thetaacc | |
else: # semi-implicit euler | |
x_dot = x_dot + self.tau * xacc | |
x = x + self.tau * x_dot | |
theta_dot = theta_dot + self.tau * thetaacc | |
theta = theta + self.tau * theta_dot | |
self.state[:, 0], self.state[:, 1], self.state[:, 2], self.state[:, 3] = x, x_dot, theta, theta_dot | |
self.done = ( | |
(x < -self.x_threshold) | |
| (x > self.x_threshold) | |
| (theta < -self.theta_threshold_radians) | |
| (theta > self.theta_threshold_radians) | |
) | |
reward = ~self.done | |
self.state = self.reset() | |
return self.state, reward, self.done, {} | |
def reset(self): | |
#breakpoint() | |
self.state = torch.where(self.done.unsqueeze(1), (torch.rand(self.env_count, 4, device=self.device) -0.5) / 10., self.state) | |
#self.state = (torch.rand((self.env_count, 4)) -0.5) / 10. | |
return self.state | |
def render(self, mode='human'): | |
screen_width = 600 | |
screen_height = 400 | |
world_width = self.x_threshold * 2 | |
scale = screen_width/world_width | |
carty = 100 # TOP OF CART | |
polewidth = 10.0 | |
polelen = scale * (2 * self.length) | |
cartwidth = 50.0 | |
cartheight = 30.0 | |
if self.viewer is None: | |
from gym.envs.classic_control import rendering | |
self.viewer = rendering.Viewer(screen_width, screen_height) | |
l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 | |
axleoffset = cartheight / 4.0 | |
cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) | |
self.carttrans = rendering.Transform() | |
cart.add_attr(self.carttrans) | |
self.viewer.add_geom(cart) | |
l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2 | |
pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) | |
pole.set_color(.8, .6, .4) | |
self.poletrans = rendering.Transform(translation=(0, axleoffset)) | |
pole.add_attr(self.poletrans) | |
pole.add_attr(self.carttrans) | |
self.viewer.add_geom(pole) | |
self.axle = rendering.make_circle(polewidth/2) | |
self.axle.add_attr(self.poletrans) | |
self.axle.add_attr(self.carttrans) | |
self.axle.set_color(.5, .5, .8) | |
self.viewer.add_geom(self.axle) | |
self.track = rendering.Line((0, carty), (screen_width, carty)) | |
self.track.set_color(0, 0, 0) | |
self.viewer.add_geom(self.track) | |
self._pole_geom = pole | |
if self.state is None: | |
return None | |
# Edit the pole polygon vertex | |
pole = self._pole_geom | |
l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2 | |
pole.v = [(l, b), (l, t), (r, t), (r, b)] | |
x = self.state | |
cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART | |
self.carttrans.set_translation(cartx, carty) | |
self.poletrans.set_rotation(-x[2]) | |
return self.viewer.render(return_rgb_array=mode == 'rgb_array') | |
def close(self): | |
if self.viewer: | |
self.viewer.close() | |
self.viewer = None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment