Skip to content

Instantly share code, notes, and snippets.

@sachinkmohan
Last active April 13, 2022 12:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sachinkmohan/61b1da5d0ba2b0c818a4ec55d0e332fd to your computer and use it in GitHub Desktop.
Save sachinkmohan/61b1da5d0ba2b0c818a4ec55d0e332fd to your computer and use it in GitHub Desktop.
Tensorflow model optimization involving pruning and quantization

Pruning on the whole model

import tensorflow_model_optimization as tfmot 
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule as pruning_sched

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 8
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = 40000
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs


#'''
#Defining pruning parameters

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=1000)
}

#'''

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

#model_for_pruning.compile(optimizer=adam, loss=ssd_loss.compute_loss)
model_for_pruning.compile()

#model_for_pruning.compile(optimizer=adam, loss='sparse_categorical_crossentropy')

model_for_pruning.summary()
#'''

Pruning just the layers- here Conv2D

    ### Layer pruning begins here

    def apply_pruning_to_conv2d(layer):
        if isinstance(layer, tf.keras.layers.Conv2D):
            return tfmot.sparsity.keras.prune_low_magnitude(layer,
                                                            pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0))
        return layer

    model_for_pruning = tf.keras.models.clone_model(
        model,
        clone_function=apply_pruning_to_conv2d,
    )

    model_for_pruning.compile(optimizer=optimizer, loss=loss)
    model_for_pruning.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment