Skip to content

Instantly share code, notes, and snippets.

@achalddave
Created October 11, 2017 18:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save achalddave/48bbf8661e28c5d4a29f37298fa164e0 to your computer and use it in GitHub Desktop.
Save achalddave/48bbf8661e28c5d4a29f37298fa164e0 to your computer and use it in GitHub Desktop.
Compute number of flops in a CNN
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def convolution(kernel_width, kernel_height):\n",
" # A simple convolution costs k_w * k_h multiplications, and the same number of additions.\n",
" return kernel_width * kernel_height\n",
"\n",
"def convolution_layer(input_width, input_height, input_channels,\n",
" kernel_width, kernel_height, output_channels):\n",
" return convolution(kernel_width, kernel_height) * input_width * input_height * input_channels * output_channels\n",
"\n",
"def fc_layer(num_inputs, num_outputs):\n",
" return num_inputs * num_outputs"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"VGG-16:\t\t\t\t\t 15,466,434,560\n",
"Predictive Corrective fc7-4:\t\t 15,466,589,184\n",
"Additional parameters in fc7-4:\t\t 154,624, 0.0010%\n"
]
}
],
"source": [
"def vgg_convolution(image_size, input_channels, output_channels):\n",
" return convolution_layer(image_size, image_size, input_channels, 3, 3, output_channels)\n",
"\n",
"vgg_16 = [\n",
" # conv1\n",
" vgg_convolution(224, 3, 64),\n",
" vgg_convolution(224, 64, 64),\n",
" # conv2\n",
" vgg_convolution(112, 64, 128),\n",
" vgg_convolution(112, 128, 128),\n",
" # conv3\n",
" vgg_convolution(56, 128, 256),\n",
" vgg_convolution(56, 256, 256),\n",
" vgg_convolution(56, 256, 256),\n",
" # conv4\n",
" vgg_convolution(28, 256, 512),\n",
" vgg_convolution(28, 512, 512),\n",
" vgg_convolution(28, 512, 512),\n",
" # conv5\n",
" vgg_convolution(14, 512, 512),\n",
" vgg_convolution(14, 512, 512),\n",
" vgg_convolution(14, 512, 512),\n",
" # fc\n",
" fc_layer(7 * 7 * 512, 4096),\n",
" fc_layer(4096, 4096),\n",
" fc_layer(4096, 65)]\n",
"\n",
"lstm_vgg_16 = [\n",
" # conv1\n",
" vgg_convolution(224, 3, 64),\n",
" vgg_convolution(224, 64, 64),\n",
" # conv2\n",
" vgg_convolution(112, 64, 128),\n",
" vgg_convolution(112, 128, 128),\n",
" # conv3\n",
" vgg_convolution(56, 128, 256),\n",
" vgg_convolution(56, 256, 256),\n",
" vgg_convolution(56, 256, 256),\n",
" # conv4\n",
" vgg_convolution(28, 256, 512),\n",
" vgg_convolution(28, 512, 512),\n",
" vgg_convolution(28, 512, 512),\n",
" # conv5\n",
" vgg_convolution(14, 512, 512),\n",
" vgg_convolution(14, 512, 512),\n",
" vgg_convolution(14, 512, 512),\n",
" # fc\n",
" fc_layer(7 * 7 * 512, 4096),\n",
" fc_layer(4096, 4096),\n",
" fc_layer(4096, 300),\n",
" fc_layer(300, 65)]\n",
"\n",
"vgg_19 = [\n",
" # conv1\n",
" vgg_convolution(224, 3, 64),\n",
" vgg_convolution(224, 64, 64),\n",
" # conv2\n",
" vgg_convolution(112, 64, 128),\n",
" vgg_convolution(112, 128, 128),\n",
" # conv3\n",
" vgg_convolution(56, 128, 256),\n",
" vgg_convolution(56, 256, 256),\n",
" vgg_convolution(56, 256, 256),\n",
" vgg_convolution(56, 256, 256), # Added\n",
" # conv4\n",
" vgg_convolution(28, 256, 512),\n",
" vgg_convolution(28, 512, 512),\n",
" vgg_convolution(28, 512, 512),\n",
" vgg_convolution(28, 512, 512), # Added\n",
" # conv5\n",
" vgg_convolution(14, 512, 512),\n",
" vgg_convolution(14, 512, 512),\n",
" vgg_convolution(14, 512, 512),\n",
" vgg_convolution(14, 512, 512), # Added\n",
" # fc\n",
" fc_layer(7 * 7 * 512, 4096),\n",
" fc_layer(4096, 4096),\n",
" fc_layer(4096, 65)]\n",
"\n",
"# Predictive corrective: Conv3-3 reinit 1, FC7 reinit 4.\n",
"c33_1_fc7_4= [\n",
" # conv1\n",
" vgg_convolution(224, 3, 64),\n",
" vgg_convolution(224, 64, 64),\n",
" # conv2\n",
" vgg_convolution(112, 64, 128),\n",
" vgg_convolution(112, 128, 128),\n",
" # conv3\n",
" vgg_convolution(56, 128, 256),\n",
" vgg_convolution(56, 256, 256),\n",
" vgg_convolution(56, 256, 256),\n",
" # subtraction\n",
" 56 * 56 * 256,\n",
" # conv4\n",
" vgg_convolution(28, 256, 512),\n",
" vgg_convolution(28, 512, 512),\n",
" vgg_convolution(28, 512, 512),\n",
" # conv5\n",
" vgg_convolution(14, 512, 512),\n",
" vgg_convolution(14, 512, 512),\n",
" vgg_convolution(14, 512, 512),\n",
" # fc\n",
" fc_layer(7 * 7 * 512, 4096),\n",
" # addition\n",
" 4096,\n",
" fc_layer(4096, 4096),\n",
" fc_layer(4096, 65)]\n",
"\n",
"fc7_4 = [\n",
" # subtraction\n",
" 224 * 224 * 3,\n",
" # conv1\n",
" vgg_convolution(224, 3, 64),\n",
" vgg_convolution(224, 64, 64),\n",
" # conv2\n",
" vgg_convolution(112, 64, 128),\n",
" vgg_convolution(112, 128, 128),\n",
" # conv3\n",
" vgg_convolution(56, 128, 256),\n",
" vgg_convolution(56, 256, 256),\n",
" vgg_convolution(56, 256, 256),\n",
" # conv4\n",
" vgg_convolution(28, 256, 512),\n",
" vgg_convolution(28, 512, 512),\n",
" vgg_convolution(28, 512, 512),\n",
" # conv5\n",
" vgg_convolution(14, 512, 512),\n",
" vgg_convolution(14, 512, 512),\n",
" vgg_convolution(14, 512, 512),\n",
" # fc\n",
" fc_layer(7 * 7 * 512, 4096),\n",
" # addition\n",
" 4096,\n",
" fc_layer(4096, 4096),\n",
" fc_layer(4096, 65)]\n",
"\n",
"print('VGG-16:\\t\\t\\t\\t\\t {:,}'.format(sum(vgg_16)))\n",
"\n",
"new_params_fc7_4 = sum(fc7_4) - sum(vgg_16)\n",
"print('Predictive Corrective fc7-4:\\t\\t {:,}'.format(sum(fc7_4)))\n",
"print('Additional parameters in fc7-4:\\t\\t {:,}, {:.4f}%'.format(new_params_fc7_4, 100 * (new_params_fc7_4) / sum(vgg_16)))\n",
"\n",
"#new_params_c33_1_fc7_4 = sum(c33_1_fc7_4) - sum(vgg_16)\n",
"#print('Predictive Corrective c33-1,fc7-4:\\t {:,}'.format(sum(c33_1_fc7_4)))\n",
"#print('Additional parameters in c33-1_fc7-4:\\t {:,}, {:.4f}%'.format(new_params_c33_1_fc7_4, 100 * (new_params_c33_1_fc7_4) / sum(vgg_16)))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"15,467,416,620\n"
]
}
],
"source": [
"print('{:,}'.format(sum(lstm_vgg_16)))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.005349828524934073"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"100 * (sum(lstm_vgg_16) - sum(fc7_4)) / sum(fc7_4)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.0"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment