Skip to content

Instantly share code, notes, and snippets.

@llj098
Created December 3, 2016 03:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save llj098/0c27580364e27d53b1f386a248e36499 to your computer and use it in GitHub Desktop.
Save llj098/0c27580364e27d53b1f386a248e36499 to your computer and use it in GitHub Desktop.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.framework import add_arg_scope
from tensorflow.contrib.layers.python.layers import utils
slim = tf.contrib.slim
def squeeze(inputs, num_outputs, fire_id):
return slim.conv2d(inputs, num_outputs, 1, stride=1, scope="fire/squeeze/"+str(fire_id))
def expand(inputs, num_outputs, fire_id):
with tf.variable_scope('expand'):
e1x1 = slim.conv2d(inputs, num_outputs, 1, stride=1, scope='fire/ex1x1/'+str(fire_id))
e3x3 = slim.conv2d(inputs, num_outputs, 3, scope='fire/ex3x3/'+str(fire_id))
return tf.concat(3, [e1x1, e3x3])
def fire_module(x, s=16, e=64, fire_id=0):
return expand(squeeze(x, s, fire_id), e, fire_id)
def inference(images, _a, phase_train=True, weight_decay=0.8, reuse=None):
with tf.variable_scope('SqueezeNet', 'SqueezeNet', [images], reuse=reuse):
x_image = tf.reshape(images, [-1, 160, 160, 3])
net = slim.conv2d(x_image, 96, 7, scope="conv_1", stride=2) #[?,14,14,64]
net = slim.max_pool2d(net, 3, stride=2, scope="maxpool1") #[?,6,6,64]
net = fire_module(net, fire_id=2)
net = fire_module(net, fire_id=3)
net = fire_module(net, 32, 128, fire_id=4)
net = slim.max_pool2d(net, 3, scope="maxpool3")
net = fire_module(net, 32, 128, fire_id=5)
net = fire_module(net, 48, 192, fire_id=6)
net = fire_module(net, 48, 192, fire_id=7)
net = fire_module(net, 64, 256, fire_id=8)
net = slim.max_pool2d(net, 2, stride=2, scope="maxpool8")
net = fire_module(net, 64, 256, fire_id=9)
net = slim.dropout(net, 0.5)
net = slim.conv2d(net, 1792, 1, scope="conv10", padding="VALID")
net = slim.avg_pool2d(net, 9)
net = slim.flatten(net)
return net,{}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment