-
-
Save kk-55/490d64e0cf4b8942f15c0825dbef2d2b to your computer and use it in GitHub Desktop.
Cartpole_server/client example demonstrating consequences of different framework settings on a log_action call
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
"""Example of training with a policy server. Copy this file for your use case. | |
To try this out, in two separate shells run: | |
$ python cartpole_server.py --run=[PPO|DQN] | |
$ python cartpole_client.py --inference-mode=local|remote | |
Local inference mode offloads inference to the client for better performance. | |
""" | |
import argparse | |
import gym | |
from ray.rllib.env.policy_client import PolicyClient | |
from ray.rllib.models import ModelCatalog | |
from ray.rllib.examples.custom_keras_model import MyKerasModel | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--no-train", action="store_true", help="Whether to disable training.") | |
parser.add_argument( | |
"--inference-mode", type=str, default="local", choices=["local", "remote"]) | |
parser.add_argument( | |
"--off-policy", | |
action="store_true", | |
help="Whether to take random instead of on-policy actions.") | |
parser.add_argument( | |
"--stop-reward", | |
type=int, | |
default=9999, | |
help="Stop once the specified reward is reached.") | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
# args.inference_mode = "remote" | |
args.off_policy = True | |
ModelCatalog.register_custom_model( | |
"keras_model", MyKerasModel) | |
env = gym.make("CartPole-v0") | |
client = PolicyClient( | |
"http://localhost:9900", inference_mode=args.inference_mode) | |
eid = client.start_episode(training_enabled=not args.no_train) | |
obs = env.reset() | |
rewards = 0 | |
while True: | |
if args.off_policy: | |
action = env.action_space.sample() | |
client.log_action(eid, obs, action) | |
else: | |
action = client.get_action(eid, obs) | |
obs, reward, done, info = env.step(action) | |
rewards += reward | |
client.log_returns(eid, reward, info=info) | |
if done: | |
print("Total reward:", rewards) | |
if rewards >= args.stop_reward: | |
print("Target reward achieved, exiting") | |
exit(0) | |
rewards = 0 | |
client.end_episode(eid, obs) | |
obs = env.reset() | |
eid = client.start_episode(training_enabled=not args.no_train) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
"""Example of running a policy server. Copy this file for your use case. | |
To try this out, in two separate shells run: | |
$ python cartpole_server.py | |
$ python cartpole_client.py --inference-mode=local|remote | |
""" | |
import argparse | |
import os | |
import ray | |
from ray.rllib.agents.dqn import DQNTrainer | |
from ray.rllib.agents.ppo import PPOTrainer | |
from ray.rllib.env.policy_server_input import PolicyServerInput | |
from ray.rllib.examples.custom_metrics_and_callbacks import MyCallbacks | |
from ray.tune.logger import pretty_print | |
from ray.rllib.models import ModelCatalog | |
from ray.rllib.examples.custom_keras_model import MyKerasModel | |
SERVER_ADDRESS = "localhost" | |
SERVER_PORT = 9900 | |
CHECKPOINT_FILE = "last_checkpoint_{}.out" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--run", type=str, default="DQN") | |
parser.add_argument( | |
"--framework", type=str, choices=["tf", "torch"], default="tf") | |
parser.add_argument( | |
"--no-restore", | |
action="store_true", | |
help="Do not restore from a previously saved checkpoint (location of " | |
"which is saved in `last_checkpoint_[algo-name].out`).") | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
args.run = "PPO" | |
args.no_restore = True | |
args.framework = "tf2" | |
ray.init() | |
ModelCatalog.register_custom_model( | |
"keras_model", MyKerasModel) | |
env = "CartPole-v0" | |
connector_config = { | |
# Use the connector server to generate experiences. | |
"input": ( | |
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT) | |
), | |
# Use a single worker process to run the server. | |
"num_workers": 0, | |
# Disable OPE, since the rollouts are coming from online clients. | |
"input_evaluation": [], | |
"callbacks": MyCallbacks, | |
} | |
if args.run == "DQN": | |
# Example of using DQN (supports off-policy actions). | |
trainer = DQNTrainer( | |
env=env, | |
config=dict( | |
connector_config, **{ | |
"learning_starts": 100, | |
"timesteps_per_iteration": 200, | |
"framework": args.framework, | |
})) | |
elif args.run == "PPO": | |
# Example of using PPO (does NOT support off-policy actions). | |
trainer = PPOTrainer( | |
env=env, | |
config=dict( | |
connector_config, **{ | |
"rollout_fragment_length": 100, | |
"train_batch_size": 400, | |
"framework": args.framework, | |
"model": { | |
"custom_model": "keras_model" | |
}, | |
})) | |
else: | |
raise ValueError("--run must be DQN or PPO") | |
checkpoint_path = CHECKPOINT_FILE.format(args.run) | |
# Attempt to restore from checkpoint, if possible. | |
if not args.no_restore and os.path.exists(checkpoint_path): | |
checkpoint_path = open(checkpoint_path).read() | |
print("Restoring from checkpoint path", checkpoint_path) | |
trainer.restore(checkpoint_path) | |
# Serving and training loop. | |
while True: | |
print(pretty_print(trainer.train())) | |
checkpoint = trainer.save() | |
print("Last checkpoint", checkpoint) | |
with open(checkpoint_path, "w") as f: | |
f.write(checkpoint) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Example of using a custom ModelV2 Keras-style model.""" | |
import argparse | |
import os | |
import ray | |
from ray import tune | |
from ray.rllib.agents.dqn.distributional_q_tf_model import \ | |
DistributionalQTFModel | |
from ray.rllib.models import ModelCatalog | |
from ray.rllib.models.tf.misc import normc_initializer | |
from ray.rllib.models.tf.tf_modelv2 import TFModelV2 | |
from ray.rllib.models.tf.visionnet import VisionNetwork as MyVisionNetwork | |
from ray.rllib.policy.policy import LEARNER_STATS_KEY | |
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID | |
from ray.rllib.utils.framework import try_import_tf | |
tf1, tf, tfv = try_import_tf() | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--run", | |
type=str, | |
default="DQN", | |
help="The RLlib-registered algorithm to use.") | |
parser.add_argument("--stop", type=int, default=200) | |
parser.add_argument("--use-vision-network", action="store_true") | |
parser.add_argument("--num-cpus", type=int, default=0) | |
class MyKerasModel(TFModelV2): | |
"""Custom model for policy gradient algorithms.""" | |
def __init__(self, obs_space, action_space, num_outputs, model_config, | |
name): | |
super(MyKerasModel, self).__init__(obs_space, action_space, | |
num_outputs, model_config, name) | |
self.inputs = tf.keras.layers.Input( | |
shape=obs_space.shape, name="observations") | |
layer_1 = tf.keras.layers.Dense( | |
128, | |
name="my_layer1", | |
activation=tf.nn.relu, | |
kernel_initializer=normc_initializer(1.0))(self.inputs) | |
layer_out = tf.keras.layers.Dense( | |
num_outputs, | |
name="my_out", | |
activation=None, | |
kernel_initializer=normc_initializer(0.01))(layer_1) | |
value_out = tf.keras.layers.Dense( | |
1, | |
name="value_out", | |
activation=None, | |
kernel_initializer=normc_initializer(0.01))(layer_1) | |
self.base_model = tf.keras.Model(self.inputs, [layer_out, value_out]) | |
def forward(self, input_dict, state, seq_lens): | |
print() | |
print("Calling NN model") | |
print() | |
model_out, self._value_out = self.base_model(input_dict["obs"]) | |
print("NN model called") | |
print() | |
return model_out, state | |
def value_function(self): | |
return tf.reshape(self._value_out, [-1]) | |
def metrics(self): | |
return {"foo": tf.constant(42.0)} | |
class MyKerasQModel(DistributionalQTFModel): | |
"""Custom model for DQN.""" | |
def __init__(self, obs_space, action_space, num_outputs, model_config, | |
name, **kw): | |
super(MyKerasQModel, self).__init__( | |
obs_space, action_space, num_outputs, model_config, name, **kw) | |
# Define the core model layers which will be used by the other | |
# output heads of DistributionalQModel | |
self.inputs = tf.keras.layers.Input( | |
shape=obs_space.shape, name="observations") | |
layer_1 = tf.keras.layers.Dense( | |
128, | |
name="my_layer1", | |
activation=tf.nn.relu, | |
kernel_initializer=normc_initializer(1.0))(self.inputs) | |
layer_out = tf.keras.layers.Dense( | |
num_outputs, | |
name="my_out", | |
activation=tf.nn.relu, | |
kernel_initializer=normc_initializer(1.0))(layer_1) | |
self.base_model = tf.keras.Model(self.inputs, layer_out) | |
# Implement the core forward method. | |
def forward(self, input_dict, state, seq_lens): | |
model_out = self.base_model(input_dict["obs"]) | |
return model_out, state | |
def metrics(self): | |
return {"foo": tf.constant(42.0)} | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
ray.init(num_cpus=args.num_cpus or None) | |
ModelCatalog.register_custom_model( | |
"keras_model", MyVisionNetwork | |
if args.use_vision_network else MyKerasModel) | |
ModelCatalog.register_custom_model( | |
"keras_q_model", MyVisionNetwork | |
if args.use_vision_network else MyKerasQModel) | |
# Tests https://github.com/ray-project/ray/issues/7293 | |
def check_has_custom_metric(result): | |
r = result["result"]["info"]["learner"] | |
if DEFAULT_POLICY_ID in r: | |
r = r[DEFAULT_POLICY_ID].get(LEARNER_STATS_KEY, | |
r[DEFAULT_POLICY_ID]) | |
assert r["model"]["foo"] == 42, result | |
if args.run == "DQN": | |
extra_config = {"learning_starts": 0} | |
else: | |
extra_config = {} | |
tune.run( | |
args.run, | |
stop={"episode_reward_mean": args.stop}, | |
config=dict( | |
extra_config, | |
**{ | |
"env": "BreakoutNoFrameskip-v4" | |
if args.use_vision_network else "CartPole-v0", | |
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. | |
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), | |
"callbacks": { | |
"on_train_result": check_has_custom_metric, | |
}, | |
"model": { | |
"custom_model": "keras_q_model" | |
if args.run == "DQN" else "keras_model" | |
}, | |
"framework": "tf", | |
})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment