Skip to content

Instantly share code, notes, and snippets.

@staceysv
Last active October 2, 2019 21:36
Show Gist options
  • Save staceysv/2724732c6632220a5ff6078bf0c8ae14 to your computer and use it in GitHub Desktop.
Save staceysv/2724732c6632220a5ff6078bf0c8ae14 to your computer and use it in GitHub Desktop.
example W&B callback for per-class precision
from keras.callbacks import Callback
import numpy as np
import pandas
from sklearn.metrics import precision_score
import time
import wandb
# other metrics from sklearn we may use
# confusion_matrix, f1_score, recall_score
class PerClassMetrics(Callback):
""" Computes per class metrics (precision for now) at the end of each training epoch"""
def __init__(self, generator=None, num_batches=25, mode="image_table", verbose=False):
""" Stores reference to data generator, number of batches to use when computing
per class metrics, and class names to log metrics """
self.generator = generator
self.num_batches = num_batches
self.class_names = { v: k for k, v in generator.class_indices.items() }
self.mode = mode
self.verbose = verbose
self.run = wandb.init()
def per_class_precision(self, ground_truth, guesses):
""" First pass at computing precision for each class--iterate over the array of
truth and guesses and increment the count of true positives and false positives
for each class. This can probably be much faster :) """
num_classes = len(self.class_names)
tp = np.zeros(num_classes)
fp = np.zeros(num_classes)
for truth, guess in zip(ground_truth, guesses):
if truth == guess:
tp[truth] += 1
else:
fp[guess] += 1
return tp / (tp + fp)
def on_epoch_end(self, epoch, logs={}):
# collect validation data and ground truth labels from generator
val_data, val_labels = zip(*(self.generator[i] for i in range(self.num_batches)))
val_data, val_labels = np.vstack(val_data), np.vstack(val_labels)
# make predictions for all the collected validation data
val_predictions = self.model.predict(val_data)
if self.mode == "precision":
# convert class probabilities and one-hot encoded labels to corresponding class ids
guessed_class_ids = val_predictions.argmax(axis=1)
ground_truth_class_ids = val_labels.argmax(axis=1)
# sklearn implementation is 2-3x faster than our naive implementation
prec_start = time.time()
val_precision = precision_score(ground_truth_class_ids, guessed_class_ids, average=None)
sklearn_time = time.time() - prec_start
# track precision for all classes
precisions = {"prec/" + self.class_names[i] : round(prec, 2) for i, prec in enumerate(val_precision)}
# optionally compare naive implementation to sklearn's
if self.verbose:
ours_start = time.time()
ours = self.per_class_precision(ground_truth_class_ids, guessed_class_ids)
ours_time = time.time() - ours_start
print("sklearn: "+ str( val_precision) + ", " + str(sklearn_time))
print("ours: "+str(ours)+ ", "+str(ours_time))
wandb.log(precisions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment