Skip to content

Instantly share code, notes, and snippets.

@AkiyonKS
Created August 9, 2022 05:09
Show Gist options
  • Save AkiyonKS/8c888a22f04c4ab4ef3312d322f02611 to your computer and use it in GitHub Desktop.
Save AkiyonKS/8c888a22f04c4ab4ef3312d322f02611 to your computer and use it in GitHub Desktop.
import tensorflow_model_optimization as tfmot
from keras import optimizers
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
# Compute end step to finish pruning after 2 epochs.
batch_size = params['batch_size']
epochs = params['epochs']
num_images = len(df_photo_ids['train'])
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
# Define model for pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
final_sparsity=0.80,
begin_step=0,
end_step=end_step)
}
model_for_pruning = prune_low_magnitude(model, **pruning_params)
# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer=optimizers.gradient_descent_v2.SGD(learning_rate=1e-4, momentum=0.9),
loss='categorical_crossentropy',
metrics=['accuracy'])
model_for_pruning.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment