Skip to content

Instantly share code, notes, and snippets.

@zhreshold
Last active June 23, 2017 17:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zhreshold/4b3bcf3a8b51e1dc49b3c20834e5e401 to your computer and use it in GitHub Desktop.
Save zhreshold/4b3bcf3a8b51e1dc49b3c20834e5e401 to your computer and use it in GitHub Desktop.
inflated network
import mxnet as mx
def inflated_layer(data, num_in, num_out, name):
assert(num_out % num_in == 0)
num_group = num_out / num_in
outputs = []
for i in range(num_group):
bias = mx.sym.Variable(shape=(1, num_in, 1, 1),
name="{}_{}_bias".format(name, i))
outputs.append(mx.sym.broadcast_add(lhs=data, rhs=bias))
return mx.sym.Concat(*outputs, dim=1)
def inflated_group(data, ratio, num_out, kernel, pad,
stride, name):
num_inter = num_out / ratio
conv = mx.sym.Convolution(data=data, kernel=kernel, pad=pad, stride=stride,
num_filter=num_inter, name=name+'_conv')
if ratio > 1:
conv = inflated_layer(conv, num_inter, num_out, name)
bn = mx.sym.BatchNorm(data=conv)
relu = mx.sym.Activation(data=bn, act_type='relu')
return relu
def get_symbol(num_classes, **kwargs):
data = mx.sym.Variable(name='data')
# first standard conv
conv1 = mx.sym.Convolution(data=data, num_filter=32, kernel=(3, 3), pad=(1, 1),
stride=(2, 2), name='conv1')
bn1 = mx.sym.BatchNorm(data=conv1)
relu1 = mx.sym.Activation(data=bn1, act_type='relu')
# inflated
ratios = [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 4, 4]
filters = [64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024]
strides = [1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1]
x = relu1
index = 2
for r, nf, ns in zip(ratios, filters, strides):
x = inflated_group(data=x, ratio=r, num_out=nf, kernel=(3, 3),
pad=(1, 1), stride=(ns, ns), name='conv{}'.format(index))
index += 1
# avg pool
pool = mx.sym.Pooling(data=x, pool_type='avg', global_pool=True, kernel=(7, 7))
flat = mx.sym.Flatten(data=pool)
fc = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc')
softmax = mx.sym.SoftmaxOutput(data=fc, name='softmax')
return softmax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment