Skip to content

Instantly share code, notes, and snippets.

@kashif
Created November 1, 2021 10:08
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save kashif/4bf65f5a9cd726718b6ec709f4013fec to your computer and use it in GitHub Desktop.
pt-keras-metrics
import tensorflow as tf
import torch
from torchmetrics import Metric
def tf2pt(x_tf=None):
if x_tf is None:
return None
x_torch = torch.utils.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(x_tf))
return x_torch
def pt2tf(x_torch=None):
if x_torch is None:
return None
x_torch = x_torch.contiguous()
x_tf = tf.experimental.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x_torch))
return x_tf
class BatchRecall(Metric):
def _get_batch_similarities(self, batch_label, full_vocab_similarities):
return tf.gather(full_vocab_similarities, tf.transpose(batch_label)[0], axis=1)
def __init__(
self,
thresholds=None,
top_k=1,
class_id=None,
name="batch_recall",
):
super().__init__()
self.metric = tf.keras.metrics.Recall(
thresholds=thresholds, top_k=top_k, class_id=class_id, name=name
)
def update(self, y_true, y_pred, sample_weight=None):
tf_y_true = pt2tf(y_true.unsqueeze(-1))
tf_y_pred = pt2tf(y_pred)
tf_sample_weight = pt2tf(sample_weight)
label_indicides = tf.eye(
tf.shape(tf_y_pred)[0], tf.shape(tf_y_pred)[0], dtype=tf.dtypes.float32
)
normalized_logits = tf.nn.softmax(
self._get_batch_similarities(tf_y_true, tf_y_pred)
)
self.metric.update_state(label_indicides, normalized_logits, tf_sample_weight)
def compute(self):
return tf2pt(self.metric.result())
def reset(self):
self.metric.reset_state()
return super().reset()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment