Skip to content

Instantly share code, notes, and snippets.

@georgwiese
Created April 7, 2020 12:00
Show Gist options
  • Save georgwiese/57568ac518813b9d9f0e6785d8f707fa to your computer and use it in GitHub Desktop.
Save georgwiese/57568ac518813b9d9f0e6785d8f707fa to your computer and use it in GitHub Desktop.
From fb7dd5c67d70f86082efca7cd3f29061b808467f Mon Sep 17 00:00:00 2001
From: Georg Wiese <georgwiese@gmail.com>
Date: Thu, 2 Apr 2020 15:38:00 +0200
Subject: [PATCH] Fix keras handling of targets with no loss
---
tensorflow/python/keras/engine/training_eager.py | 14 ++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index be1b2e89d9..e39a571a91 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -143,10 +143,11 @@ def _model_loss(model,
output_losses = []
with backend.name_scope('loss'):
- loss_fns = [
- loss_fn for loss_fn in model.loss_functions if loss_fn is not None
- ]
- for i, loss_fn in enumerate(loss_fns):
+ i_target = 0
+ for i, loss_fn in enumerate(model.loss_functions):
+ if loss_fn is None:
+ # Output i has no loss.
+ continue
weights = sample_weights[i] if sample_weights else None
mask = masks[i]
with backend.name_scope(model.output_names[i] + '_loss'):
@@ -163,7 +164,7 @@ def _model_loss(model,
weights *= mask
if hasattr(loss_fn, 'reduction'):
- per_sample_losses = loss_fn.call(targets[i], outs[i])
+ per_sample_losses = loss_fn.call(targets[i_target], outs[i])
weighted_losses = losses_utils.compute_weighted_loss(
per_sample_losses,
sample_weight=weights,
@@ -193,13 +194,14 @@ def _model_loss(model,
# as part of the loss_metrics.
if len(model.outputs) > 1:
# Keep track of the stateful output loss result.
- output_losses.append(output_loss_metrics[i](output_loss))
+ output_losses.append(output_loss_metrics[i_target](output_loss))
# Scale output loss for distribution. For custom losses we assume
# reduction was mean.
if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE:
output_loss = losses_utils.scale_loss_for_distribution(output_loss)
total_loss += model._loss_weights_list[i] * output_loss
+ i_target += 1
# Add regularization losses
custom_losses = model.losses
--
2.21.1 (Apple Git-122.3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment