Skip to content

Instantly share code, notes, and snippets.

@matt-gardner
Created March 22, 2017 22:25
Show Gist options
  • Save matt-gardner/693f869088f30dda5abf9e66573d5afa to your computer and use it in GitHub Desktop.
Save matt-gardner/693f869088f30dda5abf9e66573d5afa to your computer and use it in GitHub Desktop.
import numpy
from keras.models import Model, load_model
from keras.layers import Input, Concatenate, Dense
def build_first_model():
input_layer = Input(shape=(5,), name="feature_input")
hidden_layer = input_layer
for i in range(3):
hidden_layer = Dense(4, activation='relu', name='hidden_layer_{}'.format(i))(hidden_layer)
final_layer = Dense(1, activation='sigmoid')(hidden_layer)
return Model(inputs=input_layer, outputs=final_layer)
def build_dependent_model(model):
input_layer = Input(shape=(5,), name="feature_input")
additional_features = Input(shape=(3,), name="new_feature_input")
partial_model = repurpose_model(model, ['feature_input'], ['hidden_layer_2'])
hidden_features = partial_model(input_layer)
duplicated_partial_model = repurpose_model(model, ['feature_input'], ['hidden_layer_1'])
duplicated_hidden_features = duplicated_partial_model(input_layer)
final_features = Concatenate()([hidden_features, duplicated_hidden_features, additional_features])
final_layer = Dense(1, activation='sigmoid')(final_features)
return Model(inputs=[input_layer, additional_features], outputs=final_layer)
def repurpose_model(model, input_layer_names, output_layer_names):
layer_input_dict = {}
layer_output_dict = {}
for layer in model.layers:
layer_input_dict[layer.name] = layer.get_input_at(0)
layer_output_dict[layer.name] = layer.get_output_at(0)
input_layers = [layer_input_dict[name] for name in input_layer_names]
output_layers = [layer_output_dict[name] for name in output_layer_names]
return Model(inputs=input_layers, outputs=output_layers)
def main():
first_model = build_first_model()
first_model_input = numpy.random.rand(10, 5)
first_model_output = numpy.random.randint(0, 2, (10,))
first_model.compile('adam', 'binary_crossentropy')
first_model.fit(first_model_input, first_model_output)
first_model.save("./tmp_model.h5")
loaded_first_model = load_model("./tmp_model.h5")
second_model = build_dependent_model(loaded_first_model)
second_model_input = numpy.random.rand(10, 3)
second_model.compile('adam', 'binary_crossentropy')
second_model.fit([first_model_input, second_model_input], first_model_output)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment