Skip to content

Instantly share code, notes, and snippets.

@cool-RR
Created September 20, 2022 16:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cool-RR/a8144840657721e73ee8eca5c7537cf5 to your computer and use it in GitHub Desktop.
Save cool-RR/a8144840657721e73ee8eca5c7537cf5 to your computer and use it in GitHub Desktop.
ry.py
from __future__ import annotations
import random
from typing import Tuple
import re
import pprint
import dataclasses
import functools
import ray.rllib.env.multi_agent_env
from ray.rllib.agents.ppo import PPOTrainer
import gym
from gym.spaces import Box, Discrete
import numpy as np
N_AGENTS = 3
class State:
def __init__(self, length: int, positions: Tuple[int], i_timestep: int) -> None:
self.length = length
self.positions = positions
self.i_timestep = i_timestep
self.agents = tuple(f'agent_{i}' for i in range(len(self.positions)))
self.observations = {agent: self.positions[i:] + self.positions[:i]
for i, agent in enumerate(self.agents)}
self.rewards = {agent: self._get_distance_to_closest_neighbor(i) ** 2
for i, agent in enumerate(self.agents)}
@staticmethod
def make_initial(n_agents, length):
assert n_agents <= 10
return State(
length=length,
positions=tuple(random.sample(range(length), n_agents)),
i_timestep=0
)
def step(self, actions) -> State:
new_positions = list(self.positions)
for i, agent in enumerate(self.agents):
offset = 1 if (actions[agent] == 1) else -1
new_positions[i] += offset
return State(
length=self.length,
positions=new_positions,
i_timestep=self.i_timestep + 1
)
def _get_distance_to_closest_neighbor(self, i: int) -> int:
position = self.positions[i]
other_positions = set(self.positions[:i] + self.positions[i + 1:])
for i in range(self.length):
# Walk left and right until finding the closest neighbor:
if other_positions & {(position - i) % self.length, (position + i) % self.length}:
return i
else:
raise RuntimeError
@property
def text(self):
text = ([' '] * self.length) + [']']
for i, position in enumerate(self.positions):
if text[position] == ' ':
text[position] = str(i)
else:
text[position] = '*'
result = '[' + ''.join(text)
score = int((10 * self.rewards[0]) // ((self.length / 2) ** 2))
result += ' ' + ('+' * score) + '\n'
return result
class Env(ray.rllib.env.multi_agent_env.MultiAgentEnv):
def __init__(self, config=None):
config = config or {}
self.length = config.get('length', 15)
self.n_agents = config.get('n_agents', N_AGENTS)
self.agents = tuple(f'agent_{i}' for i in range(self.n_agents))
self.timestep_limit = config.get('ts', 100)
self.observation_space = Box(0, self.length - 1, shape=(self.n_agents,), dtype=int)
self.action_space = Discrete(2)
self.reset()
def reset(self):
self.state = State.make_initial(self.n_agents, self.length)
return self.observations
@property
def observations(self):
return self.state.observations
def step(self, actions: dict):
self.state = self.state.step(actions)
is_done = (self.state.i_timestep >= self.timestep_limit)
dones = {key: is_done for key in self.agents + ('__all__',)}
return self.observations, self.state.rewards, dones, {}
env = Env()
policies = {
f'policy_{i}': (None, env.observation_space, env.action_space, {}) for i in range(N_AGENTS)
}
def policy_mapping_fn(agent_id: str) -> str:
match = re.fullmatch('^agent_([0-9])$', agent_id)
i = int(match.group(1))
assert 0 <= i <= 9
return f'policy_{i}'
config = {
'env': Env,
'env_config': {
'config': {},
},
'create_env_on_driver': True,
'multiagent': {
'policies': policies,
'policy_mapping_fn': policy_mapping_fn,
},
}
rllib_trainer = PPOTrainer(config=config)
for _ in range(5):
results = rllib_trainer.train()
print(f"Iteration={rllib_trainer.iteration}: R(\"return\")={results['episode_reward_mean']}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment