Skip to content

Instantly share code, notes, and snippets.

@ericl
Created July 24, 2019 01:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ericl/6501eb32054c1e000dbd7ba2492ff9b1 to your computer and use it in GitHub Desktop.
Save ericl/6501eb32054c1e000dbd7ba2492ff9b1 to your computer and use it in GitHub Desktop.
class MaskingLayerRNNmodel(TFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kw):
super(MaskingLayerRNNmodel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw)
self.initialize_lstm_with_prev_state = model_config['custom_options']['initialize_lstm_with_prev_state']
self.input_layer = tf.keras.layers.Input(
shape=(None, obs_space.shape[0]),
name='inputLayer')
self.state_in_c = tf.keras.layers.Input(
shape=(model_config['lstm_cell_size']),
name='c')
self.state_in_h = tf.keras.layers.Input(
shape=(model_config['lstm_cell_size']),
name='h')
self.seq_in = tf.keras.layers.Input(
shape=(),
name='seqLens')
dense_layer_1 = tf.keras.layers.Dense(
model_config['fcnet_hiddens'][0],
activation=tf.nn.relu,
name='denseLayer1')(self.input_layer)
# masking_layer = tf.keras.layers.Masking(
# mask_value=0.0)(dense_layer_1)
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
model_config['lstm_cell_size'],
return_sequences=True,
return_state=True,
name='lstmLayer')(inputs=dense_layer_1,
mask=tf.sequence_mask(self.seq_in) if self.model_config['max_seq_len'] > 1 else None,
initial_state=[self.state_in_c, self.state_in_h]) # note that initial_states=None (not correct), how could we pass 'state' here?
# if we had access to batch shape, we could set stateful=True in LSTM and call reset_states() instead of passing state
# reshape_layer does not accept mask which is propogated through model if Masking() is used upstream, FAILS!
reshape_layer = tf.keras.layers.Lambda(
lambda x: tf.reshape(x, [-1, model_config['lstm_cell_size']]))(lstm_out)
dense_layer_2 = tf.keras.layers.Dense(
model_config['fcnet_hiddens'][1],
activation=tf.nn.relu,
name='denseLayer2')(lstm_out)
logits_layer = tf.keras.layers.Dense(
self.num_outputs,
activation=tf.keras.activations.linear,
name='logitsLayer')(dense_layer_2)
value_layer = tf.keras.layers.Dense(
1,
activation=None,
name='valueLayer')(dense_layer_2)
state = [state_h, state_c]
self.base_model = tf.keras.Model(inputs=[self.input_layer, self.state_in_c, self.state_in_h, self.seq_in], outputs=[logits_layer, value_layer, state_h, state_c])
self.register_variables(self.base_model.variables)
self.base_model.summary()
# Implement the core forward method
def forward(self, input_dict, state, seq_lens):
x = input_dict['obs']
if x._rank() < 3:
x = add_time_dimension(x, seq_lens)
logits, self._value_out, h, c = self.base_model((x, state[0], state[1], seq_lens))
return tf.reshape(logits, [-1, 2]), [h, c]
def get_initial_state(self):
return [np.zeros(self.model_config['lstm_cell_size'], np.float32),
np.zeros(self.model_config['lstm_cell_size'], np.float32)]
def value_function(self):
return tf.reshape(self._value_out, [-1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment