Skip to content

Instantly share code, notes, and snippets.

@kirarpit
Last active June 17, 2019 06:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kirarpit/799b11670537e332296a17159cba17bb to your computer and use it in GitHub Desktop.
Save kirarpit/799b11670537e332296a17159cba17bb to your computer and use it in GitHub Desktop.
Robbie the soda-can-collecting robot (https://bit.ly/2MPUku5) trained with a deep RL algorithm, PPO, with the help of Ray RLlib.
import ray
from gym import spaces
from ray.rllib.env.multi_agent_env import MultiAgentEnv
import numpy as np
from ray.tune.registry import register_env
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.rollout import rollout
import time
from ray import tune
# Actions to idx mapping
NUM_ACTS = 7
ACT_STAY = 0
ACT_PICKUP = 1
ACT_N = 2
ACT_S = 3
ACT_E = 4
ACT_W = 5
ACT_RAND = 6
# Cell states
NUM_CELL_STATES = 4
CELL_EMPTY = 0
CELL_FOOD = 1
CELL_WALL = 2
CELL_ROBOT = 3
class RobbieCanPicker(MultiAgentEnv):
def __init__(self, env_config):
self.width = env_config['width']
self.height = env_config['height']
self.num_robots = env_config['num_robots']
self.num_cans = env_config['num_cans']
self.max_steps = env_config['max_steps']
self.alpha = env_config['alpha']
self._over = False
self.can_cnt = 0
self.step_cnt = 0
self.grid, self.robot_positions = self.generate_grid()
self.action_space = spaces.Discrete(NUM_ACTS)
self.observation_space = spaces.Tuple([spaces.Discrete(NUM_CELL_STATES)] * 5)
self.viewer = None
self.sf = 100
def reset(self):
"""Resets the env and returns observations from ready agents.
Returns:
obs (dict): New observations for each ready agent.
"""
self._over = False
self.can_cnt = 0
self.step_cnt = 0
self.grid, self.robot_positions = self.generate_grid()
return self.get_observations()
def step(self, action_dict):
"""Returns observations from ready agents.
The returns are dicts mapping from agent_id strings to values. The
number of agents in the env can vary over time.
Returns
-------
obs (dict): New observations for each ready agent.
rewards (dict): Reward values for each ready agent. If the
episode is just started, the value will be None.
dones (dict): Done values for each ready agent. The special key
"__all__" (required) is used to indicate env termination.
infos (dict): Optional info values for each agent id.
"""
if self._over:
raise Exception("Game is over. Must reset to play.")
scores = {}
for i in range(self.num_robots):
scores[i] = 0
dones = {}
for i in range(self.num_robots):
dones[i] = False
dones["__all__"] = False
# Update the robot positions.
for idx, action in action_dict.items():
if action >= NUM_ACTS:
raise Exception('Invalid action. This shouldn\'t be the case.')
r, c = self.robot_positions[idx]
if action == ACT_RAND:
action = np.random.randint(ACT_N, ACT_W + 1)
if action == ACT_N:
rn, cn = (r - 1, c)
elif action == ACT_S:
rn, cn = (r + 1, c)
elif action == ACT_W:
rn, cn = (r, c - 1)
elif action == ACT_E:
rn, cn = (r, c + 1)
else:
rn, cn = (r, c)
# If the robot collides with a wall, move it to its previous position.
# Heavily penalize it if it ran into a wall
if self.grid[rn][cn] == CELL_WALL:
scores[idx] -= 5
rn, cn = (r, c)
# Otherwise, update the robot's position.
self.robot_positions[idx] = (rn, cn)
# Check to see if there are any robots that attempted to pick up food. Reward successful
# attempts and penalize failed ones.
# first get all the agents that tried to pick the food correctly
food_picks = {}
for idx, action in action_dict.items():
if action == ACT_PICKUP:
r, c = self.robot_positions[idx]
if self.grid[r][c] == CELL_FOOD:
if (r, c) in food_picks:
food_picks[(r, c)].append(idx)
else:
food_picks[(r, c)] = [idx]
else:
scores[idx] -= 1
# Now reward the agents
for (r, c), agent_ids in food_picks.items():
for agent_id in agent_ids:
scores[agent_id] += 10 / len(agent_ids)
# reward other players according to coop or comp modes
# alpha = 0 means zerosum
# alpha = -1 means competitive with penalties
# alpha = 1 means coop
for i in [x for x in range(self.num_robots) if x != agent_id]:
scores[i] = scores[i] + self.alpha * 10 / len(agent_ids)
# now mark food gone
self.grid[r][c] = CELL_EMPTY
self.can_cnt += 1
self.step_cnt += 1
if self.can_cnt >= self.num_cans or self.step_cnt >= self.max_steps:
self._over = True
for i in range(self.num_robots):
dones[i] = True
dones["__all__"] = True
return self.get_observations(), scores, dones, {}
def generate_grid(self):
width = self.width
height = self.height
# origin at top left corner. (r, c) means rth row and cth col
grid = [[CELL_EMPTY for _ in range(width + 2)] for _ in range(height + 2)]
# Initialize the walls of the grid.
for i in range(height + 2):
grid[i][0] = CELL_WALL
grid[i][width + 1] = CELL_WALL
for j in range(width + 2):
grid[0][j] = CELL_WALL
grid[height + 1][j] = CELL_WALL
# Generate the configuration of cans and robots.
grid_locations = [(r, c) for c in range(1, self.width + 1)
for r in range(1, self.height + 1)]
np.random.shuffle(grid_locations)
for (r, c) in grid_locations[:self.num_cans]:
grid[r][c] = CELL_FOOD
return grid, [(1, 1)] * self.num_robots
def get_observations(self):
obs_dict = {}
# First, store the state of the cells the robots are on
square_contents = []
for (r, c) in self.robot_positions:
square_contents.append(self.grid[r][c])
# Then, place robots on grid
for (r, c) in self.robot_positions:
self.grid[r][c] = CELL_ROBOT
# Get observation for every robot
for idx, (r, c) in enumerate(self.robot_positions):
center = square_contents[idx]
# Translate the robot's neighborhood to a key for the chromosome (in the format N, S,
# E, W, center).
obs = (self.grid[r - 1][c], self.grid[r + 1][c],
self.grid[r][c + 1], self.grid[r][c - 1],
center)
obs_dict[idx] = obs
# Translate the robot squares back to their previous values.
for idx, (r, c) in enumerate(self.robot_positions):
self.grid[r][c] = square_contents[idx]
return obs_dict
def render(self, mode='human'):
from gym.envs.classic_control import rendering
sf = self.sf
if self.viewer is None:
self.viewer = rendering.Viewer((self.width+2)*sf, (self.height+2)*sf)
# put walls
grid = np.array(self.grid)
r, c = np.where(grid == CELL_WALL)
for pos in list(zip(r, c)):
pos = (pos[1]*sf + sf//2, (self.height+1-pos[0])*sf + sf//2)
square = rendering.make_circle(self.sf//2, res=4)
square.set_color(0, 0, 255)
transform = rendering.Transform()
square.add_attr(transform)
transform.set_translation(*pos)
self.viewer.add_geom(square)
# put food items
grid = np.array(self.grid)
r, c = np.where(grid == CELL_FOOD)
for pos in list(zip(r, c)):
pos = (pos[1]*sf + sf//2 + 1, (self.height+1-pos[0])*sf + sf//2 + 1)
circle = self.viewer.draw_circle(self.sf//2, color=(0, 255, 0))
transform = rendering.Transform()
circle.add_attr(transform)
transform.set_translation(*pos)
# put robots
for pos in self.robot_positions:
pos = (pos[1]*sf + sf//2 + 1, (self.height+1-pos[0])*sf + sf//2 + 1)
circle = self.viewer.draw_circle(self.sf//2, color=(255, 0, 0))
transform = rendering.Transform()
circle.add_attr(transform)
transform.set_translation(*pos)
time.sleep(1/10)
return self.viewer.render(return_rgb_array=mode == 'rgb_array')
def close(self):
if self.viewer:
self.viewer.close()
if not ray.is_initialized():
ray.init()
register_env("robbie_env", lambda env_config: RobbieCanPicker(env_config))
config = {'width': 10, 'height': 10, 'num_robots': 2, 'num_cans': 10, 'max_steps': 200, 'alpha': 0}
# training
tune.run(
"PPO",
name="tuning_robbie",
stop={"training_iteration": 20},
config={
"env": "robbie_env",
"num_gpus": 0,
"num_workers": 1,
"env_config": config
},
checkpoint_at_end=True
)
# rendering rollouts after training for 20 iterations
cls = get_agent_class("PPO")
trainer = cls(env="robbie_env", config={"env_config": config})
# load the saved checkpoint (Note: you must change the checkpoint path which should be at ~/ray_results/...)
trainer.restore(checkpoint_path="/nfs/student/k/kiralobo/ray_results/tuning_robbie/PPO_robbie_env_0_2019-05-21_22-16-4076pyok0l/checkpoint_20/checkpoint-20")
rollout(trainer, "robbie_env", 5000, None, False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment