Skip to content

Instantly share code, notes, and snippets.

@AndreiMoraru123
Created January 7, 2023 13:11
Show Gist options
  • Save AndreiMoraru123/0c6787e7dec1180c0932006f8dc6ff72 to your computer and use it in GitHub Desktop.
Save AndreiMoraru123/0c6787e7dec1180c0932006f8dc6ff72 to your computer and use it in GitHub Desktop.
Short MNIST model pruning using a MLPClassifier in scikit-learn
"""
A simple implementation of L1 unstructured pruning.
Uses the magnitude of the weights to determine which weights to prune.
Made for the purpose of understanding pruning.
Built on top of the scikit-learn MLPClassifier.
Trained on the MNIST (28 x 28) dataset.
Output:
Accuracy before pruning: 0.901
Model size before pruning: 1291299
Accuracy after pruning: 0.732
Model size after pruning: 1291352
Time to predict 100 samples with non-pruned model: 0.005106449127197266
Time to predict 100 samples with pruned model: 0.004999876022338867
"""
# ---------------- Importing Libraries ----------------
import copy
import os
import pickle
import time
import numpy as np
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import fetch_openml
# -----------------------------------------------------
# Load the MNIST dataset
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
# Train a multi-layer perceptron classifier
model = MLPClassifier(hidden_layer_sizes=(100,), max_iter=15, alpha=1e-4,
solver='sgd', verbose=10, tol=1e-4, random_state=1)
model.fit(X_train, y_train)
# Save the model to a file
with open("model.pkl", "wb") as f:
pickle.dump(model, f)
# Load the model from the file
with open("model.pkl", "rb") as f:
model = pickle.load(f)
# Measure the accuracy of the model on the test set
accuracy_before_pruning = model.score(X_test, y_test)
print("Accuracy before pruning:", accuracy_before_pruning)
# Get the size of the model on disk
model_size_before_pruning = os.stat("model.pkl").st_size
print("Model size before pruning:", model_size_before_pruning)
# Copy the weights
pruned_model = copy.deepcopy(model)
# Prune the model
pruned_model.coefs_[0][np.abs(pruned_model.coefs_[0]) < 0.05] = 0
# Save the pruned model to a file
with open("model_pruned.pkl", "wb") as f:
pickle.dump(pruned_model, f)
# Load the pruned model from the file
with open("model_pruned.pkl", "rb") as f:
pruned_model = pickle.load(f)
# Measure the accuracy of the pruned model on the test set
accuracy_after_pruning = pruned_model.score(X_test, y_test)
print("Accuracy after pruning:", accuracy_after_pruning)
# Get the size of the pruned model on disk
model_size_after_pruning = os.stat("model_pruned.pkl").st_size
print("Model size after pruning:", model_size_after_pruning)
# Compare the speed of the pruned and non-pruned models
start = time.time()
model.predict(X_test[:100])
end = time.time()
print("Time to predict 100 samples with non-pruned model:", end - start)
start = time.time()
pruned_model.predict(X_test[:100])
end = time.time()
print("Time to predict 100 samples with pruned model:", end - start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment