Last active
May 10, 2022 11:15
-
-
Save TheTrope/5fe56bd683408cd9c18f503b6ea88e8d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gym import spaces | |
import ray | |
from ray import tune | |
from ray.rllib.env.multi_agent_env import MultiAgentEnv | |
################################### | |
##### Game simulation Mock | |
################################## | |
class BaseAgent(): | |
def get_obs(self): | |
raise NotImplementedError() | |
def set_action(self, action): | |
raise NotImplementedError() | |
def get_rew(self): | |
return 0 | |
class Car(BaseAgent): | |
def __init__(self, max_traffic_lights): | |
self.observation_space = spaces.Dict({ | |
"position": spaces.Box(low=-1, high=1, shape=(3,)), | |
"velocity": spaces.Box(low=-1, high=1, shape=(3,)), | |
"sensors": spaces.Box(low=-1, high=1, shape=(10,)), | |
"traffic_lights": spaces.Box(low=-1, high=1, shape=(3 * max_traffic_lights,)) | |
}) | |
self.action_space = spaces.Dict({ | |
"mov": spaces.Discrete(3), | |
"turn": spaces.Discrete(3) | |
}) | |
def get_obs(self): | |
#Mock obs | |
return self.observation_space.sample() | |
def set_action(self, action): | |
return | |
class TrafficLight(BaseAgent): | |
def __init__(self, max_cars): | |
self.observation_space = spaces.Dict({ | |
"cars": spaces.Box(low=-1, high=1, shape=(9 * max_cars,)), | |
"state": spaces.Box(low=-1, high=1, shape=(6,)) | |
}) | |
self.action_space = spaces.Dict({ | |
"state": spaces.Discrete(3) | |
}) | |
def get_obs(self): | |
#Mock obs | |
return self.observation_space.sample() | |
def set_action(self, action): | |
return | |
class TrafficSimulation(): | |
def __init__(self): | |
self.num_cars = 4 | |
self.num_traffic_lights = 2 | |
self.cars = {"car_" + str(x): Car(self.num_traffic_lights) for x in range(self.num_cars)} | |
self.traffic_lights = {"traffic_light_" + str(x): TrafficLight(self.num_cars) for x in range(self.num_traffic_lights)} | |
# Agents with agent_id as key and BaseAgent Value | |
self.agents = {**self.cars, **self.traffic_lights} | |
def step(self, action_dict): | |
# Distribute action + simulate env | |
for agent_id, action in action_dict.items(): | |
self.agents[agent_id].set_action(action) | |
# Here mock result (obs, rew, dones, infos) | |
return ( | |
{agent_id: agent.get_obs() for agent_id, agent in self.agents.items()}, | |
{agent_id: agent.get_rew() for agent_id, agent in self.agents.items()}, | |
{agent_id: False for agent_id in self.agents.keys()}, | |
{agent_id: {} for agent_id in self.agents.keys()} | |
) | |
def reset(self): | |
return {agent_id: agent.get_obs() for agent_id, agent in self.agents.items()} | |
################################ | |
######### ENV | |
################################ | |
class SpacesPerAgentTrafficMultiAgentEnv(MultiAgentEnv): | |
""" Env with 3 agents: 2 cars and 1 traffic light """ | |
def __init__(self): | |
super().__init__() | |
self.trafficSimulation = TrafficSimulation() | |
#obs_spaces, act_spaces in multiAgentDict form | |
self.observation_space = spaces.Dict({agent_id: agent.observation_space for agent_id, agent in self.trafficSimulation.agents.items()}) | |
self.action_space = spaces.Dict({agent_id: agent.action_space for agent_id, agent in self.trafficSimulation.agents.items()}) | |
self._agent_ids = set(self.observation_space.keys()) | |
self.steps = 0 | |
def reset(self): | |
self.steps = 0 | |
return self.trafficSimulation.reset() | |
def step(self, action_dict): | |
self.steps += 1 | |
obs, rew, dones, infos = self.trafficSimulation.step(action_dict) | |
dones = {"__all__": self.steps >= 100} | |
return obs, rew, dones, infos | |
from ray.rllib.policy.policy import PolicySpec | |
from ray.rllib.examples.policy.random_policy import RandomPolicy | |
def main(): | |
tune.register_env( | |
"ExampleEnv", | |
lambda c: SpacesPerAgentTrafficMultiAgentEnv(), | |
) | |
def policy_mapping_fn(agent_id, episode, worker, **kwargs): | |
# Fix after feedback | |
return "car" if agent_id[0] == "c" else "traffic_light" | |
ray.init() | |
tune.run( | |
"PPO", | |
stop={"episode_reward_mean": 200}, | |
config={ | |
"env": "ExampleEnv", | |
"num_gpus": 0, | |
"num_workers": 1, | |
"multiagent" :{ | |
"policies": { | |
"main": PolicySpec(), | |
"random": PolicySpec(policy_class=RandomPolicy), | |
}, | |
"policy_mapping_fn": policy_mapping_fn, | |
# FIx after feedback from sven | |
"policies_to_train": ["car", "traffic_light"] | |
}, | |
"framework": "torch" | |
} | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment