Skip to content

Instantly share code, notes, and snippets.

@geografif
Forked from ad-1/build_unet.py
Created February 7, 2023 19:02
Show Gist options
  • Save geografif/a5d5868f73ed85b5ae8c868e9a88e984 to your computer and use it in GitHub Desktop.
Save geografif/a5d5868f73ed85b5ae8c868e9a88e984 to your computer and use it in GitHub Desktop.
Semantic Segmentation of MBRSC Aerial Imagery of Dubai using a TensorFlow U-Net model in Python. https://towardsdatascience.com/semantic-segmentation-of-aerial-imagery-using-u-net-in-python-552705238514
# =====================================================
# define U-Net model architecture
def build_unet(img_shape):
# input layer shape is equal to patch image size
inputs = Input(shape=img_shape)
# rescale images from (0, 255) to (0, 1)
rescale = Rescaling(scale=1. / 255, input_shape=(img_height, img_width, img_channels))(inputs)
previous_block_activation = rescale # Set aside residual
contraction = {}
# # Contraction path: Blocks 1 through 5 are identical apart from the feature depth
for f in [16, 32, 64, 128]:
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(previous_block_activation)
x = Dropout(0.1)(x)
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x)
contraction[f'conv{f}'] = x
x = MaxPooling2D((2, 2))(x)
previous_block_activation = x
c5 = Conv2D(160, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(previous_block_activation)
c5 = Dropout(0.2)(c5)
c5 = Conv2D(160, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
previous_block_activation = c5
# Expansive path: Second half of the network: upsampling inputs
for f in reversed([16, 32, 64, 128]):
x = Conv2DTranspose(f, (2, 2), strides=(2, 2), padding='same')(previous_block_activation)
x = concatenate([x, contraction[f'conv{f}']])
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x)
x = Dropout(0.2)(x)
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x)
previous_block_activation = x
outputs = Conv2D(filters=n_classes, kernel_size=(1, 1), activation="softmax")(previous_block_activation)
return Model(inputs=inputs, outputs=outputs)
# build model
model = build_unet(img_shape=(img_height, img_width, img_channels))
model.summary()
# =======================================================
# add callbacks, compile model and fit training data
# save best model with maximum validation accuracy
checkpoint = ModelCheckpoint(model_checkpoint_filepath, monitor="val_accuracy", verbose=1, save_best_only=True, mode="max")
# stop model training early if validation loss doesn't continue to decrease over 2 iterations
early_stopping = EarlyStopping(monitor="val_loss", patience=2, verbose=1, mode="min")
# log training console output to csv
csv_logger = CSVLogger(csv_logger, separator=",", append=False)
# create list of callbacks
callbacks_list = [checkpoint, csv_logger] # early_stopping
# compile model
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy", jaccard_index])
# train and save model
model.fit(X_train, Y_train, epochs=20, batch_size=32, validation_data=(X_test, Y_test), callbacks=callbacks_list, verbose=1)
model.save(model_save_path)
print("model saved:", model_save_path)
# jaccard similarity: the size of the intersection divided by the size of the union of two sets
def jaccard_index(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
def load_images_and_patchify(directory_path, patch_size):
"""
:param patch_size: image patchify square size
:param directory_path: path to root directory containing training and test images
:return: list of images from directory
"""
# initialize empty list for images
instances = []
# iterate through files in directory
for file_number, filepath in tqdm(enumerate(os.listdir(directory_path))):
extension = filepath.split(".")[-1]
if extension == "jpg" or extension == "png":
# current image path
img_path = rf"{directory_path}/{filepath}"
# Reads image as BGR
image = cv2.imread(img_path)
# convert image to RBG
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
size_x = (image.shape[1] // patch_size) * patch_size # get width to nearest size divisible by patch size
size_y = (image.shape[0] // patch_size) * patch_size # get height to nearest size divisible by patch size
image = Image.fromarray(image)
# Crop original image to size divisible by patch size from top left corner
image = np.array(image.crop((0, 0, size_x, size_y)))
# Extract patches from each image, step=patch_size means no overlap
patch_img = patchify(image, (patch_size, patch_size, 3), step=patch_size)
# iterate over vertical patch axis
for j in range(patch_img.shape[0]):
# iterate over horizontal patch axis
for k in range(patch_img.shape[1]):
# patches are located like a grid. use (j, k) indices to extract single patched image
single_patch_img = patch_img[j, k]
# Drop extra extra dimension from patchify
instances.append(np.squeeze(single_patch_img))
return instances
# =====================================================
# load pre-trained model
model = load_model(
'.../final_aerial_segmentation_2022-03-31 13_28_03_079442.hdf5',
custom_objects={'jaccard_index': jaccard_index}
)
# mask color codes
class MaskColorMap(Enum):
Unlabelled = (155, 155, 155)
Building = (60, 16, 152)
Land = (132, 41, 246)
Road = (110, 193, 228)
Vegetation = (254, 221, 58)
Water = (226, 169, 41)
def one_hot_encode_masks(masks, num_classes):
"""
:param masks: Y_train patched mask dataset
:param num_classes: number of classes
:return:
"""
# initialise list for integer encoded masks
integer_encoded_labels = []
# iterate over each mask
for mask in tqdm(masks):
# get image shape
_img_height, _img_width, _img_channels = mask.shape
# create new mask of zeros
encoded_image = np.zeros((_img_height, _img_width, 1)).astype(int)
for j, cls in enumerate(MaskColorMap):
encoded_image[np.all(mask == cls.value, axis=-1)] = j
# append encoded image
integer_encoded_labels.append(encoded_image)
# return one-hot encoded labels
return to_categorical(y=integer_encoded_labels, num_classes=num_classes)
import datetime
import math
import os
from enum import Enum
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from keras import backend as K
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
from keras.models import Model, load_model
from keras.utils import to_categorical
from patchify import patchify
from sklearn.model_selection import train_test_split
from keras.layers import Input, Conv2D, MaxPooling2D, concatenate, Conv2DTranspose, Dropout
from keras.layers.experimental.preprocessing import Rescaling
from tqdm import tqdm
# =======================================================
# image preprocessing
def load_images_and_patchify(directory_path, patch_size):
"""
:param patch_size: image patchify square size
:param directory_path: path to root directory containing training and test images
:return: list of images from directory
"""
# initialize empty list for images
instances = []
# iterate through files in directory
for file_number, filepath in tqdm(enumerate(os.listdir(directory_path))):
extension = filepath.split(".")[-1]
if extension == "jpg" or extension == "png":
# current image path
img_path = rf"{directory_path}/{filepath}"
# Reads image as BGR
image = cv2.imread(img_path)
# convert image to RBG
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
size_x = (image.shape[1] // patch_size) * patch_size # get width to nearest size divisible by patch size
size_y = (image.shape[0] // patch_size) * patch_size # get height to nearest size divisible by patch size
image = Image.fromarray(image)
# Crop original image to size divisible by patch size from top left corner
image = np.array(image.crop((0, 0, size_x, size_y)))
# Extract patches from each image, step=patch_size means no overlap
patch_img = patchify(image, (patch_size, patch_size, 3), step=patch_size)
# iterate over vertical patch axis
for j in range(patch_img.shape[0]):
# iterate over horizontal patch axis
for k in range(patch_img.shape[1]):
# patches are located like a grid. use (j, k) indices to extract single patched image
single_patch_img = patch_img[j, k]
# Drop extra extra dimension from patchify
instances.append(np.squeeze(single_patch_img))
return instances
def reshape_images(instances):
"""
:param instances: list of images
:return: reshaped images
"""
for j in range(len(instances)):
instances[j] = instances[j].reshape(-1, 1)
return instances
def get_minimum_image_size(instances):
"""
:param instances: list of images
:return: min and max dimensions out of all images
"""
# initialize minimum values to infinity
min_x = math.inf
min_y = math.inf
# loop through each instance
for image in instances:
# check min x (rows)
min_x = image.shape[0] if image.shape[0] < min_x else min_x
# check min y (columns)
min_y = image.shape[1] if image.shape[1] < min_y else min_y
return min_x, min_y
def display_images(instances, rows=2, titles=None):
"""
:param instances: list of images
:param rows: number of rows in subplot
:param titles: subplot titles
:return:
"""
n = len(instances)
cols = n // rows if (n / rows) % rows == 0 else (n // rows) + 1
# iterate through images and display subplots
for j, image in enumerate(instances):
plt.subplot(rows, cols, j + 1)
plt.title('') if titles is None else plt.title(titles[j])
plt.axis("off")
plt.imshow(image)
# show the figure
plt.show()
# =====================================================
# prepare training data input images
def get_training_data(root_directory):
# initialise lists
image_dataset, mask_dataset = [], []
# define image patch size
patch_size = 160
# walk through root directory
for path, directories, files in os.walk(root_directory):
for subdirectory in directories:
# extract training input images and patchify
if subdirectory == "images":
image_dataset.extend(
load_images_and_patchify(os.path.join(path, subdirectory), patch_size=patch_size))
# extract training label masks and patchify
elif subdirectory == "masks":
mask_dataset.extend(
load_images_and_patchify(os.path.join(path, subdirectory), patch_size=patch_size))
# return input images and masks
return np.array(image_dataset), np.array(mask_dataset)
# mask color codes
class MaskColorMap(Enum):
Unlabelled = (155, 155, 155)
Building = (60, 16, 152)
Land = (132, 41, 246)
Road = (110, 193, 228)
Vegetation = (254, 221, 58)
Water = (226, 169, 41)
def one_hot_encode_masks(masks, num_classes):
"""
:param masks: Y_train patched mask dataset
:param num_classes: number of classes
:return:
"""
# initialise list for integer encoded masks
integer_encoded_labels = []
# iterate over each mask
for mask in tqdm(masks):
# get image shape
_img_height, _img_width, _img_channels = mask.shape
# create new mask of zeros
encoded_image = np.zeros((_img_height, _img_width, 1)).astype(int)
for j, cls in enumerate(MaskColorMap):
encoded_image[np.all(mask == cls.value, axis=-1)] = j
# append encoded image
integer_encoded_labels.append(encoded_image)
# return one-hot encoded labels
return to_categorical(y=integer_encoded_labels, num_classes=num_classes)
# =====================================================
# output directories
# datetime for filename saving
dt_now = str(datetime.datetime.now()).replace(".", "_").replace(":", "_")
model_img_save_path = f"{os.getcwd()}/models/final_aerial_segmentation_{dt_now}.png"
model_save_path = f"{os.getcwd()}/models/final_aerial_segmentation_{dt_now}.hdf5"
model_checkpoint_filepath = os.getcwd() + "/models/weights-improvement-{epoch:02d}-{val_accuracy:.2f}.hdf5"
csv_logger = rf"{os.getcwd()}/logs/aerial_segmentation_log_{dt_now}.csv"
# =======================================================
# training metrics
# jaccard similarity: the size of the intersection divided by the size of the union of two sets
def jaccard_index(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
# =====================================================
# get training data
# number of classes in segmentation dataset
n_classes = 6
# dataset directory
data_dir = r"/Users/andrewdavies/Code/Python/MachineLearning/earth-observation/data/semantic-segmentation-dataset"
# create (X, Y) training data
X, Y = get_training_data(root_directory=data_dir)
# extract X_train shape parameters
m, img_height, img_width, img_channels = X.shape
print('number of patched image training data:', m)
# display images from both training and test sets
display_count = 6
random_index = [np.random.randint(0, m) for _ in range(display_count)]
sample_images = [x for z in zip(list(X[random_index]), list(Y[random_index])) for x in z]
display_images(sample_images, rows=2)
# convert RGB values to ineger encoded labels for categorial_crossentropy
Y = one_hot_encode_masks(Y, num_classes=n_classes)
# split dataset into training and test groups
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.10, random_state=42)
# =====================================================
# define U-Net model architecture
def build_unet(img_shape):
# input layer shape is equal to patch image size
inputs = Input(shape=img_shape)
# rescale images from (0, 255) to (0, 1)
rescale = Rescaling(scale=1. / 255, input_shape=(img_height, img_width, img_channels))(inputs)
previous_block_activation = rescale # Set aside residual
contraction = {}
# # Contraction path: Blocks 1 through 5 are identical apart from the feature depth
for f in [16, 32, 64, 128]:
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(previous_block_activation)
x = Dropout(0.1)(x)
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x)
contraction[f'conv{f}'] = x
x = MaxPooling2D((2, 2))(x)
previous_block_activation = x
c5 = Conv2D(160, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(previous_block_activation)
c5 = Dropout(0.2)(c5)
c5 = Conv2D(160, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
previous_block_activation = c5
# Expansive path: Second half of the network: upsampling inputs
for f in reversed([16, 32, 64, 128]):
x = Conv2DTranspose(f, (2, 2), strides=(2, 2), padding='same')(previous_block_activation)
x = concatenate([x, contraction[f'conv{f}']])
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x)
x = Dropout(0.2)(x)
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x)
previous_block_activation = x
outputs = Conv2D(filters=n_classes, kernel_size=(1, 1), activation="softmax")(previous_block_activation)
return Model(inputs=inputs, outputs=outputs)
# build model
model = build_unet(img_shape=(img_height, img_width, img_channels))
model.summary()
# =======================================================
# add callbacks, compile model and fit training data
# save best model with maximum validation accuracy
checkpoint = ModelCheckpoint(model_checkpoint_filepath, monitor="val_accuracy", verbose=1, save_best_only=True, mode="max")
# stop model training early if validation loss doesn't continue to decrease over 2 iterations
early_stopping = EarlyStopping(monitor="val_loss", patience=2, verbose=1, mode="min")
# log training console output to csv
csv_logger = CSVLogger(csv_logger, separator=",", append=False)
# create list of callbacks
callbacks_list = [checkpoint, csv_logger] # early_stopping
# compile model
model.compile(optimizer="adam", loss="categorical_crossentropy",
metrics=["accuracy", iou_coefficient, jaccard_index])
# train and save model
model.fit(X_train, Y_train, epochs=20, batch_size=32, validation_data=(X_test, Y_test), callbacks=callbacks_list, verbose=1)
model.save(model_save_path)
print("model saved:", model_save_path)
# =====================================================
# load pre-trained model
# model = load_model(
# '/Users/andrewdavies/Code/Python/MachineLearning/earth-observation/models/final_aerial_segmentation_2022-03-31 13_28_03_079442.hdf5',
# custom_objects={'iou_coefficient': iou_coefficient, 'jaccard_index': jaccard_index}
# )
# =====================================================
# Predict
def rgb_encode_mask(mask):
# initialize rgb image with equal spatial resolution
rgb_encode_image = np.zeros((mask.shape[0], mask.shape[1], 3))
# iterate over MaskColorMap
for j, cls in enumerate(MaskColorMap):
# convert single integer channel to RGB channels
rgb_encode_image[(mask == j)] = np.array(cls.value) / 255.
return rgb_encode_image
for _ in range(20):
# choose random number from 0 to test set size
test_img_number = np.random.randint(0, len(X_test))
# extract test input image
test_img = X_test[test_img_number]
# ground truth test label converted from one-hot to integer encoding
ground_truth = np.argmax(Y_test[test_img_number], axis=-1)
# expand first dimension as U-Net requires (m, h, w, nc) input shape
test_img_input = np.expand_dims(test_img, 0)
# make prediction with model and remove extra dimension
prediction = np.squeeze(model.predict(test_img_input))
# convert softmax probabilities to integer values
predicted_img = np.argmax(prediction, axis=-1)
# convert integer encoding to rgb values
rgb_image = rgb_encode_mask(predicted_img)
rgb_ground_truth = rgb_encode_mask(ground_truth)
# visualize model predictions
display_images(
[test_img, rgb_ground_truth, rgb_image],
rows=1, titles=['Aerial', 'Ground Truth', 'Prediction']
)
# =====================================================
# Predict
def rgb_encode_mask(mask):
# initialize rgb image with equal spatial resolution
rgb_encode_image = np.zeros((mask.shape[0], mask.shape[1], 3))
# iterate over MaskColorMap
for j, cls in enumerate(MaskColorMap):
# convert single integer channel to RGB channels
rgb_encode_image[(mask == j)] = np.array(cls.value) / 255.
return rgb_encode_image
for _ in range(20):
# choose random number from 0 to test set size
test_img_number = np.random.randint(0, len(X_test))
# extract test input image
test_img = X_test[test_img_number]
# ground truth test label converted from one-hot to integer encoding
ground_truth = np.argmax(Y_test[test_img_number], axis=-1)
# expand first dimension as U-Net requires (m, h, w, nc) input shape
test_img_input = np.expand_dims(test_img, 0)
# make prediction with model and remove extra dimension
prediction = np.squeeze(model.predict(test_img_input))
# convert softmax probabilities to integer values
predicted_img = np.argmax(prediction, axis=-1)
# convert integer encoding to rgb values
rgb_image = rgb_encode_mask(predicted_img)
rgb_ground_truth = rgb_encode_mask(ground_truth)
# visualize model predictions
display_images(
[test_img, rgb_ground_truth, rgb_image],
rows=1, titles=['Aerial', 'Ground Truth', 'Prediction']
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment