Skip to content

Instantly share code, notes, and snippets.

View sayakpaul's full-sized avatar
:octocat:
Learn, unlearn and relearn.

Sayak Paul sayakpaul

:octocat:
Learn, unlearn and relearn.
View GitHub Profile
converter = tf.lite.TFLiteConverter.from_keras_model(non_qat_flower_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()
f = open("normal_flower_model.tflite", "wb")
f.write(quantized_tflite_model)
f.close()
import tensorflow_model_optimization as tfmot
qat_model = tfmot.quantization.keras.quantize_model(your_keras_model)
# Copy the kernel weights and get ranked indices of the
# column-wise L2 Norms
kernel_weights = np.copy(k_weights)
ind = np.argsort(np.linalg.norm(kernel_weights, axis=0))
# Number of indices to be set to 0
sparsity_percentage = 0.7
cutoff = int(len(ind)*sparsity_percentage)
# The indices in the 2D kernel weight matrix to be set to 0
pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(
target_sparsity=target_sparsity,
begin_step=begin_step,
end_step=end_step,
frequency=frequency
)
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
trained_model, pruning_schedule=pruning_schedule
)
pruned_model.compile(loss='sparse_categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
pruned_model.summary()
Layer (type) Output Shape Param #
=================================================================
prune_low_magnitude_conv2d ( (None, 26, 26, 12) 230
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 12) 1
_________________________________________________________________
prune_low_magnitude_flatten (None, 2028) 1
_________________________________________________________________
prune_low_magnitude_dense (P (None, 10) 40572
=================================================================
for layer in model.layers:
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
for weight in layer.layer.get_prunable_weights():
print(np.allclose(
target_sparsity, get_sparsity(tf.keras.backend.get_value(weight)),
rtol=1e-6, atol=1e-6)
)
def get_sparsity(weights):
return 1.0 - np.count_nonzero(weights) / float(weights.size)
def get_gzipped_model_size(file):
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(file)
return os.path.getsize(zipped_file)
# Courtesy: TFOD API Authors
# https://github.com/tensorflow/models/blob/master/research/object_detection
model {
ssd {
inplace_batchnorm_update: true
freeze_batchnorm: false
num_classes: 37
box_coder {
faster_rcnn_box_coder {