Created
January 7, 2023 13:11
-
-
Save AndreiMoraru123/0c6787e7dec1180c0932006f8dc6ff72 to your computer and use it in GitHub Desktop.
Short MNIST model pruning using a MLPClassifier in scikit-learn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A simple implementation of L1 unstructured pruning. | |
Uses the magnitude of the weights to determine which weights to prune. | |
Made for the purpose of understanding pruning. | |
Built on top of the scikit-learn MLPClassifier. | |
Trained on the MNIST (28 x 28) dataset. | |
Output: | |
Accuracy before pruning: 0.901 | |
Model size before pruning: 1291299 | |
Accuracy after pruning: 0.732 | |
Model size after pruning: 1291352 | |
Time to predict 100 samples with non-pruned model: 0.005106449127197266 | |
Time to predict 100 samples with pruned model: 0.004999876022338867 | |
""" | |
# ---------------- Importing Libraries ---------------- | |
import copy | |
import os | |
import pickle | |
import time | |
import numpy as np | |
from sklearn.neural_network import MLPClassifier | |
from sklearn.datasets import fetch_openml | |
# ----------------------------------------------------- | |
# Load the MNIST dataset | |
X, y = fetch_openml('mnist_784', version=1, return_X_y=True) | |
# Split the data into training and testing sets | |
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:] | |
# Train a multi-layer perceptron classifier | |
model = MLPClassifier(hidden_layer_sizes=(100,), max_iter=15, alpha=1e-4, | |
solver='sgd', verbose=10, tol=1e-4, random_state=1) | |
model.fit(X_train, y_train) | |
# Save the model to a file | |
with open("model.pkl", "wb") as f: | |
pickle.dump(model, f) | |
# Load the model from the file | |
with open("model.pkl", "rb") as f: | |
model = pickle.load(f) | |
# Measure the accuracy of the model on the test set | |
accuracy_before_pruning = model.score(X_test, y_test) | |
print("Accuracy before pruning:", accuracy_before_pruning) | |
# Get the size of the model on disk | |
model_size_before_pruning = os.stat("model.pkl").st_size | |
print("Model size before pruning:", model_size_before_pruning) | |
# Copy the weights | |
pruned_model = copy.deepcopy(model) | |
# Prune the model | |
pruned_model.coefs_[0][np.abs(pruned_model.coefs_[0]) < 0.05] = 0 | |
# Save the pruned model to a file | |
with open("model_pruned.pkl", "wb") as f: | |
pickle.dump(pruned_model, f) | |
# Load the pruned model from the file | |
with open("model_pruned.pkl", "rb") as f: | |
pruned_model = pickle.load(f) | |
# Measure the accuracy of the pruned model on the test set | |
accuracy_after_pruning = pruned_model.score(X_test, y_test) | |
print("Accuracy after pruning:", accuracy_after_pruning) | |
# Get the size of the pruned model on disk | |
model_size_after_pruning = os.stat("model_pruned.pkl").st_size | |
print("Model size after pruning:", model_size_after_pruning) | |
# Compare the speed of the pruned and non-pruned models | |
start = time.time() | |
model.predict(X_test[:100]) | |
end = time.time() | |
print("Time to predict 100 samples with non-pruned model:", end - start) | |
start = time.time() | |
pruned_model.predict(X_test[:100]) | |
end = time.time() | |
print("Time to predict 100 samples with pruned model:", end - start) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment