Created
May 10, 2020 18:45
-
-
Save sol0invictus/344747cd358a27422bfc123251327880 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# We will look at two different implementations. | |
# The first one is straightforward | |
model = Sequential() | |
model.add(Conv2D(100,3,padding='valid',activation='relu',strides=1,input_shape=(52,52, 1))) | |
model.add(Conv2D(1,1,activation='sigmoid')) | |
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) | |
model.summary() | |
# The second one is almost the same with an added feature of taking multiple time-steps | |
class CAtf(tf.keras.Model): | |
def __init__(self,num_states): | |
super(CAtf, self).__init__() | |
self.conv1 = tf.keras.layers.Conv2D(100,kernel_size=3,padding='same') | |
self.conv2 = tf.keras.layers.Conv2D(num_states,1,padding='valid') | |
def call(self, x, steps=1): | |
for _ in range(steps): | |
x = tf.nn.softmax(x, axis=1) | |
x = self.conv1(x) | |
x = self.conv2(x) | |
return x | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment