Skip to content

Instantly share code, notes, and snippets.

@tanzhenyu
Last active August 23, 2019 17:56
Show Gist options
  • Save tanzhenyu/b8a26ec85130fa08f8b1cf975a8c5528 to your computer and use it in GitHub Desktop.
Save tanzhenyu/b8a26ec85130fa08f8b1cf975a8c5528 to your computer and use it in GitHub Desktop.
PPO Actor Critic Model
def mlp(ob_space, hidden_sizes=(32,), activation=tf.tanh, output_activation=None):
model = tf.keras.Sequential()
for h in hidden_sizes[:-1]:
model.add(tf.keras.layers.Dense(units=h, activation=activation))
model.add(tf.keras.layers.Dense(units=hidden_sizes[-1], activation=output_activation))
model.build(input_shape=(None,) + ob_space.shape)
return model
class MlpCategoricalActorCritic(tf.keras.Model):
def __init__(self, ob_space, ac_space, hidden_sizes=(64, 64), activation=tf.keras.activations.tanh, output_activation=None):
super(MlpCategoricalActorCritic, self).__init__()
self.act_dim = ac_space.n
with tf.name_scope('pi'):
self.actor_mlp = mlp(ob_space=ob_space, hidden_sizes=list(hidden_sizes)+[self.act_dim], activation=activation)
with tf.name_scope('v'):
self.critic_mlp = mlp(ob_space=ob_space, hidden_sizes=list(hidden_sizes)+[1], activation=activation)
@tf.function
def get_pi_logpi_vf(self, observations):
logits = self.actor_mlp(observations)
logp_all = tf.nn.log_softmax(logits)
pi = tf.squeeze(tf.random.categorical(logits, num_samples=1, seed=0), axis=1)
logp_pi = tf.reduce_sum(tf.one_hot(pi, depth=self.act_dim) * logp_all, axis=1)
vf = self.critic_mlp(observations)
return pi, logp_pi, vf
@tf.function
def get_logp(self, observations, actions):
logits = self.actor_mlp(observations)
logp_all = tf.nn.log_softmax(logits)
return tf.reduce_sum(tf.one_hot(actions, depth=self.act_dim) * logp_all, axis=1)
@tf.function
def get_v(self, observations):
return tf.squeeze(self.critic_mlp(observations), axis=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment