Skip to content

Instantly share code, notes, and snippets.

@xmaayy
Created February 20, 2019 16:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmaayy/152e26955587ee37571fb768e10021d3 to your computer and use it in GitHub Desktop.
Save xmaayy/152e26955587ee37571fb768e10021d3 to your computer and use it in GitHub Desktop.
A brief example of how to use knowledge distillation to train a model for CIFAR10
"""
Written by Xander May on the 15th of January 2019 to accompany the blogpost on IMRSV.ai
discussing knowledge distillation.
This file trains 10 seperate networks for the prupose of testing the efficacy of the
Knowledge Distillation method.
"""
import pdb
import os
import tensorflow as tf
from keras.models import Model
from keras import backend as K
from keras import initializers
from keras.datasets import cifar10
from keras.optimizers import RMSprop
from keras.utils import to_categorical
from keras.layers.merge import concatenate
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Dense, Embedding, Input
from keras.engine.topology import Layer, InputSpec
from keras.layers import Dropout, Activation, Flatten
sdir = os.path.join(os.getcwd())
model_name = 'cifar10KD.h5' # The output filename
batch_size = 32 # The size of the training batches, make this bigger for faster GPU's
num_classes = 10 #How many classes (could also get this form the dataset)
epochs = 20 # How many epochs should we train each model for
generations = 5 # How many models we're going to train
num_predictions = 20
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
def get_model(is_student=False):
input_layer = Input(shape=x_train.shape[1:])
x = Conv2D(32, (3,3), padding='same')(input_layer)
x = Activation('relu')(x)
x = Conv2D(32, (3,3))(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(2,2))(x)
x = Dropout(0.25)(x)
x = Conv2D(32, (3,3), padding='same')(input_layer)
x = Activation('relu')(x)
x = Conv2D(32, (3,3))(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(2,2))(x)
x = Dropout(0.25)(x)
x = Flatten()(x)
x = Dense(512)(x)
x = Activation('relu')(x)
x = Dropout(0.25)(x)
x = Dense(10)(x)
output_layer = Activation('softmax')(x)
if is_student:
model = Model(inputs=input_layer, outputs=[output_layer, output_layer])
model.compile(loss=['categorical_crossentropy', 'mean_squared_error'],
optimizer=RMSprop(clipvalue=1, clipnorm=1), loss_weights=[0.4,0.7],
metrics=['accuracy'])
else:
model = Model(inputs=input_layer, outputs=output_layer)
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(clipvalue=1, clipnorm=1),
metrics=['accuracy'])
return model
evals = []
for i in range(generations):
# Get a fresh model
model = get_model()
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
shuffle=True)
# Take a mesurement of accuracy at halfway
evals.append(model.evaluate(x_test,y_test))
# Get the models current predictions for use as soft labels
soft_labels = model.predict(x_train)
# Continue training
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
shuffle=True)
# Measure the models end training accuracy
evals.append(model.evaluate(x_test,y_test))
# Get and train a new model with knowledge distillation
model = get_model(True)
model.fit(x_train, [y_train, soft_labels],
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, [y_test, y_test]),
shuffle=True)
# Measure the distilled training accuracy
evals.append(model.evaluate(x_test,[y_test, y_test]))
print(evals)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment