Forked from wassname/keras_weighted_categorical_crossentropy.py
Created
July 30, 2019 22:34
-
-
Save iflament/48028239a2c730f7ebd3faa61dcd7a75 to your computer and use it in GitHub Desktop.
Keras weighted categorical_crossentropy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A weighted version of categorical_crossentropy for keras (2.0.6). This lets you apply a weight to unbalanced classes. | |
@url: https://gist.github.com/wassname/ce364fddfc8a025bfab4348cf5de852d | |
@author: wassname | |
""" | |
from keras import backend as K | |
def weighted_categorical_crossentropy(weights): | |
""" | |
A weighted version of keras.objectives.categorical_crossentropy | |
Variables: | |
weights: numpy array of shape (C,) where C is the number of classes | |
Usage: | |
weights = np.array([0.5,2,10]) # Class one at 0.5, class 2 twice the normal weights, class 3 10x. | |
loss = weighted_categorical_crossentropy(weights) | |
model.compile(loss=loss,optimizer='adam') | |
""" | |
weights = K.variable(weights) | |
def loss(y_true, y_pred): | |
# scale predictions so that the class probas of each sample sum to 1 | |
y_pred /= K.sum(y_pred, axis=-1, keepdims=True) | |
# clip to prevent NaN's and Inf's | |
y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon()) | |
# calc | |
loss = y_true * K.log(y_pred) * weights | |
loss = -K.sum(loss, -1) | |
return loss | |
return loss | |
import numpy as np | |
from keras.activations import softmax | |
from keras.objectives import categorical_crossentropy | |
# init tests | |
samples=3 | |
maxlen=4 | |
vocab=5 | |
y_pred_n = np.random.random((samples,maxlen,vocab)).astype(K.floatx()) | |
y_pred = K.variable(y_pred_n) | |
y_pred = softmax(y_pred) | |
y_true_n = np.random.random((samples,maxlen,vocab)).astype(K.floatx()) | |
y_true = K.variable(y_true_n) | |
y_true = softmax(y_true) | |
# test 1 that it works the same as categorical_crossentropy with weights of one | |
weights = np.ones(vocab) | |
loss_weighted=weighted_categorical_crossentropy(weights)(y_true,y_pred).eval(session=K.get_session()) | |
loss=categorical_crossentropy(y_true,y_pred).eval(session=K.get_session()) | |
np.testing.assert_almost_equal(loss_weighted,loss) | |
print('OK test1') | |
# test 2 that it works differen't than categorical_crossentropy with weights of less than one | |
weights = np.array([0.1,0.3,0.5,0.3,0.5]) | |
loss_weighted=weighted_categorical_crossentropy(weights)(y_true,y_pred).eval(session=K.get_session()) | |
loss=categorical_crossentropy(y_true,y_pred).eval(session=K.get_session()) | |
np.testing.assert_array_less(loss_weighted,loss) | |
print('OK test2') | |
# same keras version as I tested it on? | |
import keras | |
assert keras.__version__.split('.')[:2]==['2', '0'], 'this was tested on keras 2.0.6 you have %s' % keras.__version | |
print('OK version') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
test weighted_categorical_crossentropy on a real dataset | |
''' | |
from __future__ import print_function | |
import keras | |
from keras.datasets import cifar10 | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras.models import Sequential | |
from keras.layers import Dense, Dropout, Activation, Flatten | |
from keras.layers import Conv2D, MaxPooling2D | |
import os | |
import pickle | |
import numpy as np | |
batch_size = 32 | |
num_classes = 10 | |
epochs = 200 | |
data_augmentation = False | |
num_predictions = 20 | |
save_dir = os.path.join(os.getcwd(), 'saved_models') | |
model_name = 'keras_cifar10_trained_model.h5' | |
# The data, shuffled and split between train and test sets: | |
(x_train, y_train), (x_test, y_test) = cifar10.load_data() | |
print('x_train shape:', x_train.shape) | |
print(x_train.shape[0], 'train samples') | |
print(x_test.shape[0], 'test samples') | |
# Convert class vectors to binary class matrices. | |
y_train = keras.utils.to_categorical(y_train, num_classes) | |
y_test = keras.utils.to_categorical(y_test, num_classes) | |
model = Sequential() | |
model.add(Conv2D(32, (3, 3), padding='same', | |
input_shape=x_train.shape[1:])) | |
model.add(Activation('relu')) | |
model.add(Conv2D(32, (3, 3))) | |
model.add(Activation('relu')) | |
model.add(MaxPooling2D(pool_size=(2, 2))) | |
model.add(Dropout(0.25)) | |
model.add(Conv2D(64, (3, 3), padding='same')) | |
model.add(Activation('relu')) | |
model.add(Conv2D(64, (3, 3))) | |
model.add(Activation('relu')) | |
model.add(MaxPooling2D(pool_size=(2, 2))) | |
model.add(Dropout(0.25)) | |
model.add(Flatten()) | |
model.add(Dense(512)) | |
model.add(Activation('relu')) | |
model.add(Dropout(0.5)) | |
model.add(Dense(num_classes)) | |
model.add(Activation('softmax')) | |
# initiate RMSprop optimizer | |
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6) | |
# Let's train the model using RMSprop | |
weights = np.ones((10,)) | |
model.compile(loss=weighted_categorical_crossentropy(weights), | |
optimizer=opt, | |
metrics=['accuracy']) | |
x_train = x_train.astype('float32') | |
x_test = x_test.astype('float32') | |
x_train /= 255 | |
x_test /= 255 | |
nc = 100 | |
x_train = x_train[:nc] | |
y_train = y_train[:nc] | |
x_test = x_test[:nc] | |
y_test = y_test[:nc] | |
if not data_augmentation: | |
print('Not using data augmentation.') | |
model.fit(x_train, y_train, | |
batch_size=batch_size, | |
epochs=epochs, | |
validation_data=(x_test, y_test), | |
shuffle=True) | |
else: | |
print('Using real-time data augmentation.') | |
# This will do preprocessing and realtime data augmentation: | |
datagen = ImageDataGenerator( | |
featurewise_center=False, # set input mean to 0 over the dataset | |
samplewise_center=False, # set each sample mean to 0 | |
featurewise_std_normalization=False, # divide inputs by std of the dataset | |
samplewise_std_normalization=False, # divide each input by its std | |
zca_whitening=False, # apply ZCA whitening | |
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) | |
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) | |
height_shift_range=0.1, # randomly shift images vertically (fraction of total height) | |
horizontal_flip=True, # randomly flip images | |
vertical_flip=False) # randomly flip images | |
# Compute quantities required for feature-wise normalization | |
# (std, mean, and principal components if ZCA whitening is applied). | |
datagen.fit(x_train) | |
# Fit the model on the batches generated by datagen.flow(). | |
model.fit_generator(datagen.flow(x_train, y_train, | |
batch_size=batch_size), | |
steps_per_epoch=x_train.shape[0] // batch_size, | |
epochs=epochs, | |
validation_data=(x_test, y_test)) | |
# Save model and weights | |
if not os.path.isdir(save_dir): | |
os.makedirs(save_dir) | |
model_path = os.path.join(save_dir, model_name) | |
model.save(model_path) | |
print('Saved trained model at %s ' % model_path) | |
# Load label names to use in prediction results | |
label_list_path = 'datasets/cifar-10-batches-py/batches.meta' | |
keras_dir = os.path.expanduser(os.path.join('~', '.keras')) | |
datadir_base = os.path.expanduser(keras_dir) | |
if not os.access(datadir_base, os.W_OK): | |
datadir_base = os.path.join('/tmp', '.keras') | |
label_list_path = os.path.join(datadir_base, label_list_path) | |
with open(label_list_path, mode='rb') as f: | |
labels = pickle.load(f) | |
# Evaluate model with test data set and share sample prediction results | |
evaluation = model.evaluate_generator(datagen.flow(x_test, y_test, | |
batch_size=batch_size), | |
steps=x_test.shape[0] // batch_size) | |
print('Model Accuracy = %.2f' % (evaluation[1])) | |
nc=200 | |
predict_gen = model.predict_generator(datagen.flow(x_test, y_test, | |
batch_size=batch_size), | |
steps=x_test.shape[0] // batch_size) | |
for predict_index, predicted_y in enumerate(predict_gen): | |
actual_label = labels['label_names'][np.argmax(y_test[predict_index])] | |
predicted_label = labels['label_names'][np.argmax(predicted_y)] | |
print('Actual Label = %s vs. Predicted Label = %s' % (actual_label, | |
predicted_label)) | |
if predict_index == num_predictions: | |
break | |
""" | |
Epoch 195/200 | |
100/100 [==============================] - 2s - loss: 0.2921 - acc: 0.9300 - val_loss: 3.1197 - val_acc: 0.2300 | |
Epoch 196/200 | |
100/100 [==============================] - 2s - loss: 0.3474 - acc: 0.9300 - val_loss: 3.1419 - val_acc: 0.2200 | |
Epoch 197/200 | |
100/100 [==============================] - 2s - loss: 0.3614 - acc: 0.9000 - val_loss: 3.2418 - val_acc: 0.2300 | |
Epoch 198/200 | |
100/100 [==============================] - 2s - loss: 0.4221 - acc: 0.8800 - val_loss: 3.1150 - val_acc: 0.2100 | |
Epoch 199/200 | |
100/100 [==============================] - 2s - loss: 0.3901 - acc: 0.8900 - val_loss: 3.1687 - val_acc: 0.2400 | |
Epoch 200/200 | |
100/100 [==============================] - 2s - loss: 0.3228 - acc: 0.9400 - val_loss: 3.3791 - val_acc: 0.2200 | |
Saved trained model at D:\NotBackedUp\MyDocumentsLarge_mclark52\WinPython-64bit-3.5.3.1Qt5\notebooks\saved_models\keras_cifar10_trained_model.h5 | |
Model Accuracy = 0.21 | |
Actual Label = cat vs. Predicted Label = ship | |
Actual Label = ship vs. Predicted Label = cat | |
Actual Label = ship vs. Predicted Label = truck | |
Actual Label = airplane vs. Predicted Label = dog | |
Actual Label = frog vs. Predicted Label = bird | |
Actual Label = frog vs. Predicted Label = horse | |
Actual Label = automobile vs. Predicted Label = truck | |
Actual Label = frog vs. Predicted Label = airplane | |
Actual Label = cat vs. Predicted Label = automobile | |
Actual Label = automobile vs. Predicted Label = horse | |
Actual Label = airplane vs. Predicted Label = airplane | |
Actual Label = truck vs. Predicted Label = truck | |
Actual Label = dog vs. Predicted Label = bird | |
Actual Label = horse vs. Predicted Label = truck | |
Actual Label = truck vs. Predicted Label = bird | |
Actual Label = ship vs. Predicted Label = truck | |
Actual Label = dog vs. Predicted Label = truck | |
Actual Label = horse vs. Predicted Label = bird | |
Actual Label = ship vs. Predicted Label = automobile | |
Actual Label = frog vs. Predicted Label = cat | |
Actual Label = horse vs. Predicted Label = automobile | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment