Skip to content

Instantly share code, notes, and snippets.

@aymuos15
Created September 5, 2024 15:24
Show Gist options
  • Save aymuos15/10c8348148d3de5954a6b4cc5cfebf2c to your computer and use it in GitHub Desktop.
Save aymuos15/10c8348148d3de5954a6b4cc5cfebf2c to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import numpy as np
import matplotlib.pyplot as plt
class SimpleCNN(nn.Module):
def __init__(self, conv1_out=32, conv2_out=64):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, conv1_out, 3, 1)
self.conv2 = nn.Conv2d(conv1_out, conv2_out, 3, 1)
self.fc1 = nn.Linear(conv2_out * 5 * 5, 128)
self.fc2 = nn.Linear(128, 10)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def apply_pytorch_unstructured_pruning(model, amount):
"""
Apply PyTorch's unstructured pruning.
This method sets individual weights to zero but doesn't remove them from the model.
The total parameter count remains unchanged.
"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=amount)
return model
def apply_pytorch_structured_pruning(model, amount):
"""
Apply PyTorch's structured pruning.
This method zeros out entire channels/filters but doesn't remove them from the model structure.
The total parameter count remains unchanged.
"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.ln_structured(module, name='weight', amount=amount, n=2, dim=0)
return model
def custom_unstructured_prune(model, amount):
"""
Apply custom unstructured pruning.
Similar to PyTorch's unstructured pruning, this sets individual weights to zero
but doesn't remove them from the model. The total parameter count remains unchanged.
"""
for module in model.modules():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
tensor = module.weight.data
alive = tensor.abs() > tensor.abs().quantile(amount)
module.weight.data *= alive
return model
def custom_structured_prune(model, amount):
"""
Apply custom structured pruning.
This method identifies entire filters/channels to remove based on their L2 norm.
However, it doesn't actually remove them from the model structure yet.
This function is typically followed by create_pruned_model_structure to actually reduce the model size.
"""
for module in model.modules():
if isinstance(module, nn.Conv2d):
out_channels = module.out_channels
l2_norm = torch.norm(module.weight.data.view(out_channels, -1), 2, dim=1)
num_keep = int(out_channels * (1 - amount))
top_indices = l2_norm.argsort(descending=True)[:num_keep]
mask = torch.zeros(out_channels)
mask[top_indices] = 1
module.weight.data *= mask.view(-1, 1, 1, 1)
if module.bias is not None:
module.bias.data *= mask
return model
def create_pruned_model_structure(original_model, amount):
"""
Create a new, smaller model structure based on the pruning results.
This function actually reduces the number of parameters in the model by:
1. Creating a new model with fewer filters/channels
2. Copying the remaining weights from the original model to the new model
This is why only this method results in a true reduction of parameter count.
"""
conv1_out = int(original_model.conv1.out_channels * (1 - amount))
conv2_out = int(original_model.conv2.out_channels * (1 - amount))
new_model = SimpleCNN(conv1_out=conv1_out, conv2_out=conv2_out)
# Here, we would typically copy the non-pruned weights from the original model to the new model.
# For simplicity, we're just creating a new model with the reduced size.
return new_model
# Collect model sizes
models = [
("Normal", SimpleCNN()),
("PyTorch Unstructured", apply_pytorch_unstructured_pruning(SimpleCNN(), amount=0.5)),
("PyTorch Structured", apply_pytorch_structured_pruning(SimpleCNN(), amount=0.5)),
("Custom Unstructured", custom_unstructured_prune(SimpleCNN(), amount=0.5)),
("Custom Structured", create_pruned_model_structure(custom_structured_prune(SimpleCNN(), amount=0.5), amount=0.5))
]
names = [name for name, _ in models]
sizes = [count_parameters(model) for _, model in models]
# Create bar plot
plt.figure(figsize=(12, 6))
bars = plt.bar(names, sizes)
plt.title("Model Sizes After Different Pruning Techniques")
plt.xlabel("Pruning Technique")
plt.ylabel("Number of Parameters")
plt.xticks(rotation=45, ha='right')
# Add value labels on top of each bar
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height,
f'{height:,}',
ha='center', va='bottom', rotation=0)
plt.tight_layout()
plt.show()
# Explanation of why only custom structured pruning reduces parameter count
print("""
Explanation of pruning results:
1. Normal model: This is the baseline model with no pruning applied.
2. PyTorch Unstructured Pruning: This method sets individual weights to zero but doesn't remove them from the model structure.
The total parameter count remains unchanged because the weight tensors maintain their original shape.
3. PyTorch Structured Pruning: Similar to unstructured pruning, this method applies masks to entire rows or columns of weight tensors,
but it doesn't actually remove these structures from the model. The original tensor dimensions are preserved.
4. Custom Unstructured Pruning: This method also sets individual weights to zero without changing the model structure,
resulting in no change to the total parameter count.
5. Custom Structured Pruning: This is the only method that actually reduces the parameter count because:
a) It identifies entire structural components (e.g., filters in convolutional layers) to remove.
b) It creates a new, smaller model structure that doesn't include these pruned components.
c) It adjusts the connections between layers to account for the removed components.
The key difference with custom structured pruning is that it modifies the model architecture itself,
removing entire structural components and their associated connections. This results in a model with fewer parameters overall.
To achieve actual parameter reduction and potential speed improvements, you need to:
1. Identify which structures to remove (e.g., using L1 or L2 norm of filters).
2. Create a new, smaller model architecture based on the pruning results.
3. Copy the remaining weights to this new, smaller model structure.
This is why only the custom structured pruning approach shows a decrease in the total number of model parameters.
""")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment