Skip to content

Instantly share code, notes, and snippets.

@wonglkd
Created January 5, 2017 14:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wonglkd/56de56ab05abf3e1c74994d02b40f913 to your computer and use it in GitHub Desktop.
Save wonglkd/56de56ab05abf3e1c74994d02b40f913 to your computer and use it in GitHub Desktop.
Bypass sample weights
# Implement REINFORCE rule by using crossentropy
# + (reward - baseline) as sample_weight.
self.model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=self.lr))
# Adapted from model.compile: redo the computation of loss.
loss_functions = [objectives.get('categorical_crossentropy')]
total_loss = None
for i in range(len(self.model.outputs)):
y_true = self.model.targets[i]
y_pred = self.model.outputs[i]
sample_weight = self.model.sample_weights[i]
output_loss = loss_functions[i](y_true, y_pred) * sample_weight
output_loss = K.mean(output_loss)
if total_loss is None:
total_loss = output_loss
else:
total_loss += output_loss
for r in self.model.regularizers:
total_loss = r(total_loss)
self.model.total_loss = total_loss
# For callbacks; normally in .fit
self.model.validation_data = None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment