Skip to content

Instantly share code, notes, and snippets.

@ginrou
Last active April 23, 2019 15:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ginrou/01e3b5f7acba87cd05901f7590957b4a to your computer and use it in GitHub Desktop.
Save ginrou/01e3b5f7acba87cd05901f7590957b4a to your computer and use it in GitHub Desktop.
#!/bin/usr/env python
import numpy as np
import os
import itertools
from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization, Activation, Dense, Dropout
from keras.layers.core import Lambda, RepeatVector, Reshape
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D, GlobalMaxPool2D
from keras.layers.merge import concatenate, add
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
class Unet:
@staticmethod
def conv2d_block(input, filters, skip_dropout = False, dropout = 0.5, kernel_size=3, batchnorm = True):
c = Conv2D(filters=filters, kernel_size=(kernel_size, kernel_size), padding="same", kernel_initializer="he_normal")(input)
if batchnorm:
c = BatchNormalization()(c)
c = Activation("relu")(c)
c = Conv2D(filters=filters, kernel_size=(kernel_size, kernel_size), padding="same", kernel_initializer="he_normal")(c)
if batchnorm:
c = BatchNormalization()(c)
c = Activation("relu")(c)
if skip_dropout:
return c
else:
p = MaxPooling2D((2,2))(c)
p = Dropout(dropout)(p)
return (c, p)
@staticmethod
def trans_conv2d_block(input_upsample, input_concat, filters, kernel_size=3, strides=2, dropout=0.5):
u = Conv2DTranspose(filters, (kernel_size, kernel_size), strides=(strides, strides), padding="same")(input_upsample)
u = concatenate([u, input_concat])
u = Dropout(dropout)(u)
c = Unet.conv2d_block(u, filters, skip_dropout=True)
return (u,c)
def __init__(self):
self.IMG_HEIGHT = 256
self.IMG_WIDTH = 256
self.IMG_CHANNELS = 3
self.NUM_CLASS = 21
self.FILTER_BASE = 16
self.input = Input((self.IMG_HEIGHT, self.IMG_WIDTH, self.IMG_CHANNELS))
c1, p1 = self.conv2d_block(self.input, self.FILTER_BASE * 1)
c2, p2 = self.conv2d_block(p1, self.FILTER_BASE * 2)
c3, p3 = self.conv2d_block(p2, self.FILTER_BASE * 4)
c4, p4 = self.conv2d_block(p3, self.FILTER_BASE * 8)
c5 = self.conv2d_block(p4, self.FILTER_BASE * 16, skip_dropout=True)
u6, c6 = self.trans_conv2d_block(c5, c4, self.FILTER_BASE * 8)
u7, c7 = self.trans_conv2d_block(c6, c3, self.FILTER_BASE * 4)
u8, c8 = self.trans_conv2d_block(c7, c2, self.FILTER_BASE * 2)
u9, c9 = self.trans_conv2d_block(c8, c1, self.FILTER_BASE * 1)
self.output = Conv2D(self.NUM_CLASS, (1,1), activation="sigmoid")(c9)
self.model = Model(inputs=[self.input], outputs=[self.output])
def predict(self, input_img): ## input_img will be numpy array which has same shape with self.input
return self.model.predict(input_img)
def to_index_color_img(self, predict_result):
arg_max_img = np.argmax(predict_result, axis=2)
color_img = np.array(arg_max_img.shape)
cmap = color_map()
@staticmethod
def color_map(N=256, normalized=False):
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)
dtype = 'float32' if normalized else 'uint8'
cmap = np.zeros((N, 3), dtype=dtype)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7-j)
g = g | (bitget(c, 1) << 7-j)
b = b | (bitget(c, 2) << 7-j)
c = c >> 3
cmap[i] = np.array([r, g, b])
cmap = cmap/255 if normalized else cmap
return cmap
class ImageLoader:
def crop(self, img, shape):
W, H = img.size
X, Y = shape
top_left = img.crop((0,0,X,Y))
top_right = img.crop((W-X, 0, W, Y))
bottom_left = img.crop((0, H-Y, X, H))
bottom_right = img.crop((W-X, H-Y, W, H))
center = img.crop(((W-X)/2, (H-Y)/2, (W+X)/2, (H+Y)/2 ))
return top_left, top_right, bottom_left, bottom_right, center
def load(self, data_dir, img_size, num_train = 100):
self.keys = sorted([f[:-4] for f in os.listdir(os.path.join(data_dir, "SegmentationClass"))])
self.img_path_list = [os.path.join(data_dir, "JPEGImages", k + ".jpg") for k in self.keys]
self.mask_path_list = [os.path.join(data_dir, "SegmentationClass", k + ".png") for k in self.keys]
imgs = list(itertools.chain.from_iterable([ self.crop(load_img(p), img_size) for p in self.img_path_list] ))
masks = list(itertools.chain.from_iterable([ self.crop(load_img(p), img_size) for p in self.mask_path_list] ))
self.train_imgs, self.test_imgs = imgs[:-num_train], imgs[:num_train]
self.train_masks, self.test_masks = masks[:-num_train], masks[:num_train]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment