Skip to content

Instantly share code, notes, and snippets.

@benblack769
Created February 28, 2021 21:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benblack769/7ea33e31d31a819d25bf64fb2bf1d430 to your computer and use it in GitHub Desktop.
Save benblack769/7ea33e31d31a819d25bf64fb2bf1d430 to your computer and use it in GitHub Desktop.
from copy import deepcopy
from numpy import float32
import os
from supersuit import dtype_v0
import ray
from ray import tune
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.env import PettingZooEnv
from pettingzoo.classic import leduc_holdem_v2
import tensorflow as tf
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from gym.spaces import Box
from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_ops import FLOAT_MIN, FLOAT_MAX
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
class TorchMaskedActions(DQNTorchModel):
"""PyTorch version of above ParametricActionsModel."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
**kw):
DQNTorchModel.__init__(self, obs_space, action_space, num_outputs,
model_config, name, **kw)
obs_len = obs_space.shape[0]-action_space.n
orig_obs_space = Box(shape=(obs_len,),low=obs_space.low[:obs_len],high=obs_space.high[:obs_len])
self.action_embed_model = TorchFC(orig_obs_space, action_space, action_space.n,
model_config, name + "_action_embed")
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
action_mask = input_dict["obs"]["action_mask"]
# Compute the predicted action embedding
action_logits, _ = self.action_embed_model({
"obs": input_dict["obs"]['observation']
})
# turns probit action mask into logit action mask
inf_mask = torch.clamp(torch.log(action_mask), -1e10, FLOAT_MAX)
return action_logits + inf_mask, state
def value_function(self):
return self.action_embed_model.value_function()
if __name__ == "__main__":
alg_name = "DQN"
ModelCatalog.register_custom_model(
"pa_model", TorchMaskedActions)
# function that outputs the environment you wish to register.
def env_creator():
env = leduc_holdem_v2.env()
return env
num_cpus = 1
config = deepcopy(get_agent_class(alg_name)._default_config)
register_env("leduc_holdem",
lambda config: PettingZooEnv(env_creator()))
test_env = PettingZooEnv(env_creator())
obs_space = test_env.observation_space
print(obs_space)
act_space = test_env.action_space
config["multiagent"] = {
"policies": {
"player_0": (None, obs_space, act_space, {}),
"player_1": (None, obs_space, act_space, {}),
},
"policy_mapping_fn": lambda agent_id: agent_id
}
config["num_gpus"] = int(os.environ.get("RLLIB_NUM_GPUS", "0"))
config["log_level"] = "DEBUG"
config["num_workers"] = 1
config["rollout_fragment_length"] = 30
config["train_batch_size"] = 200
config["horizon"] = 200
config["no_done_at_end"] = False
config["framework"] = "torch"
config["model"] = {
"custom_model":"pa_model",
}
config['n_step'] = 1
config["exploration_config"]= {
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 0.1,
"final_epsilon": 0.0,
"epsilon_timesteps": 100000, # Timesteps over which to anneal epsilon.
}
config['hiddens'] = []
config['dueling'] = False
config['env'] = "leduc_holdem"
ray.init(num_cpus=num_cpus + 1)
tune.run(
alg_name,
name="DQN",
stop={"timesteps_total": 10000000},
checkpoint_freq=10,
config=config
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment