Created
September 20, 2022 16:28
-
-
Save cool-RR/a8144840657721e73ee8eca5c7537cf5 to your computer and use it in GitHub Desktop.
ry.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import random | |
from typing import Tuple | |
import re | |
import pprint | |
import dataclasses | |
import functools | |
import ray.rllib.env.multi_agent_env | |
from ray.rllib.agents.ppo import PPOTrainer | |
import gym | |
from gym.spaces import Box, Discrete | |
import numpy as np | |
N_AGENTS = 3 | |
class State: | |
def __init__(self, length: int, positions: Tuple[int], i_timestep: int) -> None: | |
self.length = length | |
self.positions = positions | |
self.i_timestep = i_timestep | |
self.agents = tuple(f'agent_{i}' for i in range(len(self.positions))) | |
self.observations = {agent: self.positions[i:] + self.positions[:i] | |
for i, agent in enumerate(self.agents)} | |
self.rewards = {agent: self._get_distance_to_closest_neighbor(i) ** 2 | |
for i, agent in enumerate(self.agents)} | |
@staticmethod | |
def make_initial(n_agents, length): | |
assert n_agents <= 10 | |
return State( | |
length=length, | |
positions=tuple(random.sample(range(length), n_agents)), | |
i_timestep=0 | |
) | |
def step(self, actions) -> State: | |
new_positions = list(self.positions) | |
for i, agent in enumerate(self.agents): | |
offset = 1 if (actions[agent] == 1) else -1 | |
new_positions[i] += offset | |
return State( | |
length=self.length, | |
positions=new_positions, | |
i_timestep=self.i_timestep + 1 | |
) | |
def _get_distance_to_closest_neighbor(self, i: int) -> int: | |
position = self.positions[i] | |
other_positions = set(self.positions[:i] + self.positions[i + 1:]) | |
for i in range(self.length): | |
# Walk left and right until finding the closest neighbor: | |
if other_positions & {(position - i) % self.length, (position + i) % self.length}: | |
return i | |
else: | |
raise RuntimeError | |
@property | |
def text(self): | |
text = ([' '] * self.length) + [']'] | |
for i, position in enumerate(self.positions): | |
if text[position] == ' ': | |
text[position] = str(i) | |
else: | |
text[position] = '*' | |
result = '[' + ''.join(text) | |
score = int((10 * self.rewards[0]) // ((self.length / 2) ** 2)) | |
result += ' ' + ('+' * score) + '\n' | |
return result | |
class Env(ray.rllib.env.multi_agent_env.MultiAgentEnv): | |
def __init__(self, config=None): | |
config = config or {} | |
self.length = config.get('length', 15) | |
self.n_agents = config.get('n_agents', N_AGENTS) | |
self.agents = tuple(f'agent_{i}' for i in range(self.n_agents)) | |
self.timestep_limit = config.get('ts', 100) | |
self.observation_space = Box(0, self.length - 1, shape=(self.n_agents,), dtype=int) | |
self.action_space = Discrete(2) | |
self.reset() | |
def reset(self): | |
self.state = State.make_initial(self.n_agents, self.length) | |
return self.observations | |
@property | |
def observations(self): | |
return self.state.observations | |
def step(self, actions: dict): | |
self.state = self.state.step(actions) | |
is_done = (self.state.i_timestep >= self.timestep_limit) | |
dones = {key: is_done for key in self.agents + ('__all__',)} | |
return self.observations, self.state.rewards, dones, {} | |
env = Env() | |
policies = { | |
f'policy_{i}': (None, env.observation_space, env.action_space, {}) for i in range(N_AGENTS) | |
} | |
def policy_mapping_fn(agent_id: str) -> str: | |
match = re.fullmatch('^agent_([0-9])$', agent_id) | |
i = int(match.group(1)) | |
assert 0 <= i <= 9 | |
return f'policy_{i}' | |
config = { | |
'env': Env, | |
'env_config': { | |
'config': {}, | |
}, | |
'create_env_on_driver': True, | |
'multiagent': { | |
'policies': policies, | |
'policy_mapping_fn': policy_mapping_fn, | |
}, | |
} | |
rllib_trainer = PPOTrainer(config=config) | |
for _ in range(5): | |
results = rllib_trainer.train() | |
print(f"Iteration={rllib_trainer.iteration}: R(\"return\")={results['episode_reward_mean']}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment