Skip to content

Instantly share code, notes, and snippets.

@lxuechen
Last active August 3, 2018 19:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lxuechen/c113f74e7ef059294f05606a8f23b54c to your computer and use it in GitHub Desktop.
Save lxuechen/c113f74e7ef059294f05606a8f23b54c to your computer and use it in GitHub Desktop.
define residual block
class Residual(tf.keras.Model):
def __init__(self, filters):
super(Residual, self).__init__()
self.f = ResidualInner(filters=filters, strides=(1, 1))
self.g = ResidualInner(filters=filters, strides=(1, 1))
def call(self, x, training=True):
x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis)
f_x2 = self.f(x2, training=training)
y1 = f_x2 + x1
g_y1 = self.g(y1, training=training)
y2 = g_y1 + x2
return tf.concat([y1, y2], axis=self.axis)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment