Skip to content

Instantly share code, notes, and snippets.

@mikigom
Created May 28, 2017 23:27
Show Gist options
  • Save mikigom/757bd0132ace8bf914a82b8d7c200da4 to your computer and use it in GitHub Desktop.
Save mikigom/757bd0132ace8bf914a82b8d7c200da4 to your computer and use it in GitHub Desktop.
import tensorflow as tf
slim = tf.contrib.slim
class Unet(object):
def __init__(self, input, class_num, reuse = False):
self.input = input
self.reuse = reuse
self.class_num = class_num
self.build_model()
def build_model(self):
# (30, 572, 572, 1)
self.input = tf.expand_dims(self.input, -1)
if self.reuse:
tf.get_variable_scope().reuse_variables()
with tf.name_scope('encoder'):
with tf.name_scope('e0'):
# (30, 570, 570, 64)
self.e0_0 = lrelu(slim.conv2d(self.input, 64, [3, 3], scope = 'e0_0', padding = 'VALID', activation_fn = None))
# (30, 568, 568, 64)
self.e0_1 = lrelu(slim.conv2d(self.e0_0, 64, [3, 3], scope = 'e0_1', padding = 'VALID', activation_fn = None))
with tf.name_scope('max_pool1'):
# (30, 284, 284, 64)
self.max_pool1 = slim.max_pool2d(self.e0_1, [2, 2], scope = 'max_pool1')
with tf.name_scope('e1'):
# (30, 282, 282, 128)
self.e1_0 = BL(slim.conv2d(self.max_pool1, 128, [3, 3], scope = 'e1_0', padding = 'VALID', activation_fn = None))
# (30, 280, 280, 128)
self.e1_1 = BL(slim.conv2d(self.e1_0, 128, [3, 3], scope = 'e1_1', padding = 'VALID', activation_fn = None))
with tf.name_scope('max_pool2'):
# (30, 140, 140, 128)
self.max_pool2 = slim.max_pool2d(self.e1_1, [2, 2], scope = 'max_pool2')
with tf.name_scope('e2'):
# (30, 138, 138, 256)
self.e2_0 = BL(slim.conv2d(self.max_pool2, 256, [3, 3], scope = 'e2_0', padding = 'VALID', activation_fn = None))
# (30, 136, 136, 256)
self.e2_1 = BL(slim.conv2d(self.e2_0, 256, [3, 3], scope = 'e2_1', padding = 'VALID', activation_fn = None))
with tf.name_scope('max_pool3'):
# (30, 68, 68, 256)
self.max_pool3 = slim.max_pool2d(self.e2_1, [2, 2], scope = 'max_pool3')
with tf.name_scope('e3'):
# (30, 66, 66, 512)
self.e3_0 = BL(slim.conv2d(self.max_pool3, 512, [3, 3], scope = 'e3_0', padding = 'VALID', activation_fn = None))
# (30, 64, 64, 512)
self.e3_1 = BL(slim.conv2d(self.e3_0, 512, [3, 3], scope = 'e3_1', padding = 'VALID', activation_fn = None))
with tf.name_scope('middle'):
with tf.name_scope('max_pool4'):
# (30, 32, 32, 512)
self.max_pool4 = slim.max_pool2d(self.e3_1, [2, 2], scope = 'max_pool4')
with tf.name_scope('middle_conv'):
self.middle_0 = BL(slim.conv2d(self.max_pool4, 1024, [3, 3], scope = 'middle_0', padding = 'SAME', activation_fn = None))
self.middle_1 = BL(slim.conv2d(self.middle_0, 1024, [3, 3], scope = 'middle_1', padding = 'VALID', activation_fn = None))
self.middle_2 = BDR(slim.conv2d(self.middle_1, 1024, [3, 3], scope = 'middle_2', padding = 'SAME', activation_fn = None))
self.middle_3 = BDR(slim.conv2d(self.middle_2, 1024, [3, 3], scope = 'middle_3', padding = 'VALID', activation_fn = None))
with tf.name_scope('up_conv1'):
# (30, 56, 56, 512)
self.up_conv1 = slim.batch_norm(slim.conv2d_transpose(self.middle_3, 512, [2, 2], [2, 2], scope = 'up_conv1', padding = 'VALID', activation_fn = None), activation_fn = None)
with tf.name_scope('decoder'):
with tf.name_scope('d0'):
_, self.e3_1_w, self.e3_1_h, __ = self.e3_1.get_shape().as_list()
_, self.up_conv1_w, self.up_conv1_h, __ = self.up_conv1.get_shape().as_list()
self.crop_0 = self.e3_1[:, self.e3_1_w/2 - self.up_conv1_w/2 : self.e3_1_w/2 + self.up_conv1_w/2,\
self.e3_1_h/2 - self.up_conv1_h/2 : self.e3_1_h/2 + self.up_conv1_w/2, :]
# (30, 54, 54, 512)
self.d0_0 = BR(slim.conv2d(tf.concat([self.up_conv1, self.crop_0], axis = 3), 512, [3, 3], scope = 'd0_0', padding = 'VALID', activation_fn = None))
# (30, 52, 52, 512)
self.d0_1 = BR(slim.conv2d(self.d0_0, 512, [3, 3], scope = 'd0_1', padding = 'VALID', activation_fn = None))
with tf.name_scope('up_conv2'):
# (30, 104, 104, 256)
self.up_conv2 = slim.batch_norm(slim.conv2d_transpose(self.d0_1, 256, [2, 2], [2, 2], scope = 'up_conv2', padding = 'VALID', activation_fn = None), activation_fn = None)
with tf.name_scope('d1'):
_, self.e2_1_w, self.e2_1_h, __ = self.e2_1.get_shape().as_list()
_, self.up_conv2_w, self.up_conv2_h, __ = self.up_conv2.get_shape().as_list()
self.crop_1 = self.e2_1[:, self.e2_1_w/2 - self.up_conv2_w/2 : self.e2_1_w/2 + self.up_conv2_w/2,\
self.e2_1_h/2 - self.up_conv2_h/2 : self.e2_1_h/2 + self.up_conv2_w/2, :]
# (30, 102, 102, 256)
self.d1_0 = BR(slim.conv2d(tf.concat([self.up_conv2, self.crop_1], axis = 3), 256, [3, 3], scope = 'd1_0', padding = 'VALID', activation_fn = None))
# (30, 100, 100, 256)
self.d1_1 = BR(slim.conv2d(self.d1_0, 256, [3, 3], scope = 'd1_1', padding = 'VALID', activation_fn = None))
with tf.name_scope('up_conv3'):
# (30, 200, 200, 128)
self.up_conv3 = slim.batch_norm(slim.conv2d_transpose(self.d1_1, 128, [2, 2], [2, 2], scope = 'up_conv3', padding = 'VALID', activation_fn = None), activation_fn = None)
with tf.name_scope('d2'):
_, self.e1_1_w, self.e1_1_h, __ = self.e1_1.get_shape().as_list()
_, self.up_conv3_w, self.up_conv3_h, __ = self.up_conv3.get_shape().as_list()
self.crop_2 = self.e1_1[:, self.e1_1_w/2 - self.up_conv3_w/2 : self.e1_1_w/2 + self.up_conv3_w/2,\
self.e1_1_h/2 - self.up_conv3_h/2 : self.e1_1_h/2 + self.up_conv3_w/2, :]
# (30, 198, 198, 128)
self.d2_0 = BR(slim.conv2d(tf.concat([self.up_conv3, self.crop_2], axis = 3), 128, [3, 3], scope = 'd2_0', padding = 'VALID', activation_fn = None))
# (30, 196, 196, 128)
self.d2_1 = BR(slim.conv2d(self.d2_0, 128, [3, 3], scope = 'd2_1', padding = 'VALID', activation_fn = None))
with tf.name_scope('up_conv4'):
# (30, 392, 392, 64)
self.up_conv4 = slim.batch_norm(slim.conv2d_transpose(self.d2_1, 64, [2, 2], [2, 2], scope = 'up_conv4', padding = 'VALID', activation_fn = None), activation_fn = None)
with tf.name_scope('d3'):
_, self.e0_1_w, self.e0_1_h, __ = self.e0_1.get_shape().as_list()
_, self.up_conv4_w, self.up_conv4_h, __ = self.up_conv4.get_shape().as_list()
self.crop_3 = self.e0_1[:, self.e0_1_w/2 - self.up_conv4_w/2 : self.e0_1_w/2 + self.up_conv4_w/2,\
self.e0_1_h/2 - self.up_conv4_h/2 : self.e0_1_h/2 + self.up_conv4_w/2, :]
# (30, 390, 390, 64)
self.d3_0 = BR(slim.conv2d(tf.concat([self.up_conv4, self.crop_3], axis = 3), 64, [3, 3], scope = 'd3_0', padding = 'VALID', activation_fn = None))
# (30, 388, 388, 64)
self.d3_1 = BR(slim.conv2d(self.d3_0, 64, [3, 3], scope = 'd3_1', padding = 'VALID', activation_fn = None))
with tf.name_scope('conv1x1'):
# (30, 388, 388, 2)
self.output = slim.conv2d(self.d3_1, self.class_num, [1, 1], scope = 'output', padding = 'VALID', activation_fn = tf.sigmoid)
def lrelu(x, leak=0.2, name="lrelu"):
return tf.maximum(x, leak*x)
def BL(x, leak = 0.2):
return lrelu(slim.batch_norm(x, activation_fn = None), leak)
def BDR(x):
return tf.nn.relu(slim.dropout(slim.batch_norm(x, activation_fn = None)))
def BR(x):
return tf.nn.relu(slim.batch_norm(x, activation_fn = None))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment