Last active
June 23, 2017 17:15
-
-
Save zhreshold/4b3bcf3a8b51e1dc49b3c20834e5e401 to your computer and use it in GitHub Desktop.
inflated network
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
def inflated_layer(data, num_in, num_out, name): | |
assert(num_out % num_in == 0) | |
num_group = num_out / num_in | |
outputs = [] | |
for i in range(num_group): | |
bias = mx.sym.Variable(shape=(1, num_in, 1, 1), | |
name="{}_{}_bias".format(name, i)) | |
outputs.append(mx.sym.broadcast_add(lhs=data, rhs=bias)) | |
return mx.sym.Concat(*outputs, dim=1) | |
def inflated_group(data, ratio, num_out, kernel, pad, | |
stride, name): | |
num_inter = num_out / ratio | |
conv = mx.sym.Convolution(data=data, kernel=kernel, pad=pad, stride=stride, | |
num_filter=num_inter, name=name+'_conv') | |
if ratio > 1: | |
conv = inflated_layer(conv, num_inter, num_out, name) | |
bn = mx.sym.BatchNorm(data=conv) | |
relu = mx.sym.Activation(data=bn, act_type='relu') | |
return relu | |
def get_symbol(num_classes, **kwargs): | |
data = mx.sym.Variable(name='data') | |
# first standard conv | |
conv1 = mx.sym.Convolution(data=data, num_filter=32, kernel=(3, 3), pad=(1, 1), | |
stride=(2, 2), name='conv1') | |
bn1 = mx.sym.BatchNorm(data=conv1) | |
relu1 = mx.sym.Activation(data=bn1, act_type='relu') | |
# inflated | |
ratios = [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 4, 4] | |
filters = [64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024] | |
strides = [1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1] | |
x = relu1 | |
index = 2 | |
for r, nf, ns in zip(ratios, filters, strides): | |
x = inflated_group(data=x, ratio=r, num_out=nf, kernel=(3, 3), | |
pad=(1, 1), stride=(ns, ns), name='conv{}'.format(index)) | |
index += 1 | |
# avg pool | |
pool = mx.sym.Pooling(data=x, pool_type='avg', global_pool=True, kernel=(7, 7)) | |
flat = mx.sym.Flatten(data=pool) | |
fc = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc') | |
softmax = mx.sym.SoftmaxOutput(data=fc, name='softmax') | |
return softmax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment