Skip to content

Instantly share code, notes, and snippets.

@zhreshold
Last active April 27, 2018 13:00
Show Gist options
  • Save zhreshold/c9d0e9f7d5c3e6f6063d85980d814e52 to your computer and use it in GitHub Desktop.
Save zhreshold/c9d0e9f7d5c3e6f6063d85980d814e52 to your computer and use it in GitHub Desktop.
python mxnet/example/image-classification/train_imagenet.py --network shufflenet --data-train ~/efs/users/joshuazz/data/imagenet/record/train_480_q95.rec --data-val ~/efs/users/joshuazz/data/imagenet/record/val_256_q90.rec --batch-size 512 --gpus 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 --num-epochs 150 --lr-step-epochs 30,60,90 --min-random-scale 0.533 --lr 0.01 --disp-batches 100 --top-k 5 --data-nthreads 32 --random-mirror 1 --max-random-shear-ratio 0 --max-random-rotate-angle 0 --max-random-h 0 --max-random-l 0 --max-random-s 0 --model-prefix model/shufflenet | tee ~/efs/users/joshuazz/temp/train_imagenet_logs/shufflenet.log
import mxnet as mx
def combine(residual, data, combine):
if combine == 'add':
return residual + data
elif combine == 'concat':
return mx.sym.concat(residual, data, dim=1)
return None
def channel_shuffle(data, groups):
data = mx.sym.reshape(data, shape=(0, -4, groups, -1, -2))
data = mx.sym.swapaxes(data, 1, 2)
data = mx.sym.reshape(data, shape=(0, -3, -2))
return data
def shuffleUnit(residual, in_channels, out_channels, combine_type, groups=3, grouped_conv=True):
if combine_type == 'add':
DWConv_stride = 1
elif combine_type == 'concat':
DWConv_stride = 2
out_channels -= in_channels
first_groups = groups if grouped_conv else 1
bottleneck_channels = out_channels // 4
data = mx.sym.Convolution(data=residual, num_filter=bottleneck_channels,
kernel=(1, 1), stride=(1, 1), num_group=first_groups)
data = mx.sym.BatchNorm(data=data)
data = mx.sym.Activation(data=data, act_type='relu')
data = channel_shuffle(data, groups)
data = mx.sym.Convolution(data=data, num_filter=bottleneck_channels, kernel=(3, 3),
pad=(1, 1), stride=(DWConv_stride, DWConv_stride), num_group=groups)
data = mx.sym.BatchNorm(data=data)
data = mx.sym.Convolution(data=data, num_filter=out_channels,
kernel=(1, 1), stride=(1, 1), num_group=groups)
data = mx.sym.BatchNorm(data=data)
if combine_type == 'concat':
residual = mx.sym.Pooling(data=residual, kernel=(3, 3), pool_type='avg',
stride=(2, 2), pad=(1, 1))
data = combine(residual, data, combine_type)
return data
def make_stage(data, stage, groups=3):
stage_repeats = [3, 7, 3]
grouped_conv = stage > 2
if groups == 1:
out_channels = [-1, 24, 144, 288, 567]
elif groups == 2:
out_channels = [-1, 24, 200, 400, 800]
elif groups == 3:
out_channels = [-1, 24, 240, 480, 960]
elif groups == 4:
out_channels = [-1, 24, 272, 544, 1088]
elif groups == 8:
out_channels = [-1, 24, 384, 768, 1536]
data = shuffleUnit(data, out_channels[stage - 1], out_channels[stage],
'concat', groups, grouped_conv)
for i in range(stage_repeats[stage - 2]):
data = shuffleUnit(data, out_channels[stage], out_channels[stage],
'add', groups, True)
return data
def get_shufflenet(num_classes=10):
data = mx.sym.var('data')
data = mx.sym.Convolution(data=data, num_filter=24,
kernel=(3, 3), stride=(2, 2), pad=(1, 1))
data = mx.sym.Pooling(data=data, kernel=(3, 3), pool_type='max',
stride=(2, 2), pad=(1, 1))
data = make_stage(data, 2)
data = make_stage(data, 3)
data = make_stage(data, 4)
data = mx.sym.Pooling(data=data, kernel=(1, 1), global_pool=True, pool_type='avg')
data = mx.sym.flatten(data=data)
data = mx.sym.FullyConnected(data=data, num_hidden=num_classes)
out = mx.sym.SoftmaxOutput(data=data, name='softmax')
return out
@zzningxp
Copy link

"bottleneck_channels = out_channels // 4" should be at the place before "out_channels -= in_channels"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment