Last active
October 2, 2019 21:36
-
-
Save staceysv/2724732c6632220a5ff6078bf0c8ae14 to your computer and use it in GitHub Desktop.
example W&B callback for per-class precision
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.callbacks import Callback | |
import numpy as np | |
import pandas | |
from sklearn.metrics import precision_score | |
import time | |
import wandb | |
# other metrics from sklearn we may use | |
# confusion_matrix, f1_score, recall_score | |
class PerClassMetrics(Callback): | |
""" Computes per class metrics (precision for now) at the end of each training epoch""" | |
def __init__(self, generator=None, num_batches=25, mode="image_table", verbose=False): | |
""" Stores reference to data generator, number of batches to use when computing | |
per class metrics, and class names to log metrics """ | |
self.generator = generator | |
self.num_batches = num_batches | |
self.class_names = { v: k for k, v in generator.class_indices.items() } | |
self.mode = mode | |
self.verbose = verbose | |
self.run = wandb.init() | |
def per_class_precision(self, ground_truth, guesses): | |
""" First pass at computing precision for each class--iterate over the array of | |
truth and guesses and increment the count of true positives and false positives | |
for each class. This can probably be much faster :) """ | |
num_classes = len(self.class_names) | |
tp = np.zeros(num_classes) | |
fp = np.zeros(num_classes) | |
for truth, guess in zip(ground_truth, guesses): | |
if truth == guess: | |
tp[truth] += 1 | |
else: | |
fp[guess] += 1 | |
return tp / (tp + fp) | |
def on_epoch_end(self, epoch, logs={}): | |
# collect validation data and ground truth labels from generator | |
val_data, val_labels = zip(*(self.generator[i] for i in range(self.num_batches))) | |
val_data, val_labels = np.vstack(val_data), np.vstack(val_labels) | |
# make predictions for all the collected validation data | |
val_predictions = self.model.predict(val_data) | |
if self.mode == "precision": | |
# convert class probabilities and one-hot encoded labels to corresponding class ids | |
guessed_class_ids = val_predictions.argmax(axis=1) | |
ground_truth_class_ids = val_labels.argmax(axis=1) | |
# sklearn implementation is 2-3x faster than our naive implementation | |
prec_start = time.time() | |
val_precision = precision_score(ground_truth_class_ids, guessed_class_ids, average=None) | |
sklearn_time = time.time() - prec_start | |
# track precision for all classes | |
precisions = {"prec/" + self.class_names[i] : round(prec, 2) for i, prec in enumerate(val_precision)} | |
# optionally compare naive implementation to sklearn's | |
if self.verbose: | |
ours_start = time.time() | |
ours = self.per_class_precision(ground_truth_class_ids, guessed_class_ids) | |
ours_time = time.time() - ours_start | |
print("sklearn: "+ str( val_precision) + ", " + str(sklearn_time)) | |
print("ours: "+str(ours)+ ", "+str(ours_time)) | |
wandb.log(precisions) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment