Skip to content

Instantly share code, notes, and snippets.

@NHZlX
Created March 26, 2018 06:34
Show Gist options
  • Save NHZlX/ca05f002072995307de9914a81b8b7b2 to your computer and use it in GitHub Desktop.
Save NHZlX/ca05f002072995307de9914a81b8b7b2 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding=utf-8
from __future__ import print_function
#from mobilenet import mobile_net
#from mobilenet_ssd_net_pascal import net_conf
#from mobilenet_ssd_net_face_2 import net_conf
#from new_mobilenet_ssd_net_face_4 import net_conf
#from vgg_ssd_net_v4_4 import net_conf
import paddle.v2 as paddle
from paddle.v2.topology import Topology
#from enet import Enet
from mobilenet import mobile_net
#from enet_depthwise import Enet
#out = Enet(3 * 128 * 128)
datadim = 3 * 32 * 32
out = mobile_net(datadim, 10, 1.0)
#out, _ = net_conf('train', 1.0)
#out = mobile_net(3 * 160 * 160, 102, 0.75)
topo = Topology(out).proto()
layers = topo.layers
sum = 0.0
layer_name = []
parameter_num = []
cals_num = []
in_cs = []
out_cs = []
all_groups = []
all_kernels = []
all_outsize = []
def analysis():
global sum
for layer in layers:
if 'conv' in layer.type:
layer_name.append(layer.name)
ins = layer.inputs[0]
conf = ins.conv_conf
in_c = conf.channels
out_c = conf.filter_channels
k = conf.filter_size
out_s = conf.output_x
groups = conf.groups
params = in_c * out_c * k * k
cals = out_s * out_s * k * k * in_c * out_c / float(groups)
sum += cals
all_kernels.append(k)
all_groups.append(groups)
parameter_num.append(params)
cals_num.append(cals)
in_cs.append(in_c)
out_cs.append(out_c)
all_outsize.append(out_s)
pass
for i in xrange(len(layer_name)):
print (layer_name[i], end='\t')
print ('kernel: %d groups: %d in_c: %d out_c: %d out_size: %d' %(all_kernels[i], all_groups[i], in_cs[i], out_cs[i], all_outsize[i]), end='\t')
print ('param num: ', parameter_num[i], end='\t')
print ('cals_num: ', cals_num[i], end= '\t')
print ('cals ratio: ', cals_num[i] / sum, end = '\t')
print ('', end='\n')
print ('all calcs: ', sum)
if __name__ == '__main__':
analysis()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment