Skip to content

Instantly share code, notes, and snippets.

@horoiwa
Created May 30, 2020 07:18
Show Gist options
  • Save horoiwa/6e522ac7e12513554cdc8bc90df799f8 to your computer and use it in GitHub Desktop.
Save horoiwa/6e522ac7e12513554cdc8bc90df799f8 to your computer and use it in GitHub Desktop.
mport tensorflow as tf
import tensorflow.keras.layers as kl
import tensorflow_probability as tfp
import numpy as np
class ActorCriticNet(tf.keras.Model):
def __init__(self, action_space, lr=0.00005):
super(ActorCriticNet, self).__init__()
self.action_space = action_space
self.conv1 = kl.Conv2D(32, 8, strides=4, activation="relu",
kernel_initializer="he_normal")
self.conv2 = kl.Conv2D(64, 4, strides=2, activation="relu",
kernel_initializer="he_normal")
self.conv3 = kl.Conv2D(64, 3, strides=1, activation="relu",
kernel_initializer="he_normal")
self.flat1 = kl.Flatten()
self.dense1 = kl.Dense(512, activation="relu",
kernel_initializer="he_normal")
self.logits = kl.Dense(self.action_space,
kernel_initializer="he_normal")
self.values = kl.Dense(1, kernel_initializer="he_normal")
self.optimizer = tf.keras.optimizers.Adam(lr=lr)
def call(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.flat1(x)
x = self.dense1(x)
logits = self.logits(x)
values = self.values(x)
return values, logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment