Last active
February 16, 2018 15:21
-
-
Save peune/8767d796bf4165f427fbbf07f2cf7678 to your computer and use it in GitHub Desktop.
Given VGG16 model named org_model, construct a similar CNN named base_model using weights from org_model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# hyp: org_model is VGG16 | |
def create_base_model( org_model, input_shape ): | |
base_model = Sequential() | |
l = org_model.get_layer('block1_conv1') | |
base_model.add( Conv2D( l.filters, l.kernel_size, strides=l.strides, padding=l.padding, | |
use_bias=l.use_bias, data_format=l.data_format, activation=l.activation, | |
name='my_block1_conv1', | |
input_shape=input_shape ) ) | |
base_model.layers[-1].set_weights( l.get_weights() ) | |
l = org_model.get_layer('block1_conv2') | |
base_model.add( Conv2D( l.filters, l.kernel_size, strides=l.strides, padding=l.padding, | |
use_bias=l.use_bias, data_format=l.data_format, activation=l.activation, | |
name='my_block1_conv2' ) ) | |
base_model.layers[-1].set_weights( l.get_weights() ) | |
l = org_model.get_layer('block1_pool') | |
base_model.add( MaxPooling2D(pool_size=l.pool_size, padding=l.padding, strides=l.strides, | |
name='my_block1_pool' ) ) | |
l = org_model.get_layer('block2_conv1') | |
... | |
for l in base_model.layers: | |
l.trainable = False | |
base_model.compile( loss='mean_squared_error', optimizer='adam' ) | |
return base_model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment