Skip to content

Instantly share code, notes, and snippets.

@nicoandmee
Created April 23, 2023 14:09
Show Gist options
  • Save nicoandmee/b5e4b1fcc840c001f49f8401222918bd to your computer and use it in GitHub Desktop.
Save nicoandmee/b5e4b1fcc840c001f49f8401222918bd to your computer and use it in GitHub Desktop.
validation_images = [batch["image"] for batch in validation_dataset] # type: ignore
validation_labels = [batch["label"] for batch in validation_dataset] # type: ignore
def calculate_edit_distance(labels, predictions):
# Get a single batch and convert its labels to sparse tensors.
saprse_labels = tf.cast(tf.sparse.from_dense(labels), dtype=tf.int64)
# Make predictions and convert them to sparse tensors.
input_len = np.ones(predictions.shape[0]) * predictions.shape[1]
predictions_decoded = tf.keras.backend.ctc_decode(
predictions, input_length=input_len, greedy=True
)[0][0][:, :max_length]
sparse_predictions = tf.cast(
tf.sparse.from_dense(predictions_decoded), dtype=tf.int64
)
# Compute individual edit distances and average them out.
edit_distances = tf.edit_distance(
sparse_predictions, saprse_labels, normalize=False
)
return tf.reduce_mean(edit_distances)
class EditDistanceCallback(tf.keras.callbacks.Callback):
def __init__(self, pred_model):
super().__init__()
self.prediction_model = pred_model
def on_epoch_end(self, epoch, logs=None):
edit_distances = []
# convert to list comprehension
for i in range(len(validation_images)):
labels = validation_labels[i]
edit_distances.append(calculate_edit_distance(labels, self.prediction_model.predict(validation_images[i])).numpy())
print(
f"Mean edit distance for epoch {epoch + 1}: {np.mean(edit_distances):.4f}"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment