Skip to content

Instantly share code, notes, and snippets.

@cshimmin
Last active June 25, 2019 13:10
Show Gist options
  • Save cshimmin/546472887feea457984325c075d4dabe to your computer and use it in GitHub Desktop.
Save cshimmin/546472887feea457984325c075d4dabe to your computer and use it in GitHub Desktop.
def some_mode():
auto_input = K.Input((whatever,))
auto_z = encoder(auto_input)
auto_output = decoder(auto_z)
model = Model(input_layer, output_layer)
loss1_ = K.mean(K.square(auto_input - auto_output))
loss2_ = K.mean(K.square(auto_z))
model.hp_loss2_weight = K.variable(1.0, name='loss2_weight')
# instead of using model.add_loss(), we have to make a
# "dummy" loss function that looks like a function of
# true vs. predicted outputs
def loss_total(y_true, y_pred):
return loss1_ + model.hp_loss2_weight * loss_2_
# also, we have to make similar dummy functions for each
# metric we want to make
def loss1(y_true, y_pred):
return loss1_
def loss2(y_true, y_pred):
return loss2_
def loss2_weighted(y_true, y_pred):
return model.hp_loss2_weight * loss2_
# finally, we compile the model using these dummy functions
# as loss and metrics:
model.compile(optimizer='adam', loss=loss_total, metrics=[loss1, loss2, loss2_weighted])
model = some_model()
# The final caveat is that now we have to specify inputs and outputs for the model when training it:
model.fit(X_train, X_train, validation_data=(X_val, X_val))
# now all the new monitored loss terms should appear in the history object in addition to the
# usuall "loss" and "val_loss" items:
print(model.history.history.keys())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment