Skip to content

Instantly share code, notes, and snippets.

@standbyme
Created July 30, 2019 07:24
Show Gist options
  • Save standbyme/eb5ba94e589af96be681b30301b96f53 to your computer and use it in GitHub Desktop.
Save standbyme/eb5ba94e589af96be681b30301b96f53 to your computer and use it in GitHub Desktop.
Model Class with AutoGraph
class MyModel(tf.keras.Model):
def __init__(self, keep_probability=0.2):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4)
self.dense2 = tf.keras.layers.Dense(5)
self.keep_probability = keep_probability
@tf.function
def call(self, inputs, training=True):
y = self.dense2(self.dense1(inputs))
if training:
return tf.nn.dropout(y, self.keep_probability)
else:
return y
model = MyModel()
model(x, training=True) # executes a graph, with dropout
model(x, training=False) # executes a graph, without dropout
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment