Skip to content

Instantly share code, notes, and snippets.

@TheTrope
Last active May 10, 2022 11:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TheTrope/5fe56bd683408cd9c18f503b6ea88e8d to your computer and use it in GitHub Desktop.
Save TheTrope/5fe56bd683408cd9c18f503b6ea88e8d to your computer and use it in GitHub Desktop.
from gym import spaces
import ray
from ray import tune
from ray.rllib.env.multi_agent_env import MultiAgentEnv
###################################
##### Game simulation Mock
##################################
class BaseAgent():
def get_obs(self):
raise NotImplementedError()
def set_action(self, action):
raise NotImplementedError()
def get_rew(self):
return 0
class Car(BaseAgent):
def __init__(self, max_traffic_lights):
self.observation_space = spaces.Dict({
"position": spaces.Box(low=-1, high=1, shape=(3,)),
"velocity": spaces.Box(low=-1, high=1, shape=(3,)),
"sensors": spaces.Box(low=-1, high=1, shape=(10,)),
"traffic_lights": spaces.Box(low=-1, high=1, shape=(3 * max_traffic_lights,))
})
self.action_space = spaces.Dict({
"mov": spaces.Discrete(3),
"turn": spaces.Discrete(3)
})
def get_obs(self):
#Mock obs
return self.observation_space.sample()
def set_action(self, action):
return
class TrafficLight(BaseAgent):
def __init__(self, max_cars):
self.observation_space = spaces.Dict({
"cars": spaces.Box(low=-1, high=1, shape=(9 * max_cars,)),
"state": spaces.Box(low=-1, high=1, shape=(6,))
})
self.action_space = spaces.Dict({
"state": spaces.Discrete(3)
})
def get_obs(self):
#Mock obs
return self.observation_space.sample()
def set_action(self, action):
return
class TrafficSimulation():
def __init__(self):
self.num_cars = 4
self.num_traffic_lights = 2
self.cars = {"car_" + str(x): Car(self.num_traffic_lights) for x in range(self.num_cars)}
self.traffic_lights = {"traffic_light_" + str(x): TrafficLight(self.num_cars) for x in range(self.num_traffic_lights)}
# Agents with agent_id as key and BaseAgent Value
self.agents = {**self.cars, **self.traffic_lights}
def step(self, action_dict):
# Distribute action + simulate env
for agent_id, action in action_dict.items():
self.agents[agent_id].set_action(action)
# Here mock result (obs, rew, dones, infos)
return (
{agent_id: agent.get_obs() for agent_id, agent in self.agents.items()},
{agent_id: agent.get_rew() for agent_id, agent in self.agents.items()},
{agent_id: False for agent_id in self.agents.keys()},
{agent_id: {} for agent_id in self.agents.keys()}
)
def reset(self):
return {agent_id: agent.get_obs() for agent_id, agent in self.agents.items()}
################################
######### ENV
################################
class SpacesPerAgentTrafficMultiAgentEnv(MultiAgentEnv):
""" Env with 3 agents: 2 cars and 1 traffic light """
def __init__(self):
super().__init__()
self.trafficSimulation = TrafficSimulation()
#obs_spaces, act_spaces in multiAgentDict form
self.observation_space = spaces.Dict({agent_id: agent.observation_space for agent_id, agent in self.trafficSimulation.agents.items()})
self.action_space = spaces.Dict({agent_id: agent.action_space for agent_id, agent in self.trafficSimulation.agents.items()})
self._agent_ids = set(self.observation_space.keys())
self.steps = 0
def reset(self):
self.steps = 0
return self.trafficSimulation.reset()
def step(self, action_dict):
self.steps += 1
obs, rew, dones, infos = self.trafficSimulation.step(action_dict)
dones = {"__all__": self.steps >= 100}
return obs, rew, dones, infos
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.examples.policy.random_policy import RandomPolicy
def main():
tune.register_env(
"ExampleEnv",
lambda c: SpacesPerAgentTrafficMultiAgentEnv(),
)
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
# Fix after feedback
return "car" if agent_id[0] == "c" else "traffic_light"
ray.init()
tune.run(
"PPO",
stop={"episode_reward_mean": 200},
config={
"env": "ExampleEnv",
"num_gpus": 0,
"num_workers": 1,
"multiagent" :{
"policies": {
"main": PolicySpec(),
"random": PolicySpec(policy_class=RandomPolicy),
},
"policy_mapping_fn": policy_mapping_fn,
# FIx after feedback from sven
"policies_to_train": ["car", "traffic_light"]
},
"framework": "torch"
}
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment