Created
February 28, 2021 21:08
-
-
Save benblack769/7ea33e31d31a819d25bf64fb2bf1d430 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
from numpy import float32 | |
import os | |
from supersuit import dtype_v0 | |
import ray | |
from ray import tune | |
from ray.rllib.agents.registry import get_agent_class | |
from ray.rllib.env import PettingZooEnv | |
from pettingzoo.classic import leduc_holdem_v2 | |
import tensorflow as tf | |
from ray.rllib.models import ModelCatalog | |
from ray.tune.registry import register_env | |
from gym.spaces import Box | |
from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel | |
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC | |
from ray.rllib.utils.framework import try_import_tf, try_import_torch | |
from ray.rllib.utils.torch_ops import FLOAT_MIN, FLOAT_MAX | |
tf1, tf, tfv = try_import_tf() | |
torch, nn = try_import_torch() | |
class TorchMaskedActions(DQNTorchModel): | |
"""PyTorch version of above ParametricActionsModel.""" | |
def __init__(self, | |
obs_space, | |
action_space, | |
num_outputs, | |
model_config, | |
name, | |
**kw): | |
DQNTorchModel.__init__(self, obs_space, action_space, num_outputs, | |
model_config, name, **kw) | |
obs_len = obs_space.shape[0]-action_space.n | |
orig_obs_space = Box(shape=(obs_len,),low=obs_space.low[:obs_len],high=obs_space.high[:obs_len]) | |
self.action_embed_model = TorchFC(orig_obs_space, action_space, action_space.n, | |
model_config, name + "_action_embed") | |
def forward(self, input_dict, state, seq_lens): | |
# Extract the available actions tensor from the observation. | |
action_mask = input_dict["obs"]["action_mask"] | |
# Compute the predicted action embedding | |
action_logits, _ = self.action_embed_model({ | |
"obs": input_dict["obs"]['observation'] | |
}) | |
# turns probit action mask into logit action mask | |
inf_mask = torch.clamp(torch.log(action_mask), -1e10, FLOAT_MAX) | |
return action_logits + inf_mask, state | |
def value_function(self): | |
return self.action_embed_model.value_function() | |
if __name__ == "__main__": | |
alg_name = "DQN" | |
ModelCatalog.register_custom_model( | |
"pa_model", TorchMaskedActions) | |
# function that outputs the environment you wish to register. | |
def env_creator(): | |
env = leduc_holdem_v2.env() | |
return env | |
num_cpus = 1 | |
config = deepcopy(get_agent_class(alg_name)._default_config) | |
register_env("leduc_holdem", | |
lambda config: PettingZooEnv(env_creator())) | |
test_env = PettingZooEnv(env_creator()) | |
obs_space = test_env.observation_space | |
print(obs_space) | |
act_space = test_env.action_space | |
config["multiagent"] = { | |
"policies": { | |
"player_0": (None, obs_space, act_space, {}), | |
"player_1": (None, obs_space, act_space, {}), | |
}, | |
"policy_mapping_fn": lambda agent_id: agent_id | |
} | |
config["num_gpus"] = int(os.environ.get("RLLIB_NUM_GPUS", "0")) | |
config["log_level"] = "DEBUG" | |
config["num_workers"] = 1 | |
config["rollout_fragment_length"] = 30 | |
config["train_batch_size"] = 200 | |
config["horizon"] = 200 | |
config["no_done_at_end"] = False | |
config["framework"] = "torch" | |
config["model"] = { | |
"custom_model":"pa_model", | |
} | |
config['n_step'] = 1 | |
config["exploration_config"]= { | |
# The Exploration class to use. | |
"type": "EpsilonGreedy", | |
# Config for the Exploration class' constructor: | |
"initial_epsilon": 0.1, | |
"final_epsilon": 0.0, | |
"epsilon_timesteps": 100000, # Timesteps over which to anneal epsilon. | |
} | |
config['hiddens'] = [] | |
config['dueling'] = False | |
config['env'] = "leduc_holdem" | |
ray.init(num_cpus=num_cpus + 1) | |
tune.run( | |
alg_name, | |
name="DQN", | |
stop={"timesteps_total": 10000000}, | |
checkpoint_freq=10, | |
config=config | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment