Last active
June 25, 2019 13:10
-
-
Save cshimmin/546472887feea457984325c075d4dabe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def some_mode(): | |
auto_input = K.Input((whatever,)) | |
auto_z = encoder(auto_input) | |
auto_output = decoder(auto_z) | |
model = Model(input_layer, output_layer) | |
loss1_ = K.mean(K.square(auto_input - auto_output)) | |
loss2_ = K.mean(K.square(auto_z)) | |
model.hp_loss2_weight = K.variable(1.0, name='loss2_weight') | |
# instead of using model.add_loss(), we have to make a | |
# "dummy" loss function that looks like a function of | |
# true vs. predicted outputs | |
def loss_total(y_true, y_pred): | |
return loss1_ + model.hp_loss2_weight * loss_2_ | |
# also, we have to make similar dummy functions for each | |
# metric we want to make | |
def loss1(y_true, y_pred): | |
return loss1_ | |
def loss2(y_true, y_pred): | |
return loss2_ | |
def loss2_weighted(y_true, y_pred): | |
return model.hp_loss2_weight * loss2_ | |
# finally, we compile the model using these dummy functions | |
# as loss and metrics: | |
model.compile(optimizer='adam', loss=loss_total, metrics=[loss1, loss2, loss2_weighted]) | |
model = some_model() | |
# The final caveat is that now we have to specify inputs and outputs for the model when training it: | |
model.fit(X_train, X_train, validation_data=(X_val, X_val)) | |
# now all the new monitored loss terms should appear in the history object in addition to the | |
# usuall "loss" and "val_loss" items: | |
print(model.history.history.keys()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment