Created
December 24, 2018 14:18
-
-
Save YiqinZhao/f13515ec0888c21f1d30e34ed6a7b591 to your computer and use it in GitHub Desktop.
Training Result Utility
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import pandas as pd | |
import numpy as np | |
from keras.models import Model | |
from keras.callbacks import Callback | |
from sklearn.metrics import confusion_matrix | |
def pretty_result(y_true, y_pred, info): | |
""" | |
Display result in a elegant way. | |
:@param y_true: true label | |
:@param y_pred: predicted label | |
""" | |
cfx = confusion_matrix(y_true, y_pred) | |
ua = np.average([item[idx] / np.sum(item) for idx, item in enumerate(cfx)]) | |
wa = np.sum(y_true == y_pred) / len(y_pred) | |
# Calculate Precision and Recall | |
max_label = len(cfx) | |
cfx_data = [list(row) + [row[idx] / np.sum(row)] | |
for idx, row in enumerate(cfx)] | |
cfx_data = np.transpose(cfx_data) | |
cfx_data = [list(row) + | |
([(row[idx] / np.sum(row)) | |
if np.sum(row) else 0] | |
if idx < max_label else [0]) | |
for idx, row in enumerate(cfx_data)] | |
cfx_data = np.transpose(cfx_data) | |
data_distribution = [np.sum(x) for x in cfx] | |
data_correctness = [x[i] for i, x in enumerate(cfx)] | |
print('--------'.join(['' for _ in range(max_label + 3)])) | |
print(' %s' % info) | |
print('========'.join(['' for _ in range(max_label + 3)])) | |
print('- WA : ', wa) | |
print('- UA : ', ua) | |
print('- Y Distribution : ', data_distribution) | |
print('- Y Correctness : ', data_correctness) | |
print('- Confusion Matrix :') | |
print('\t%s\tRCAL' % '\t'.join([str(v) for v in range(max_label)])) | |
for idx, row in enumerate(cfx_data): | |
mark = idx if not idx == max_label else 'PRCS' | |
digitals = [0 if math.isnan(v) else v for v in row] | |
inner = '' | |
if idx == max_label: | |
inner = '\t'.join(['%3.2f' for _ in range(max_label)]) % tuple( | |
digitals[:max_label]) | |
else: | |
inner = '\t'.join(['%d' for _ in range(max_label)] | |
) % tuple(digitals[:max_label]) | |
print('%s\t%s\t%3.2f' % (mark, inner, digitals[max_label])) | |
print('--------'.join(['' for _ in range(max_label + 3)])) | |
print('[%s], WA: %s, UA: %s' % (info, wa, ua)) | |
return [ua, wa] | |
class TrainingResult(Callback): | |
def __init__(self, test_data, valid_data, batch_size=32, use_generator=False, show_valid=True, label='exp'): | |
super().__init__() | |
self.test_data = test_data | |
self.valid_data = valid_data | |
self.batch_size = batch_size | |
self.use_generator = use_generator | |
self.show_valid = show_valid | |
self.label = label | |
print('[Register Statistic], %s test data' % label) | |
if show_valid: | |
print('[Register Statistic], %s valid data' % label) | |
def on_epoch_end(self, epoch, logs={}): | |
# Test Data | |
x, y = self.test_data | |
res = self.model.predict(x, batch_size=self.batch_size) | |
y, p = np.argmax(y, axis=-1), np.argmax(res, axis=-1) | |
pretty_result(y, p, '%s test data' % self.label) | |
# Hide validation data display | |
if not self.show_valid: | |
return | |
# Validation Data | |
x, y = self.valid_data | |
res = self.model.predict(x, batch_size=self.batch_size) | |
y, p = np.argmax(y, axis=-1), np.argmax(res, axis=-1) | |
pretty_result(y, p, '%s valid data' % self.label) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment