Skip to content

Instantly share code, notes, and snippets.

@YiqinZhao
Created December 24, 2018 14:18
Show Gist options
  • Save YiqinZhao/f13515ec0888c21f1d30e34ed6a7b591 to your computer and use it in GitHub Desktop.
Save YiqinZhao/f13515ec0888c21f1d30e34ed6a7b591 to your computer and use it in GitHub Desktop.
Training Result Utility
import math
import pandas as pd
import numpy as np
from keras.models import Model
from keras.callbacks import Callback
from sklearn.metrics import confusion_matrix
def pretty_result(y_true, y_pred, info):
"""
Display result in a elegant way.
:@param y_true: true label
:@param y_pred: predicted label
"""
cfx = confusion_matrix(y_true, y_pred)
ua = np.average([item[idx] / np.sum(item) for idx, item in enumerate(cfx)])
wa = np.sum(y_true == y_pred) / len(y_pred)
# Calculate Precision and Recall
max_label = len(cfx)
cfx_data = [list(row) + [row[idx] / np.sum(row)]
for idx, row in enumerate(cfx)]
cfx_data = np.transpose(cfx_data)
cfx_data = [list(row) +
([(row[idx] / np.sum(row))
if np.sum(row) else 0]
if idx < max_label else [0])
for idx, row in enumerate(cfx_data)]
cfx_data = np.transpose(cfx_data)
data_distribution = [np.sum(x) for x in cfx]
data_correctness = [x[i] for i, x in enumerate(cfx)]
print('--------'.join(['' for _ in range(max_label + 3)]))
print(' %s' % info)
print('========'.join(['' for _ in range(max_label + 3)]))
print('- WA : ', wa)
print('- UA : ', ua)
print('- Y Distribution : ', data_distribution)
print('- Y Correctness : ', data_correctness)
print('- Confusion Matrix :')
print('\t%s\tRCAL' % '\t'.join([str(v) for v in range(max_label)]))
for idx, row in enumerate(cfx_data):
mark = idx if not idx == max_label else 'PRCS'
digitals = [0 if math.isnan(v) else v for v in row]
inner = ''
if idx == max_label:
inner = '\t'.join(['%3.2f' for _ in range(max_label)]) % tuple(
digitals[:max_label])
else:
inner = '\t'.join(['%d' for _ in range(max_label)]
) % tuple(digitals[:max_label])
print('%s\t%s\t%3.2f' % (mark, inner, digitals[max_label]))
print('--------'.join(['' for _ in range(max_label + 3)]))
print('[%s], WA: %s, UA: %s' % (info, wa, ua))
return [ua, wa]
class TrainingResult(Callback):
def __init__(self, test_data, valid_data, batch_size=32, use_generator=False, show_valid=True, label='exp'):
super().__init__()
self.test_data = test_data
self.valid_data = valid_data
self.batch_size = batch_size
self.use_generator = use_generator
self.show_valid = show_valid
self.label = label
print('[Register Statistic], %s test data' % label)
if show_valid:
print('[Register Statistic], %s valid data' % label)
def on_epoch_end(self, epoch, logs={}):
# Test Data
x, y = self.test_data
res = self.model.predict(x, batch_size=self.batch_size)
y, p = np.argmax(y, axis=-1), np.argmax(res, axis=-1)
pretty_result(y, p, '%s test data' % self.label)
# Hide validation data display
if not self.show_valid:
return
# Validation Data
x, y = self.valid_data
res = self.model.predict(x, batch_size=self.batch_size)
y, p = np.argmax(y, axis=-1), np.argmax(res, axis=-1)
pretty_result(y, p, '%s valid data' % self.label)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment