Skip to content

Instantly share code, notes, and snippets.

@peune
Last active February 16, 2018 15:21
Show Gist options
  • Save peune/8767d796bf4165f427fbbf07f2cf7678 to your computer and use it in GitHub Desktop.
Save peune/8767d796bf4165f427fbbf07f2cf7678 to your computer and use it in GitHub Desktop.
Given VGG16 model named org_model, construct a similar CNN named base_model using weights from org_model
# hyp: org_model is VGG16
def create_base_model( org_model, input_shape ):
base_model = Sequential()
l = org_model.get_layer('block1_conv1')
base_model.add( Conv2D( l.filters, l.kernel_size, strides=l.strides, padding=l.padding,
use_bias=l.use_bias, data_format=l.data_format, activation=l.activation,
name='my_block1_conv1',
input_shape=input_shape ) )
base_model.layers[-1].set_weights( l.get_weights() )
l = org_model.get_layer('block1_conv2')
base_model.add( Conv2D( l.filters, l.kernel_size, strides=l.strides, padding=l.padding,
use_bias=l.use_bias, data_format=l.data_format, activation=l.activation,
name='my_block1_conv2' ) )
base_model.layers[-1].set_weights( l.get_weights() )
l = org_model.get_layer('block1_pool')
base_model.add( MaxPooling2D(pool_size=l.pool_size, padding=l.padding, strides=l.strides,
name='my_block1_pool' ) )
l = org_model.get_layer('block2_conv1')
...
for l in base_model.layers:
l.trainable = False
base_model.compile( loss='mean_squared_error', optimizer='adam' )
return base_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment