Pruning on the whole model
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule as pruning_sched
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
# Compute end step to finish pruning after 2 epochs.
batch_size = 8
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.
num_images = 40000
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
#'''
#Defining pruning parameters
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
final_sparsity=0.80,
begin_step=0,
end_step=1000)
}
#'''
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
#model_for_pruning.compile(optimizer=adam, loss=ssd_loss.compute_loss)
model_for_pruning.compile()
#model_for_pruning.compile(optimizer=adam, loss='sparse_categorical_crossentropy')
model_for_pruning.summary()
#'''
Pruning just the layers- here Conv2D
### Layer pruning begins here
def apply_pruning_to_conv2d(layer):
if isinstance(layer, tf.keras.layers.Conv2D):
return tfmot.sparsity.keras.prune_low_magnitude(layer,
pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0))
return layer
model_for_pruning = tf.keras.models.clone_model(
model,
clone_function=apply_pruning_to_conv2d,
)
model_for_pruning.compile(optimizer=optimizer, loss=loss)
model_for_pruning.summary()