Skip to content

Instantly share code, notes, and snippets.

@gngdb
Last active October 17, 2019 20:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gngdb/be8944dfae73f143e09c901c7a1229ea to your computer and use it in GitHub Desktop.
Save gngdb/be8944dfae73f143e09c901c7a1229ea to your computer and use it in GitHub Desktop.
Parameter redundancy from 5 pruning papers through time.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Making a figure to show the progression of parameter redundancy through time; neural networks used to be able to lose most of their parameters without affecting accuracy, but this is no longer true."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.style.use('ggplot')\n",
"import numpy as np\n",
"import published # load csv\n",
"results = published.pruning"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'year': [2015.0, 2016.0, 2016.0, 2017.0, 2017.0],\n",
" 'top-1 before': [31.5, 25.03, 31.9, 36.69, 25.02],\n",
" 'top-1 after': [31.17, 27.83, 32.0, 36.66, 26.2],\n",
" 'params before': [138000000.0,\n",
" 21800000.0,\n",
" 15190000.0,\n",
" 132900000.0,\n",
" 9500000.0],\n",
" 'params after': [10350000.0, 19300000.0, 7450000.0, 23200000.0, 4800000.0],\n",
" 'bibtex_id': ['han2016deep',\n",
" 'li2016pruning',\n",
" 'alvarez2016learning',\n",
" 'liu2017learning',\n",
" 'huang2017condensenet'],\n",
" 'title': ['Deep Compression',\n",
" 'Pruning Filters for Efficient ConvNets',\n",
" 'Learning The Number of Neurons in Deep Networks',\n",
" 'Learning Efficient Convolutional Networks through Network Slimming',\n",
" 'CondenseNet: An Efficient DenseNet using Learned Group Convolutions']}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x324 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"labels = results['bibtex_id']\n",
"before = results['params before']\n",
"after = results['params after']\n",
"\n",
"\n",
"x = np.arange(len(labels)) # the label locations\n",
"width = 0.35 # the width of the bars\n",
"\n",
"fig, ax = plt.subplots(1,2)\n",
"axl, axr = ax\n",
"rects1 = axl.bar(x - width/2, before, width, label='Baseline', alpha=0.5)\n",
"rects2 = axl.bar(x + width/2, after, width, label='Pruned', alpha=0.5)\n",
"\n",
"# Parameters barplot\n",
"axl.set_ylabel('Parameters')\n",
"axl.set_title('Parameters Before/After Pruning')\n",
"axl.set_xticks(x)\n",
"axl.set_xticklabels(labels, rotation=70)\n",
"axl.legend()\n",
"\n",
"def label(rects):\n",
" \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
" for rect in rects:\n",
" height = rect.get_height()\n",
" axl.annotate('{}'.format(height),\n",
" xy=(rect.get_x() + rect.get_width() / 2, height),\n",
" xytext=(0, 3), # 3 points vertical offset\n",
" textcoords=\"offset points\",\n",
" ha='center', va='bottom')\n",
"\n",
"cratios = [a/b for b,a in zip(before, after)]\n",
"for rect, ratio in zip(rects2, cratios):\n",
" height = rect.get_height()\n",
" axl.annotate(f\"{100.*ratio:.2f}%\",\n",
" xy=(rect.get_x() + rect.get_width() / 2, height),\n",
" xytext=(0, 3), # 3 points vertical offset\n",
" textcoords=\"offset points\",\n",
" ha='center', va='bottom')\n",
"\n",
"#rects = axr.bar(x, results['top-1 after'], width, alpha=0.5, color='green')\n",
"#axr.set_ylabel('top-1 error')\n",
"#axr.set_title('Top-1 Error Progression')\n",
"#axr.set_xticks(x)\n",
"#axr.set_xticklabels(labels, rotation=70)\n",
"\n",
"# 50,000 examples in the validation set\n",
"# http://www.image-net.org/challenges/LSVRC/2012/\n",
"examples_parameters = [1000000*(1.-(e/100.))/p for e, p in zip(results['top-1 after'], after)]\n",
"rects = axr.bar(x, examples_parameters, width, alpha=0.5)\n",
"axr.set_ylabel('$10^6$*Accuracy/Parameters')\n",
"axr.set_title('Correct ImageNet Examples Per Parameter')\n",
"axr.set_xticks(x)\n",
"axr.set_xticklabels(labels, rotation=70)\n",
"for rect, year in zip(rects, results['year']):\n",
" height = rect.get_height()\n",
" axr.annotate(f\"{int(year)}\",\n",
" xy=(rect.get_x() + rect.get_width() / 2, height),\n",
" xytext=(0, 3), # 3 points vertical offset\n",
" textcoords=\"offset points\",\n",
" ha='center', va='bottom')\n",
"\n",
"fig.set_size_inches(10., 4.5)\n",
"\n",
"fig.tight_layout()\n",
"\n",
"plt.savefig(\"pruning-progression.pdf\", bbox_inches='tight')\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# using python file to store data, it's weird I know
table =\
"""year, top-1 before, top-1 after, params before, params after, bibtex_id, title
2015, 31.50, 31.17, 138e6, 10.35e6, han2016deep, Deep Compression
2016, 25.03, 27.83, 21.8e6, 1.93e7, li2016pruning, Pruning Filters for Efficient ConvNets
2016, 31.9, 32.0, 15.19e6, 7.45e6, alvarez2016learning, Learning The Number of Neurons in Deep Networks
2017, 36.69, 36.66, 132.9e6, 23.2e6, liu2017learning, Learning Efficient Convolutional Networks through Network Slimming
2017, 25.02, 26.2, 9.5e6, 4.8e6, huang2017condensenet, CondenseNet: An Efficient DenseNet using Learned Group Convolutions
"""
import csv
from io import StringIO
with StringIO(table) as f:
reader = csv.reader(f, delimiter=',', skipinitialspace=True)
for i, r in enumerate(reader):
if i == 0:
cols = r
pruning = {c:[] for c in cols}
else:
for c,v in zip(cols, r):
pruning[c].append(v if c in ('bibtex_id', 'title') else float(v))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment