Skip to content

Instantly share code, notes, and snippets.

@tawnkramer
Forked from rocking5566/keras_quant.py
Created May 3, 2019 22:55
Show Gist options
  • Save tawnkramer/d9288fdf79e35566f3190d66b8c67ca9 to your computer and use it in GitHub Desktop.
Save tawnkramer/d9288fdf79e35566f3190d66b8c67ca9 to your computer and use it in GitHub Desktop.
Quantization aware training in keras
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Conv2D, Flatten
from tensorflow.keras.optimizers import RMSprop
# download the mnist to the path '~/.keras/datasets/' if it is the first time to be called
# X shape (60,000 28x28), y shape (10,000, )
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# data pre-processing
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1) / 255. # normalize
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[2], 1) / 255. # normalize
y_train = to_categorical(y_train, num_classes=10) #one hot
y_test = to_categorical(y_test, num_classes=10) #one hot
# Create model
model = Sequential()
model.add(Conv2D(16, (3, 3), input_shape=(28, 28, 1)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dense(10))
model.add(Activation('softmax', name='pred'))
# Quantization aware training
sess = tf.keras.backend.get_session()
tf.contrib.quantize.create_training_graph(sess.graph)
sess.run(tf.global_variables_initializer())
# You can plot the quantize training graph on tensorboard
# tf.summary.FileWriter('/workspace/tensorboard', graph=sess.graph)
# Define optimizer
rmsprop = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)
# We add metrics to get more results you want to see
model.compile(optimizer=rmsprop,
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=1, batch_size=256)
print('\nTesting ------------')
# Evaluate the model with the metrics we defined earlier
loss, accuracy = model.evaluate(x_test, y_test)
print('test loss: ', loss)
print('test accuracy: ', accuracy)
# Print the min max in fakequant
for node in sess.graph.as_graph_def().node:
if 'weights_quant/AssignMaxLast' in node.name \
or 'weights_quant/AssignMinLast' in node.name:
tensor = sess.graph.get_tensor_by_name(node.name + ':0')
print('{} = {}'.format(node.name, sess.run(tensor)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment