- In TensorFlow 2.0 metrics have a brand new form of "stateful" objects that have a uniform API consisting of 4 methods:
def __init__(self):
#...
def update_state(self, y_true, y_pred, sample_weight=None):
#...
def result(self):
#...
def reset_states(self):
#...
(Details here: https://www.tensorflow.org/beta/guide/keras/training_and_evaluation#writing_custom_losses_and_metrics)
- Since there's no metric measuring perplexity in TF API so far, I made a custom one using the formula shown by Kirill Mavreshko in his Keras-Transformer implementation: https://github.com/kpot/keras-transformer and the logic for calculating sparseCategoricalCrossentropy loss found at: https://www.tensorflow.org/beta/tutorials/text/transformer#loss_and_metrics
- Perplexity seems to be better metric than accuracy for language generation models - more interpretable one at least. (Shout-out to Aerin Kim for a cool article on perplexity ! :) https://towardsdatascience.com/perplexity-intuition-and-derivation-105dd481c8f3)
- The metric is expecting y_true and y_pred - that's compliant to this new Metrics API. Bear in mind, though, it expects y_pred being in format of logits - that's exactly what the Transformer models outputs during training loop (inspect the predictions output in the train_step method to see the 'logits' as the output: https://www.tensorflow.org/beta/tutorials/text/transformer#training_and_checkpointing, 'logits' are a non-normalised result of model's predictions (raw vectors), if you softmax it, you get probabilities distribution :) https://developers.google.com/machine-learning/glossary/#logits ) but bear that in mind if using this metric for any different model/scenario.
I suggest update from
print
tologging
: