Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Created May 24, 2020 06:54
Show Gist options
  • Save horoiwa/8efa5edb6eceef140760fc69a84f7862 to your computer and use it in GitHub Desktop.
Save horoiwa/8efa5edb6eceef140760fc69a84f7862 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import tensorflow.keras.layers as kl
import tensorflow_probability as tfp
import numpy as np
class ActorCriticNet(tf.keras.Model):
def __init__(self, action_space=2):
super(ActorCriticNet, self).__init__()
self.action_space = action_space
self.dense1 = kl.Dense(100, activation="relu")
self.dense2 = kl.Dense(100, activation="relu")
self.values = kl.Dense(1)
self.policy_logits = kl.Dense(action_space)
def call(self, x):
x1 = self.dense1(x)
logits = self.policy_logits(x1)
x2 = self.dense2(x)
values = self.values(x2)
return values, logits
def sample_action(self, state):
state = tf.convert_to_tensor(np.atleast_2d(state), dtype=tf.float32)
_, logits = self(state)
action_probs = tf.nn.softmax(logits)
cdist = tfp.distributions.Categorical(probs=action_probs)
action = cdist.sample()
return action.numpy()[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment