This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
converter = tf.lite.TFLiteConverter.from_keras_model(non_qat_flower_model) | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
quantized_tflite_model = converter.convert() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
f = open("normal_flower_model.tflite", "wb") | |
f.write(quantized_tflite_model) | |
f.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow_model_optimization as tfmot | |
qat_model = tfmot.quantization.keras.quantize_model(your_keras_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copy the kernel weights and get ranked indices of the | |
# column-wise L2 Norms | |
kernel_weights = np.copy(k_weights) | |
ind = np.argsort(np.linalg.norm(kernel_weights, axis=0)) | |
# Number of indices to be set to 0 | |
sparsity_percentage = 0.7 | |
cutoff = int(len(ind)*sparsity_percentage) | |
# The indices in the 2D kernel weight matrix to be set to 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pruning_schedule = tfmot.sparsity.keras.ConstantSparsity( | |
target_sparsity=target_sparsity, | |
begin_step=begin_step, | |
end_step=end_step, | |
frequency=frequency | |
) | |
pruned_model = tfmot.sparsity.keras.prune_low_magnitude( | |
trained_model, pruning_schedule=pruning_schedule | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pruned_model.compile(loss='sparse_categorical_crossentropy', | |
optimizer='adam', | |
metrics=['accuracy']) | |
pruned_model.summary() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Layer (type) Output Shape Param # | |
================================================================= | |
prune_low_magnitude_conv2d ( (None, 26, 26, 12) 230 | |
_________________________________________________________________ | |
prune_low_magnitude_max_pool (None, 13, 13, 12) 1 | |
_________________________________________________________________ | |
prune_low_magnitude_flatten (None, 2028) 1 | |
_________________________________________________________________ | |
prune_low_magnitude_dense (P (None, 10) 40572 | |
================================================================= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for layer in model.layers: | |
if isinstance(layer, pruning_wrapper.PruneLowMagnitude): | |
for weight in layer.layer.get_prunable_weights(): | |
print(np.allclose( | |
target_sparsity, get_sparsity(tf.keras.backend.get_value(weight)), | |
rtol=1e-6, atol=1e-6) | |
) | |
def get_sparsity(weights): | |
return 1.0 - np.count_nonzero(weights) / float(weights.size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_gzipped_model_size(file): | |
_, zipped_file = tempfile.mkstemp('.zip') | |
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f: | |
f.write(file) | |
return os.path.getsize(zipped_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Courtesy: TFOD API Authors | |
# https://github.com/tensorflow/models/blob/master/research/object_detection | |
model { | |
ssd { | |
inplace_batchnorm_update: true | |
freeze_batchnorm: false | |
num_classes: 37 | |
box_coder { | |
faster_rcnn_box_coder { |