Skip to content

Instantly share code, notes, and snippets.

@PistonY

PistonY/flops.py

Created May 23, 2019
Embed
What would you like to do?
Flops for Gluon
# -*- coding: utf-8 -*-
# Author: pistonyang@gmail.com
from collections import OrderedDict
from mxnet import ndarray
from mxnet.gluon.nn import HybridBlock
def summary(block, *inputs):
"""Print the summary of the model's output and parameters.
The network must have been initialized, and must not have been hybridized.
Parameters
----------
inputs : object
Any input that the model supports. For any tensor in the input, only
:class:`mxnet.ndarray.NDArray` is supported.
"""
summary = OrderedDict()
seen = set()
hooks = []
def _get_shape_str(args):
def flatten(args):
if not isinstance(args, (list, tuple)):
return [args], int(0)
flat = []
fmts = []
for i in args:
arg, fmt = flatten(i)
flat.extend(arg)
fmts.append(fmt)
return flat, fmts
def regroup(args, fmt):
if isinstance(fmt, int):
if fmt == 0:
return args[0], args[1:]
return args[:fmt], args[fmt:]
ret = []
for i in fmt:
res, args = regroup(args, i)
ret.append(res)
return ret, args
flat_args, fmts = flatten(args)
flat_arg_shapes = [x.shape if isinstance(x, ndarray.NDArray) else x
for x in flat_args]
shapes = regroup(flat_arg_shapes, fmts)[0]
if isinstance(shapes, list):
shape_str = str(shapes)[1:-1]
else:
shape_str = str(shapes)
return shape_str.replace('L', '')
def _flops_str(flops):
preset = [(1e12, 'T'), (1e9, 'G'), (1e6, 'M'), (1e3, 'K')]
for p in preset:
if flops // p[0] > 0:
N = flops / p[0]
ret = "%.1f%s" % (N, p[1])
return ret
ret = "%.1f" % flops
return ret
def _calculate_conv2d_flops(block, output):
flops = 0
o_w = output[2]
o_h = output[3]
for i, p in enumerate(block.params.values()):
# weight
if i == 0:
weisht_shape = p.data().shape
o_c = weisht_shape[0]
i_c = weisht_shape[1]
ker_w = weisht_shape[2]
ker_h = weisht_shape[3]
groups = block._kwargs['num_group']
flops += i_c * ker_h * ker_w * o_c * o_w * o_h / groups
# bias
elif i == 1:
bias_shape = p.data().shape[0]
flops += bias_shape * o_h * o_w
else:
raise NotImplementedError
return flops
def _calculate_dense_flops(block):
# print(block.params.values())
flops = 0
for i, p in enumerate(block.params.values()):
# weight
if i == 0:
weisht_shape = p.data().shape
flops += 2 * weisht_shape[0] * weisht_shape[1] - weisht_shape[1]
# bias
elif i == 1:
flops += p.data().shape[0]
else:
raise NotImplementedError
return flops
def _register_summary_hook(block):
assert not isinstance(block, HybridBlock) or not block._active, \
'"{}" must not be hybridized to print summary.'.format(block.name)
def _summary_hook(block, inputs, outputs):
class_name = block.__class__.__name__
block_idx = len(summary) - 1
m_key = '%s-%i' % (class_name, block_idx + 1)
summary[m_key] = OrderedDict()
summary[m_key]['output_shape'] = _get_shape_str(outputs)
params = 0
summary[m_key]['trainable'] = 0
summary[m_key]['shared'] = 0
for p in block.params.values():
params += p.data().size
summary[m_key]['trainable'] += 0 if p.grad_req == 'null' else p.data().size
if p in seen:
summary[m_key]['shared'] += p.data().size
else:
seen.add(p)
summary[m_key]['n_params'] = params
flops = 0
if class_name == 'Conv2D':
flops += _calculate_conv2d_flops(block, outputs.shape)
elif class_name == 'Dense':
flops += _calculate_dense_flops(block)
else:
pass
summary[m_key]['n_flops'] = int(flops)
from mxnet.gluon.nn.basic_layers import Sequential, HybridSequential
if not isinstance(block, (Sequential, HybridSequential)):
hooks.append(block.register_forward_hook(_summary_hook))
summary['Input'] = OrderedDict()
summary['Input']['output_shape'] = _get_shape_str(inputs)
summary['Input']['n_flops'] = 0
summary['Input']['n_params'] = 0
summary['Input']['trainable'] = 0
summary['Input']['shared'] = 0
try:
block.apply(_register_summary_hook)
block(*inputs)
line_format = '{:>20} {:>42} {:>15} {:>15}'
print('-' * 96)
print(line_format.format('Layer (type)', 'Output Shape', 'FLOPs', 'Param #'))
print('=' * 96)
total_flops = 0
total_params = 0
trainable_params = 0
shared_params = 0
for layer in summary:
print(line_format.format(layer,
str(summary[layer]['output_shape']),
summary[layer]['n_flops'],
summary[layer]['n_params']))
total_flops += summary[layer]['n_flops']
total_params += summary[layer]['n_params']
trainable_params += summary[layer]['trainable']
shared_params += summary[layer]['shared']
print('=' * 96)
print('Parameters in forward computation graph, duplicate included')
print(' Total FLOPs: ' + str(total_flops) + " " + _flops_str(total_flops))
print(' Total params: ' + str(total_params))
print(' Trainable params: ' + str(trainable_params))
print(' Non-trainable params: ' + str(total_params - trainable_params))
print('Shared params in forward computation graph: ' + str(shared_params))
print('Unique parameters in model: ' + str(total_params - shared_params))
print('-' * 80)
finally:
for h in hooks:
h.detach()
if __name__ == '__main__':
import mxnet as mx
from mxnet import nd
from gluoncv.model_zoo.resnet import *
ctx = mx.gpu()
dt = nd.random.randn(1, 3, 224, 224, ctx=ctx)
model = resnet50_v1()
model.initialize(ctx=ctx)
summary(model, dt)
@PistonY

This comment has been minimized.

Copy link
Owner Author

@PistonY PistonY commented May 23, 2019

------------------------------------------------------------------------------------------------
        Layer (type)                                Output Shape           FLOPs         Param #
================================================================================================
               Input                            (1, 3, 224, 224)               0               0
            Conv2D-1                           (1, 64, 112, 112)       118013952            9408
         BatchNorm-2                           (1, 64, 112, 112)               0             256
        Activation-3                           (1, 64, 112, 112)               0               0
         MaxPool2D-4                             (1, 64, 56, 56)               0               0
            Conv2D-5                             (1, 64, 56, 56)        13045760            4160
         BatchNorm-6                             (1, 64, 56, 56)               0             256
        Activation-7                             (1, 64, 56, 56)               0               0
            Conv2D-8                             (1, 64, 56, 56)       115605504           36864
         BatchNorm-9                             (1, 64, 56, 56)               0             256
       Activation-10                             (1, 64, 56, 56)               0               0
           Conv2D-11                            (1, 256, 56, 56)        52183040           16640
        BatchNorm-12                            (1, 256, 56, 56)               0            1024
           Conv2D-13                            (1, 256, 56, 56)        51380224           16384
        BatchNorm-14                            (1, 256, 56, 56)               0            1024
     BottleneckV1-15                            (1, 256, 56, 56)               0               0
           Conv2D-16                             (1, 64, 56, 56)        51580928           16448
        BatchNorm-17                             (1, 64, 56, 56)               0             256
       Activation-18                             (1, 64, 56, 56)               0               0
           Conv2D-19                             (1, 64, 56, 56)       115605504           36864
        BatchNorm-20                             (1, 64, 56, 56)               0             256
       Activation-21                             (1, 64, 56, 56)               0               0
           Conv2D-22                            (1, 256, 56, 56)        52183040           16640
        BatchNorm-23                            (1, 256, 56, 56)               0            1024
     BottleneckV1-24                            (1, 256, 56, 56)               0               0
           Conv2D-25                             (1, 64, 56, 56)        51580928           16448
        BatchNorm-26                             (1, 64, 56, 56)               0             256
       Activation-27                             (1, 64, 56, 56)               0               0
           Conv2D-28                             (1, 64, 56, 56)       115605504           36864
        BatchNorm-29                             (1, 64, 56, 56)               0             256
       Activation-30                             (1, 64, 56, 56)               0               0
           Conv2D-31                            (1, 256, 56, 56)        52183040           16640
        BatchNorm-32                            (1, 256, 56, 56)               0            1024
     BottleneckV1-33                            (1, 256, 56, 56)               0               0
           Conv2D-34                            (1, 128, 28, 28)        25790464           32896
        BatchNorm-35                            (1, 128, 28, 28)               0             512
       Activation-36                            (1, 128, 28, 28)               0               0
           Conv2D-37                            (1, 128, 28, 28)       115605504          147456
        BatchNorm-38                            (1, 128, 28, 28)               0             512
       Activation-39                            (1, 128, 28, 28)               0               0
           Conv2D-40                            (1, 512, 28, 28)        51781632           66048
        BatchNorm-41                            (1, 512, 28, 28)               0            2048
           Conv2D-42                            (1, 512, 28, 28)       102760448          131072
        BatchNorm-43                            (1, 512, 28, 28)               0            2048
     BottleneckV1-44                            (1, 512, 28, 28)               0               0
           Conv2D-45                            (1, 128, 28, 28)        51480576           65664
        BatchNorm-46                            (1, 128, 28, 28)               0             512
       Activation-47                            (1, 128, 28, 28)               0               0
           Conv2D-48                            (1, 128, 28, 28)       115605504          147456
        BatchNorm-49                            (1, 128, 28, 28)               0             512
       Activation-50                            (1, 128, 28, 28)               0               0
           Conv2D-51                            (1, 512, 28, 28)        51781632           66048
        BatchNorm-52                            (1, 512, 28, 28)               0            2048
     BottleneckV1-53                            (1, 512, 28, 28)               0               0
           Conv2D-54                            (1, 128, 28, 28)        51480576           65664
        BatchNorm-55                            (1, 128, 28, 28)               0             512
       Activation-56                            (1, 128, 28, 28)               0               0
           Conv2D-57                            (1, 128, 28, 28)       115605504          147456
        BatchNorm-58                            (1, 128, 28, 28)               0             512
       Activation-59                            (1, 128, 28, 28)               0               0
           Conv2D-60                            (1, 512, 28, 28)        51781632           66048
        BatchNorm-61                            (1, 512, 28, 28)               0            2048
     BottleneckV1-62                            (1, 512, 28, 28)               0               0
           Conv2D-63                            (1, 128, 28, 28)        51480576           65664
        BatchNorm-64                            (1, 128, 28, 28)               0             512
       Activation-65                            (1, 128, 28, 28)               0               0
           Conv2D-66                            (1, 128, 28, 28)       115605504          147456
        BatchNorm-67                            (1, 128, 28, 28)               0             512
       Activation-68                            (1, 128, 28, 28)               0               0
           Conv2D-69                            (1, 512, 28, 28)        51781632           66048
        BatchNorm-70                            (1, 512, 28, 28)               0            2048
     BottleneckV1-71                            (1, 512, 28, 28)               0               0
           Conv2D-72                            (1, 256, 14, 14)        25740288          131328
        BatchNorm-73                            (1, 256, 14, 14)               0            1024
       Activation-74                            (1, 256, 14, 14)               0               0
           Conv2D-75                            (1, 256, 14, 14)       115605504          589824
        BatchNorm-76                            (1, 256, 14, 14)               0            1024
       Activation-77                            (1, 256, 14, 14)               0               0
           Conv2D-78                           (1, 1024, 14, 14)        51580928          263168
        BatchNorm-79                           (1, 1024, 14, 14)               0            4096
           Conv2D-80                           (1, 1024, 14, 14)       102760448          524288
        BatchNorm-81                           (1, 1024, 14, 14)               0            4096
     BottleneckV1-82                           (1, 1024, 14, 14)               0               0
           Conv2D-83                            (1, 256, 14, 14)        51430400          262400
        BatchNorm-84                            (1, 256, 14, 14)               0            1024
       Activation-85                            (1, 256, 14, 14)               0               0
           Conv2D-86                            (1, 256, 14, 14)       115605504          589824
        BatchNorm-87                            (1, 256, 14, 14)               0            1024
       Activation-88                            (1, 256, 14, 14)               0               0
           Conv2D-89                           (1, 1024, 14, 14)        51580928          263168
        BatchNorm-90                           (1, 1024, 14, 14)               0            4096
     BottleneckV1-91                           (1, 1024, 14, 14)               0               0
           Conv2D-92                            (1, 256, 14, 14)        51430400          262400
        BatchNorm-93                            (1, 256, 14, 14)               0            1024
       Activation-94                            (1, 256, 14, 14)               0               0
           Conv2D-95                            (1, 256, 14, 14)       115605504          589824
        BatchNorm-96                            (1, 256, 14, 14)               0            1024
       Activation-97                            (1, 256, 14, 14)               0               0
           Conv2D-98                           (1, 1024, 14, 14)        51580928          263168
        BatchNorm-99                           (1, 1024, 14, 14)               0            4096
    BottleneckV1-100                           (1, 1024, 14, 14)               0               0
          Conv2D-101                            (1, 256, 14, 14)        51430400          262400
       BatchNorm-102                            (1, 256, 14, 14)               0            1024
      Activation-103                            (1, 256, 14, 14)               0               0
          Conv2D-104                            (1, 256, 14, 14)       115605504          589824
       BatchNorm-105                            (1, 256, 14, 14)               0            1024
      Activation-106                            (1, 256, 14, 14)               0               0
          Conv2D-107                           (1, 1024, 14, 14)        51580928          263168
       BatchNorm-108                           (1, 1024, 14, 14)               0            4096
    BottleneckV1-109                           (1, 1024, 14, 14)               0               0
          Conv2D-110                            (1, 256, 14, 14)        51430400          262400
       BatchNorm-111                            (1, 256, 14, 14)               0            1024
      Activation-112                            (1, 256, 14, 14)               0               0
          Conv2D-113                            (1, 256, 14, 14)       115605504          589824
       BatchNorm-114                            (1, 256, 14, 14)               0            1024
      Activation-115                            (1, 256, 14, 14)               0               0
          Conv2D-116                           (1, 1024, 14, 14)        51580928          263168
       BatchNorm-117                           (1, 1024, 14, 14)               0            4096
    BottleneckV1-118                           (1, 1024, 14, 14)               0               0
          Conv2D-119                            (1, 256, 14, 14)        51430400          262400
       BatchNorm-120                            (1, 256, 14, 14)               0            1024
      Activation-121                            (1, 256, 14, 14)               0               0
          Conv2D-122                            (1, 256, 14, 14)       115605504          589824
       BatchNorm-123                            (1, 256, 14, 14)               0            1024
      Activation-124                            (1, 256, 14, 14)               0               0
          Conv2D-125                           (1, 1024, 14, 14)        51580928          263168
       BatchNorm-126                           (1, 1024, 14, 14)               0            4096
    BottleneckV1-127                           (1, 1024, 14, 14)               0               0
          Conv2D-128                              (1, 512, 7, 7)        25715200          524800
       BatchNorm-129                              (1, 512, 7, 7)               0            2048
      Activation-130                              (1, 512, 7, 7)               0               0
          Conv2D-131                              (1, 512, 7, 7)       115605504         2359296
       BatchNorm-132                              (1, 512, 7, 7)               0            2048
      Activation-133                              (1, 512, 7, 7)               0               0
          Conv2D-134                             (1, 2048, 7, 7)        51480576         1050624
       BatchNorm-135                             (1, 2048, 7, 7)               0            8192
          Conv2D-136                             (1, 2048, 7, 7)       102760448         2097152
       BatchNorm-137                             (1, 2048, 7, 7)               0            8192
    BottleneckV1-138                             (1, 2048, 7, 7)               0               0
          Conv2D-139                              (1, 512, 7, 7)        51405312         1049088
       BatchNorm-140                              (1, 512, 7, 7)               0            2048
      Activation-141                              (1, 512, 7, 7)               0               0
          Conv2D-142                              (1, 512, 7, 7)       115605504         2359296
       BatchNorm-143                              (1, 512, 7, 7)               0            2048
      Activation-144                              (1, 512, 7, 7)               0               0
          Conv2D-145                             (1, 2048, 7, 7)        51480576         1050624
       BatchNorm-146                             (1, 2048, 7, 7)               0            8192
    BottleneckV1-147                             (1, 2048, 7, 7)               0               0
          Conv2D-148                              (1, 512, 7, 7)        51405312         1049088
       BatchNorm-149                              (1, 512, 7, 7)               0            2048
      Activation-150                              (1, 512, 7, 7)               0               0
          Conv2D-151                              (1, 512, 7, 7)       115605504         2359296
       BatchNorm-152                              (1, 512, 7, 7)               0            2048
      Activation-153                              (1, 512, 7, 7)               0               0
          Conv2D-154                             (1, 2048, 7, 7)        51480576         1050624
       BatchNorm-155                             (1, 2048, 7, 7)               0            8192
    BottleneckV1-156                             (1, 2048, 7, 7)               0               0
 GlobalAvgPool2D-157                             (1, 2048, 1, 1)               0               0
           Dense-158                                   (1, 1000)         4094952         2049000
        ResNetV1-159                                   (1, 1000)               0               0
================================================================================================
Parameters in forward computation graph, duplicate included
   Total FLOPs: 3866919400  3.9G
   Total params: 25629032
   Trainable params: 25575912
   Non-trainable params: 53120
Shared params in forward computation graph: 0
Unique parameters in model: 25629032
--------------------------------------------------------------------------------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.