Last active
February 25, 2021 11:08
-
-
Save eliorc/ac21883c530aecb78d84b143a28fd704 to your computer and use it in GitHub Desktop.
DataMapCallback no comments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DataMapCallback(tf.keras.callbacks.Callback): | |
""" | |
Gather training dynamics for data map generation. Assumes a binary or multi-class model, no support for multi label. | |
Arguments | |
--------- | |
- `dataset` (``tf.data.: Dataset``): Usually, as the paper suggests, this is the training dataset. It should be: | |
1. Non-shuffled, so each iteration over the dataset should yield samples in the same order | |
2. Already batched, the ``.batch(n)`` method should already be applied on this dataset | |
3. Should yield batches of ``(features, labels)``, sample weights are not supported | |
- | `outputs_to_probabilities` (``Optional[Callable[[Any], tf.Tensor]]``): | |
Callable to convert model's output to probabilities. Use this if the model outputs logits, dictionary or any | |
other form which is not a tensor of probabilities. Defaults to ``None``. | |
- | `sparse_labels` (``bool``): Set to ``True`` if the labels are given as integers (not one hot encoded). Defaults | |
to ``False``. | |
Attributes | |
---------- | |
- | `gold_labels_probabilities` (``np.ndarray``): Gold label predicted probabilities. With the shape of | |
``(n_samples, n_epochs)`` and ``(i, j)`` is the probability of the gold label for sample ``i`` at epoch ``j``. | |
- `confidence` (``np.ndarray``): Mean of true label probability across epochs. | |
- `variability` (``np.ndarray``): Standard deviation of true label probability across epochs. | |
- `correctness` (``np.ndarray``): Fraction of times correctly predicted across epochs | |
Examples | |
-------- | |
Calculate training dynamics during training | |
.. code-block:: python3 | |
import tensorflow as tf | |
import tavolo as tvl | |
# Load dataset | |
train = ... # Instance of dataset | |
train_unshuffled = ... # Instance of dataset, unshuffled so that each iteration over the dataset would yield | |
# samples in the same order | |
# Prepare | |
train = train.shuffle(BUFFER_SIZE).batch(BATCH_SIZE) | |
train = train_unshuffled.batch(BATCH_SIZE * 10) # No gradient updates in data map, can use bigger batches | |
# Create the datamap callback | |
datamap = tvl.learning.DatMaCallback(dataset=train_unshuffled) | |
# Train | |
model.fit(train, epochs=N_EPOCHS, callbacks=[datamap]) | |
# Get training dynamics | |
confidence, variability, correctness = datamap.confidence, datamap.variability, datamap.correctness | |
Calculate training dynamics from a model that outputs logits (and NOT probabilities) | |
.. code-block:: python3 | |
import tensorflow as tf | |
import tavolo as tvl | |
# Create the datamap callback - using the outputs_to_predictions option | |
datamap = tvl.learning.DatMaCallback(dataset=train_unshuffled, outputs_to_probabilities=tf.nn.softmax) | |
# Train | |
model.fit(train, epochs=N_EPOCHS, callbacks=[datamap]) | |
References | |
---------- | |
- `Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics`_ | |
.. _`Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics`: https://arxiv.org/pdf/2009.10795 | |
""" | |
# TODO - The implementation saves all the gold label probabilities across epochs for the training dynamics | |
# computations. This can be optimized by calculating a running version of each training dynamic. | |
# Once tfp.stats releases RunningVariance and RunningMean to the stable tfp versions - training dynamics | |
# calculations should be reimplemented doing this, thus avoiding (n_epochs - 1) * n_samples floating points | |
# memory usage. | |
def __init__(self, dataset: tf.data.Dataset, | |
outputs_to_probabilities: Optional[Callable[[Any], tf.Tensor]] = None, | |
sparse_labels: bool = False): | |
""" | |
:param dataset: Dataset. Usually, as the paper suggests, this is the training dataset. It should be: | |
- Non-shuffled, so each iteration over the dataset should yield samples in the same order | |
- Already batched, the ``.batch(n)`` method should already be applied on this dataset | |
- Should yield batches of ``(features, labels)``, sample weights are not supported | |
:param outputs_to_probabilities: Callable to convert model's output to probabilities. Use this if the model | |
outputs logits, dictionary or any other form which is not a vector of probabilities. | |
:param sparse_labels: Set to ``True`` if the labels are given as integers (not one hot encoded) | |
""" | |
self._dataset = dataset | |
self._outputs2probabilities = outputs_to_probabilities | |
self._sparse_labels = sparse_labels | |
self._gold_labels_probabilities = None | |
def on_epoch_end(self, epoch, logs=None): | |
gold_label_probabilities = list() | |
for x, y in self._dataset: | |
probabilities = self.model.predict(x) | |
if self._outputs2probabilities is not None: | |
probabilities = self._outputs2probabilities(probabilities) | |
if self._sparse_labels: | |
y = tf.one_hot(y, depth=probabilities.shape[-1]) | |
if tf.rank(tf.squeeze(y)) == 1: | |
probabilities, y = tf.squeeze(probabilities), tf.squeeze(y) | |
batch_gold_label_probabilities = tf.where(y == 0, 1 - probabilities, probabilities) | |
elif tf.rank(tf.squeeze(y)) == 2: | |
if not tf.reduce_all(tf.reduce_sum(tf.cast(y == 1, tf.int32), axis=-1) == 1): | |
raise ValueError('DataMapCallback does not support multi-label classification') | |
batch_gold_label_probabilities = tf.boolean_mask(probabilities, tf.cast(y, tf.bool)).numpy() | |
else: | |
raise ValueError( | |
'tf.squeeze(y) (y == labels from the dataset) must be of rank 1 for binary classification or ' | |
'2 for multi class. Instead got ({})'.format(tf.rank(tf.squeeze(y)))) | |
gold_label_probabilities = np.append(gold_label_probabilities, [batch_gold_label_probabilities]) | |
if self._gold_labels_probabilities is None: | |
self._gold_labels_probabilities = np.expand_dims(gold_label_probabilities, axis=-1) | |
else: | |
stack = [self._gold_labels_probabilities, np.expand_dims(gold_label_probabilities, axis=-1)] | |
self._gold_labels_probabilities = np.hstack(stack) | |
@property | |
def gold_labels_probabilities(self) -> np.ndarray: | |
""" | |
Gold label predicted probabilities. With the shape of ``(n_samples, n_epochs)`` and ``(i, j)`` is the | |
probability of the gold label for sample ``i`` at epoch ``j`` | |
:return: Gold label probabilities | |
""" | |
return self._gold_labels_probabilities | |
@property | |
def confidence(self) -> np.ndarray: | |
""" | |
Mean of true label probability across epochs | |
:return: Confidence | |
""" | |
return np.mean(self._gold_labels_probabilities, axis=-1) | |
@property | |
def variability(self) -> np.ndarray: | |
""" | |
Standard deviation of true label probability across epochs | |
:return: Variability | |
""" | |
return np.std(self._gold_labels_probabilities, axis=-1) | |
@property | |
def correctness(self) -> np.ndarray: | |
""" | |
Fraction of times correctly predicted across epochs | |
:return: Correctness | |
""" | |
return np.mean(self._gold_labels_probabilities > 0.5, axis=-1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment