Last active May 25, 2020 02:54
UNet implementation of Matlab sample for semantic segmentation . Outputs are made on different hyperparameters.
import imageio
import numpy as np
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, Callback
from keras import optimizers
import keras.backend as K
import matplotlib.pyplot as plt
from models import Pix2Pix, SegNet, vgg19_unet, UNetMatlab
np.set_printoptions(threshold=64**4, linewidth=300)
def random_crop(image, top, left, crop_size):
bottom = top + crop_size[0]
right = left + crop_size[1]
image = image[top:bottom, left:right, :]
return image
def get_datagen(img_path, seg_path, img_size=(256, 256), batch_size=16, train=True, sample_weights=None):
img = imageio.imread(img_path, pilmode='RGB')
seg = imageio.imread(seg_path, pilmode='RGB')
seg_temp = np.copy(seg[400:2000, 270:2200, :])
img = img[400:2000, 270:2200, :]
seg = seg[400:2000, 270:2200, :2]
seg[:, :, 1] = 255 - seg[:, :, 0] # Index 0: Defect, Index 1: Background
img = img.astype(np.float)
seg = seg.astype(np.float)
img /= 255.
seg /= 255.
imgs = []
segs = []
h, w, _ = img.shape
while True:
# Crop
top = np.random.randint(0, h - img_size[0])
left = np.random.randint(0, w - img_size[1])
cropped_img = random_crop(img, top, left, img_size)
cropped_seg = random_crop(seg, top, left, img_size)
# Horizontal Flip
if np.random.rand() > 0.5 and train:
cropped_img = cropped_img[:, ::-1, :]
cropped_seg = cropped_seg[:, ::-1, :]
# Vertical Flip
if np.random.rand() > 0.5 and train:
cropped_img = cropped_img[::-1, :, :]
cropped_seg = cropped_seg[::-1, :, :]
# Noise
if train:
noise = 0.001 * np.random.randn(*cropped_img.shape)
cropped_img += noise
if len(imgs) == batch_size:
imgs_temp = np.array(imgs)
segs_temp = np.array(segs)
imgs = []
segs = []
if sample_weights is not None:
yield (imgs_temp, segs_temp, sample_weights)
yield (imgs_temp, segs_temp)
def decode_img(x):
x *= 255
x = x.astype(np.uint8)
return x
def decode_onehot(y):
y = decode_img(y)
zero_channel = np.zeros((*y.shape[:-1], 1), dtype=np.uint8)
y = np.concatenate((y, zero_channel), axis=-1)
y[:, :, :, 1] = 0
return y
def convert_prob_into_onehot(x):
t = tf.constant(value=x)
y = tf.one_hot(tf.argmax(t, dimension = -1), depth = 2)
return y.eval()
def weighted_crossentropy_wrapper(class_weights):
def weighted_cross_entropy(onehot_labels, output):
A quick wrapper to compute weighted cross entropy.
Technical Details
The class_weights list can be multiplied by onehot_labels directly because the last dimension
of onehot_labels is 12 and class_weights (length 12) can broadcast across that dimension, which is what we want.
Then we collapse the last dimension for the class_weights to get a shape of (batch_size, height, width, 1)
to get a mask with each pixel's value representing the class_weight.
This mask can then be that can be broadcasted to the intermediate output of logits
and onehot_labels when calculating the cross entropy loss.
- onehot_labels(Tensor): the one-hot encoded labels of shape (batch_size, height, width, num_classes)
- logits(Tensor): the logits output from the model that is of shape (batch_size, height, width, num_classes)
- class_weights(list): A list where each index is the class label and the value of the index is the class weight.
- loss(Tensor): a scalar Tensor that is the weighted cross entropy loss output.
# weights = onehot_labels * class_weights + (1 - onehot_labels)
# weights = tf.reduce_sum(weights, 3)
# logits = convert_to_logits(prob)
loss = -tf.reduce_mean(onehot_labels * weights * tf.log(output) + 1e-9)
# loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits, weights=weights)
# loss = tf.reduce_mean(loss_batches)
return loss
return weighted_cross_entropy
class ImageWriter(Callback):
def __init__(self, img_shape, batch_size):
self.batch_size = batch_size
test_gen = get_datagen('img91.png', 'seg91.png', train=False, batch_size=batch_size, img_size=img_shape)
self.x, self.y = test_gen.__next__()
self.y_shape = (self.x.shape[1], self.x.shape[2], 2)
self.img = decode_img(self.x)
self.gth = decode_onehot(self.y)
self.preds = []
def on_epoch_end(self, epoch, logs={}):
self.p = self.model.predict_on_batch(self.x)
self.pre = decode_onehot(self.p)
figsize = (
(self.x.shape[2] * (len(self.preds) + 1)) / 100,
(self.x.shape[1] * (self.batch_size + 1)) / 100
fig, axes = plt.subplots(self.batch_size, 2 + len(self.preds), figsize=figsize)
# Set title
axes[0, 0].set_title('X')
axes[0, 1].set_title('GT')
for i in range(len(self.preds)):
axes[0, i + 2].set_title(str(i))
# Set images
for i in range(self.batch_size):
axes[i, 0].imshow(self.img[i], vmin=0, vmax=255)
axes[i, 0].axis('off')
axes[i, 1].imshow(self.gth[i], vmin=0, vmax=255)
axes[i, 1].axis('off')
for j in range(len(self.preds)):
axes[i, j + 2].imshow(self.preds[j][i], vmin=0, vmax=255)
axes[i, j + 2].axis('off')
if __name__ =='__main__':
img_shape = (256, 256, 3)
steps_per_epoch = 128
validation_steps = 4
epochs = 50
batch_size = 16
weight_decay_l2 = 0.01
train_gen = get_datagen('img91.png', 'seg91.png', img_size=img_shape, batch_size=batch_size, sample_weights=None, train=True)
test_gen = get_datagen('img91.png', 'seg91.png', img_size=img_shape, batch_size=batch_size, sample_weights=None, train=False)
# Calculate class weights
_, y = get_datagen('img91.png', 'seg91.png', img_size=img_shape, batch_size=1024, sample_weights=None, train=False).__next__()
pixcount = np.count_nonzero(y, axis=(0,1,2))
imgcount = np.count_nonzero(np.count_nonzero(y, axis=(1, 2)), axis=0)
freq = pixcount / imgcount
weights = 1. / freq
weights /= weights.sum()
print('weights =', weights)
# model = vgg19_unet(input_shape=img_shape, classes=2, weight_decay=weight_decay_l2)
model = Pix2Pix(input_shape=img_shape, classes=2).build()
# model = SegNet(input_shape=img_shape, classes=2)
# model = UNetMatlab(input_shape=img_shape, classes=2).build()
# optimizer=optimizers.SGD(lr=5e-2, momentum=0.9, clipnorm=0.05),
optimizer=optimizers.Adam(lr=1e-4, clipnorm=0.05),
mc_cb = ModelCheckpoint('model.h5', monitor='val_loss')
im_cb = ImageWriter(img_shape, 32)
history = model.fit_generator(
callbacks=[mc_cb, im_cb],
import os
import as io
import skimage.transform as trans
from keras.engine import InputSpec
from keras import initializers, regularizers
from keras.layers import Input, Concatenate, BatchNormalization, Activation, MaxPooling2D, Dropout, Conv2DTranspose
from keras.layers.advanced_activations import LeakyReLU, ReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Model
import keras.backend as K
import tensorflow as tf
class UNetMatlab:
""" """
def __init__(self, input_shape, classes, l2reg=0.0001):
self.input_shape = input_shape
self.classes = classes
self.l2reg = l2reg
def build(self):
x = Input(shape=self.input_shape)
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(x)
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
d1 = h
h = MaxPooling2D(2)(h)
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu')(h)
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
d2 = h
h = MaxPooling2D(2)(h)
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
d3 = h
h = MaxPooling2D(2)(h)
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
d4 = h
h = Dropout(0.5)(h)
h = MaxPooling2D(2)(h)
h = Conv2D(1024, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Conv2D(1024, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Dropout(0.5)(h)
h = Conv2DTranspose(512, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Concatenate(axis=-1)([h, d4])
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2DTranspose(256, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Concatenate(axis=-1)([h, d3])
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2DTranspose(128, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Concatenate(axis=-1)([h, d2])
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2DTranspose(64, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h)
h = Concatenate(axis=-1)([h, d1])
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal')(h)
logit = Conv2D(self.classes, kernel_size=1, padding='valid', kernel_initializer='he_normal')(h)
prob = Activation('softmax')(logit)
model = Model(x, prob)
return model
""" """
class Pix2Pix:
def __init__(self, input_shape, classes):
self.input_shape = input_shape
self.classes = classes
def build(self):
def conv(layer_input, filters):
"""Layers used during downsampling"""
d = ConvSN2D(filters, kernel_size=3, strides=1, dilation_rate=2, padding='same')(layer_input)
d = BatchNormalization(momentum=0.9)(d)
d = LeakyReLU(alpha=0.2)(d)
d = ConvSN2D(filters, kernel_size=3, strides=1, dilation_rate=2, padding='same')(d)
d = BatchNormalization(momentum=0.9)(d)
d = LeakyReLU(alpha=0.2)(d)
pooled = MaxPooling2D(2)(d)
return pooled, d
def deconv(layer_input, skip_input, filters):
"""Layers used during upsampling"""
u = UpSampling2D(size=2)(layer_input)
u = Concatenate(axis=-1)([u, skip_input])
u = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(u)
u = BatchNormalization(momentum=0.9)(u)
u = LeakyReLU(alpha=0.2)(u)
u = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(u)
u = BatchNormalization(momentum=0.9)(u)
u = LeakyReLU(alpha=0.2)(u)
return u
def res(layer_input, filters):
x = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
x = BatchNormalization(momentum=0.9)(x)
x = LeakyReLU(alpha=0.2)(x)
x = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(x)
x = BatchNormalization(momentum=0.9)(x)
x = LeakyReLU(alpha=0.2)(x)
return x
x = Input(shape=self.input_shape)
p1, d1 = conv(x, 64)
p2, d2 = conv(p1, 128)
p3, d3 = conv(p2, 256)
p4, d4 = conv(p3, 512)
p5, d5 = conv(p4, 512)
p6, d6 = conv(p5, 512)
p7, d7 = conv(p6, 1024)
z = res(p7, 1024)
u1 = deconv(z, d7, 512)
u2 = deconv(u1, d6, 512)
u3 = deconv(u2, d5, 512)
u4 = deconv(u3, d4, 256)
u5 = deconv(u4, d3, 128)
u6 = deconv(u5, d2, 64)
u7 = deconv(u6, d1, 64)
logit = ConvSN2D(self.classes, kernel_size=1)(u7)
prob = Activation('softmax')(logit)
return Model(x, prob)
def vgg19_unet(input_shape, weight_decay=0., classes=2):
# Image Input
img = Input(shape=input_shape, name='image')
# Block 1
conv1 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', kernel_regularizer=regularizers.l2(weight_decay))(img)
conv1 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv1)
conv1 = BatchNormalization()(conv1)
pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(conv1)
# Block 2
conv2 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool1)
conv2 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv2)
conv2 = BatchNormalization()(conv2)
pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(conv2)
# Block 3
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool2)
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv3)
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', kernel_regularizer=regularizers.l2(weight_decay))(conv3)
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv4', kernel_regularizer=regularizers.l2(weight_decay))(conv3)
conv3 = BatchNormalization()(conv3)
pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(conv3)
# Block 4
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool3)
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv4)
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', kernel_regularizer=regularizers.l2(weight_decay))(conv4)
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv4', kernel_regularizer=regularizers.l2(weight_decay))(conv4)
conv4 = BatchNormalization()(conv4)
pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(conv4)
# Block 5
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv5)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', kernel_regularizer=regularizers.l2(weight_decay))(conv5)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv4', kernel_regularizer=regularizers.l2(weight_decay))(conv5)
conv5 = BatchNormalization()(conv5)
up6 = UpSampling2D(2)(conv5)
up6 = Concatenate(axis=-1)([up6, conv4])
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv1')(up6)
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv2')(conv6)
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv3')(conv6)
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv4')(conv6)
conv6 = BatchNormalization()(conv6)
up7 = UpSampling2D(2)(conv6)
up7 = Concatenate(axis=-1)([up7, conv3])
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv1')(up7)
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv2')(conv7)
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv3')(conv7)
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv4')(conv7)
conv7 = BatchNormalization()(conv7)
up8 = UpSampling2D(2)(conv7)
up8 = Concatenate(axis=-1)([up8, conv2])
conv8 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block8_conv1')(up8)
conv8 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block8_conv2')(conv8)
conv8 = BatchNormalization()(conv8)
up9 = UpSampling2D(2)(conv8)
up9 = Concatenate(axis=-1)([up9, conv1])
conv9 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block9_conv1')(up9)
conv9 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block9_conv2')(conv9)
conv9 = BatchNormalization()(conv9)
output = Conv2D(classes, (1, 1), padding='same', activation='softmax', name="prob")(conv9)
model = Model(inputs=img, outputs=output)
from keras.regularizers import l1, l2
from keras.applications.vgg19 import VGG19
weights_path = 'temp_vgg19_notop.h5'
VGG19(input_shape=input_shape, weights='imagenet', include_top=False).save_weights(weights_path)
model.load_weights(weights_path, by_name=True)
import os; os.remove('temp_vgg19_notop.h5')
return model
def SegNet(input_shape=(360, 480, 3), classes=12):
### @
img_input = Input(shape=input_shape)
x = img_input
# Encoder
x = Conv2D(64, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(128, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(256, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(512, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
# Decoder
x = Conv2D(512, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(256, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(128, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(64, (3, 3), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(classes, (1, 1), padding="valid")(x)
x = Activation("softmax")(x)
model = Model(img_input, x)
return model
""" """
class ConvSN2D(Conv2D):
def build(self, input_shape):
if self.data_format == 'channels_first':
channel_axis = 1
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (input_dim, self.filters)
self.kernel = self.add_weight(shape=kernel_shape,
if self.use_bias:
self.bias = self.add_weight(shape=(self.filters,),
self.bias = None
self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
initializer=initializers.RandomNormal(0, 1),
# Set input spec.
self.input_spec = InputSpec(ndim=self.rank + 2,
axes={channel_axis: input_dim})
self.built = True
def call(self, inputs, training=None):
def _l2normalize(v, eps=1e-12):
return v / (K.sum(v ** 2) ** 0.5 + eps)
def power_iteration(W, u):
#Accroding the paper, we only need to do power iteration one time.
_u = u
_v = _l2normalize(, K.transpose(W)))
_u = _l2normalize(, W))
return _u, _v
#Spectral Normalization
W_shape = self.kernel.shape.as_list()
#Flatten the Tensor
W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
_u, _v = power_iteration(W_reshaped, self.u)
#Calculate Sigma, W_reshaped), K.transpose(_u))
#normalize it
W_bar = W_reshaped / sigma
#reshape weight tensor
if training in {0, False}:
W_bar = K.reshape(W_bar, W_shape)
with tf.control_dependencies([self.u.assign(_u)]):
W_bar = K.reshape(W_bar, W_shape)
outputs = K.conv2d(
if self.use_bias:
outputs = K.bias_add(
if self.activation is not None:
return self.activation(outputs)
return outputs
