Skip to content

Instantly share code, notes, and snippets.

@jamesonthecrow
Last active November 27, 2018 23:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesonthecrow/f2c826fc504a408abf8e9ff2ddf5eee0 to your computer and use it in GitHub Desktop.
Save jamesonthecrow/f2c826fc504a408abf8e9ff2ddf5eee0 to your computer and use it in GitHub Desktop.
A smaller style transfer network.
@classmethod
def build(
cls,
image_size,
alpha=1.0,
input_tensor=None,
checkpoint_file=None):
"""Build a Small Transfer Network Model using keras' functional API.
This architecture removes some blocks of layers and reduces the size
of convolutions to save on computation.
Args:
image_size - the size of the input and output image (H, W)
alpha - a width parameter to scale the number of channels by
Returns:
model: a keras model object
"""
x = keras.layers.Input(
shape=(image_size[0], image_size[1], 3), tensor=input_tensor)
out = cls._convolution(x, int(alpha * 32), 9, strides=1)
out = cls._convolution(out, int(alpha * 32), 3, strides=2)
out = cls._convolution(out, int(alpha * 32), 3, strides=2)
out = cls._residual_block(out, int(alpha * 32))
out = cls._residual_block(out, int(alpha * 32))
out = cls._residual_block(out, int(alpha * 32))
out = cls._upsample(out, int(alpha * 32), 3)
out = cls._upsample(out, int(alpha * 32), 3)
out = cls._convolution(out, 3, 9, relu=False, padding='same')
# Restrict outputs of pixel values to -1 and 1.
out = keras.layers.Activation('tanh')(out)
# Deprocess the image into valid image data. Note we'll need to define
# a custom layer for this in Core ML as well.
out = layers.DeprocessStylizedImage()(out)
model = keras.models.Model(inputs=x, outputs=out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment