Skip to content

Instantly share code, notes, and snippets.

@tmorgan4
Created September 22, 2019 04:25
Show Gist options
  • Save tmorgan4/c59f82a49e8d4a300f47f11e745d1f42 to your computer and use it in GitHub Desktop.
Save tmorgan4/c59f82a49e8d4a300f47f11e745d1f42 to your computer and use it in GitHub Desktop.
"""Example of using a custom RNN keras model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' # sometimes required for Anaconda installations due to conflict in MKL linking
import numpy as np
import ray
from ray import tune
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.recurrent_tf_modelv2 import RecurrentTFModelV2
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
class RecurrentPGModel(RecurrentTFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kw):
super(RecurrentPGModel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw)
rgb_input = tf.keras.layers.Input(shape=obs_space.shape, name="rgb_input", dtype=tf.float32)
x = rgb_input
x = tf.keras.layers.Conv2D(32, 8, strides=(4, 4), activation=tf.nn.relu, padding="valid", name="conv11")(x)
x = tf.keras.layers.Conv2D(64, 4, strides=(2, 2), activation=tf.nn.relu, padding="valid", name="conv12")(x)
x = tf.keras.layers.Conv2D(64, 3, strides=(1, 1), activation=tf.nn.relu, padding="valid", name="conv13")(x)
x = tf.keras.layers.Flatten()(x)
rgb_h = tf.keras.layers.Dense(512, name="rgb_h", activation=tf.nn.relu)(x)
map_input = tf.keras.layers.Input(shape=obs_space.shape, name="map_input", dtype=tf.float32)
x = map_input
x = tf.keras.layers.Conv2D(32, 8, strides=(4, 4), activation=tf.nn.elu, padding="valid", name="conv21")(x)
x = tf.keras.layers.Conv2D(64, 4, strides=(2, 2), activation=tf.nn.elu, padding="valid", name="conv22")(x)
x = tf.keras.layers.Conv2D(64, 3, strides=(1, 1), activation=tf.nn.elu, padding="valid", name="conv23")(x)
x = tf.keras.layers.Flatten()(x)
map_h = tf.keras.layers.Dense(512, name="map_h", activation=tf.nn.relu)(x)
h = tf.concat([rgb_h, 0 * map_h], 1)
h = tf.keras.layers.Dense(1024, activation=tf.nn.relu)(h)
state_in_h = tf.keras.layers.Input(shape=(model_config['lstm_cell_size'],), name="state_in_h", dtype=tf.float32)
seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
h = add_time_dimension(h, seq_in)
mask = tf.sequence_mask(seq_in)
gru_out, state_h = tf.keras.layers.GRU(
model_config['lstm_cell_size'], return_sequences=True, return_state=True, name="gru")(h,
mask=mask,
initial_state=state_in_h)
pi_out = tf.keras.layers.Dense(num_outputs, name="pi_out", activation=None)(gru_out)
v_out = tf.keras.layers.Dense(1, name="v_out", activation=None)(gru_out)
self.base_model = tf.keras.Model(inputs=[rgb_input, map_input, seq_in, state_in_h],
outputs=[pi_out, v_out, state_h])
self.register_variables(self.base_model.variables)
self.base_model.summary()
@override(ModelV2)
def value_function(self):
return tf.reshape(self.value, [-1])
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
""" Processes inputs prior to calling to forward_rnn().
By default this adds time dimension to batch.
If your first layer is not RNN then you probably want to override this behavior. """
output, new_state = self.forward_rnn(input_dict, state, seq_lens)
return tf.reshape(output, [-1, self.num_outputs]), new_state
@override(RecurrentTFModelV2)
def forward_rnn(self, input_dict, state, seq_lens):
model_in = [input_dict['obs'], input_dict['obs'], seq_lens, state]
model_out, self.value, h = self.base_model(model_in)
return model_out, [h]
@override(ModelV2)
def get_initial_state(self):
return [np.zeros(self.model_config['lstm_cell_size'], np.float32)]
if __name__ == "__main__":
# running ray with num_cpus and local_mode allows usage of debugger
ray.init(num_cpus=0, local_mode=True, logging_level='DEBUG')
# ray.init()
ModelCatalog.register_custom_model("RecurrentPGModel", RecurrentPGModel)
tune.run(
"PPO",
stop={"timesteps_total": 10000},
config={
"env": 'BreakoutNoFrameskip-v4',
"log_level": "DEBUG",
"num_workers": 0, # must be 0 if running in local mode
"num_envs_per_worker": 1,
"num_sgd_iter": 1,
"model": {
"custom_model": "RecurrentPGModel",
"lstm_cell_size": 128,
},
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment