Created
February 20, 2019 16:02
-
-
Save xmaayy/152e26955587ee37571fb768e10021d3 to your computer and use it in GitHub Desktop.
A brief example of how to use knowledge distillation to train a model for CIFAR10
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Written by Xander May on the 15th of January 2019 to accompany the blogpost on IMRSV.ai | |
discussing knowledge distillation. | |
This file trains 10 seperate networks for the prupose of testing the efficacy of the | |
Knowledge Distillation method. | |
""" | |
import pdb | |
import os | |
import tensorflow as tf | |
from keras.models import Model | |
from keras import backend as K | |
from keras import initializers | |
from keras.datasets import cifar10 | |
from keras.optimizers import RMSprop | |
from keras.utils import to_categorical | |
from keras.layers.merge import concatenate | |
from keras.layers import Conv2D, MaxPooling2D | |
from keras.layers import Dense, Embedding, Input | |
from keras.engine.topology import Layer, InputSpec | |
from keras.layers import Dropout, Activation, Flatten | |
sdir = os.path.join(os.getcwd()) | |
model_name = 'cifar10KD.h5' # The output filename | |
batch_size = 32 # The size of the training batches, make this bigger for faster GPU's | |
num_classes = 10 #How many classes (could also get this form the dataset) | |
epochs = 20 # How many epochs should we train each model for | |
generations = 5 # How many models we're going to train | |
num_predictions = 20 | |
(x_train, y_train), (x_test, y_test) = cifar10.load_data() | |
y_train = to_categorical(y_train, num_classes) | |
y_test = to_categorical(y_test, num_classes) | |
x_train = x_train.astype('float32') | |
x_test = x_test.astype('float32') | |
x_train /= 255 | |
x_test /= 255 | |
def get_model(is_student=False): | |
input_layer = Input(shape=x_train.shape[1:]) | |
x = Conv2D(32, (3,3), padding='same')(input_layer) | |
x = Activation('relu')(x) | |
x = Conv2D(32, (3,3))(x) | |
x = Activation('relu')(x) | |
x = MaxPooling2D(pool_size=(2,2))(x) | |
x = Dropout(0.25)(x) | |
x = Conv2D(32, (3,3), padding='same')(input_layer) | |
x = Activation('relu')(x) | |
x = Conv2D(32, (3,3))(x) | |
x = Activation('relu')(x) | |
x = MaxPooling2D(pool_size=(2,2))(x) | |
x = Dropout(0.25)(x) | |
x = Flatten()(x) | |
x = Dense(512)(x) | |
x = Activation('relu')(x) | |
x = Dropout(0.25)(x) | |
x = Dense(10)(x) | |
output_layer = Activation('softmax')(x) | |
if is_student: | |
model = Model(inputs=input_layer, outputs=[output_layer, output_layer]) | |
model.compile(loss=['categorical_crossentropy', 'mean_squared_error'], | |
optimizer=RMSprop(clipvalue=1, clipnorm=1), loss_weights=[0.4,0.7], | |
metrics=['accuracy']) | |
else: | |
model = Model(inputs=input_layer, outputs=output_layer) | |
model.compile(loss='categorical_crossentropy', | |
optimizer=RMSprop(clipvalue=1, clipnorm=1), | |
metrics=['accuracy']) | |
return model | |
evals = [] | |
for i in range(generations): | |
# Get a fresh model | |
model = get_model() | |
model.fit(x_train, y_train, | |
batch_size=batch_size, | |
epochs=epochs, | |
validation_data=(x_test, y_test), | |
shuffle=True) | |
# Take a mesurement of accuracy at halfway | |
evals.append(model.evaluate(x_test,y_test)) | |
# Get the models current predictions for use as soft labels | |
soft_labels = model.predict(x_train) | |
# Continue training | |
model.fit(x_train, y_train, | |
batch_size=batch_size, | |
epochs=epochs, | |
validation_data=(x_test, y_test), | |
shuffle=True) | |
# Measure the models end training accuracy | |
evals.append(model.evaluate(x_test,y_test)) | |
# Get and train a new model with knowledge distillation | |
model = get_model(True) | |
model.fit(x_train, [y_train, soft_labels], | |
batch_size=batch_size, | |
epochs=epochs, | |
validation_data=(x_test, [y_test, y_test]), | |
shuffle=True) | |
# Measure the distilled training accuracy | |
evals.append(model.evaluate(x_test,[y_test, y_test])) | |
print(evals) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment