-
-
Save TheTrope/c0d8efafa3853caeb5a083a25af5bffe to your computer and use it in GitHub Desktop.
Smaller example to reproduce the multiagentdict spaces bug in ray 1.12
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gym import spaces | |
import ray | |
from ray import tune | |
from ray.rllib.env.multi_agent_env import MultiAgentEnv | |
import gym | |
from ray.rllib.policy.policy import PolicySpec | |
from ray.rllib.examples.policy.random_policy import RandomPolicy | |
class BasicMultiAgentMultiSpaces(MultiAgentEnv): | |
def __init__(self): | |
self.agents = {"agent0", "agent1"} | |
self.dones = set() | |
# Here i'm replacing the env spaces ie: self.observation_space = gym.spaces.Box(.....) to a multiAgentDict space | |
self.observation_space = gym.spaces.Dict({"agent0": gym.spaces.Box(low="-1", high=1, shape=(10,)), "agent1": gym.spaces.Box(low="-1", high=1, shape=(20,))}) | |
self.action_space = gym.spaces.Dict({"agent0": gym.spaces.Discrete(2), "agent1": gym.spaces.Discrete(3)}) | |
self._agent_ids = set(self.agents) | |
self._spaces_in_preferred_format = True | |
super().__init__() | |
def reset(self): | |
self.dones = set() | |
return {i: self.observation_space[i].sample() for i in self.agents} | |
def step(self, action_dict): | |
obs, rew, done, info = {}, {}, {}, {} | |
for i, action in action_dict.items(): | |
obs[i], rew[i], done[i], info[i] = self.observation_space[i].sample(), 0.0, False, {} | |
if done[i]: | |
self.dones.add(i) | |
done["__all__"] = len(self.dones) == len(self.agents) | |
print("step") | |
return obs, rew, done, info | |
def main(): | |
tune.register_env( | |
"ExampleEnv", | |
lambda c: BasicMultiAgentMultiSpaces() | |
) | |
def policy_mapping_fn(agent_id, episode, worker, **kwargs): | |
# Fix here after feedback from sven | |
return "main0" if agent_id == "agent0" else "main1" | |
ray.init() | |
tune.run( | |
"PPO", | |
stop={"episode_reward_mean": 200}, | |
config={ | |
"env": "ExampleEnv", | |
"num_gpus": 0, | |
"num_workers": 1, | |
"multiagent" :{ | |
"policies": { | |
"main": PolicySpec(), | |
"random": PolicySpec(policy_class=RandomPolicy), | |
}, | |
"policy_mapping_fn": policy_mapping_fn, | |
# Fix here | |
"policies_to_train": ["main0", "main1"] | |
}, | |
"framework": "torch" | |
} | |
) | |
if __name__ == "__main__": | |
main() | |
# ValueError: The two structures don't have the same nested structure. | |
# First structure: type=ndarray str=[ 0.9564604 -0.5021941 -0.5057076 0.25486383 0.46681443 0.8739439 | |
# 0.5189227 0.24640203 -0.6063386 -0.69826615 -0.9067742 -0.5668338 | |
# 0.74380994 -0.848033 -0.9815409 -0.5252794 -0.00128481 -0.18890348 | |
# -0.40913624 0.10361464] | |
# Second structure: type=OrderedDict str=OrderedDict([('agent0', array([ 0.40599328, -0.90823233, -0.90186906, -0.00112456, -0.82994014, | |
# 0.94780207, 0.12449201, -0.043993 , -0.22595881, 0.7862879 ], | |
# dtype=float32)), ('agent1', array([ 0.30436766, 0.32244647, 0.86386913, -0.39648414, 0.98302525, | |
# -0.5297048 , -0.9484631 , 0.3249894 , -0.45240116, 0.3745865 , | |
# -0.709602 , 0.34890667, 0.12114473, -0.87268084, -0.95431393, | |
# -0.79948044, 0.4677822 , -0.15980895, 0.9318225 , -0.9016584 ], | |
# dtype=float32))]) | |
# More specifically: Substructure "type=OrderedDict str=OrderedDict([('agent0', array([ 0.40599328, -0.90823233, -0.90186906, -0.00112456, -0.82994014, | |
# 0.94780207, 0.12449201, -0.043993 , -0.22595881, 0.7862879 ], | |
# dtype=float32)), ('agent1', array([ 0.30436766, 0.32244647, 0.86386913, -0.39648414, 0.98302525, | |
# -0.5297048 , -0.9484631 , 0.3249894 , -0.45240116, 0.3745865 , | |
# -0.709602 , 0.34890667, 0.12114473, -0.87268084, -0.95431393, | |
# -0.79948044, 0.4677822 , -0.15980895, 0.9318225 , -0.9016584 ], | |
# dtype=float32))])" is a sequence, while substructure "type=ndarray str=[ 0.9564604 -0.5021941 -0.5057076 0.25486383 0.46681443 0.8739439 | |
# 0.5189227 0.24640203 -0.6063386 -0.69826615 -0.9067742 -0.5668338 | |
# 0.74380994 -0.848033 -0.9815409 -0.5252794 -0.00128481 -0.18890348 | |
# -0.40913624 0.10361464]" is not | |
# Entire first structure: | |
# . | |
# Entire second structure: | |
# OrderedDict([('agent0', .), ('agent1', .)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment