-
-
Save geografif/a5d5868f73ed85b5ae8c868e9a88e984 to your computer and use it in GitHub Desktop.
Semantic Segmentation of MBRSC Aerial Imagery of Dubai using a TensorFlow U-Net model in Python. https://towardsdatascience.com/semantic-segmentation-of-aerial-imagery-using-u-net-in-python-552705238514
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ===================================================== | |
# define U-Net model architecture | |
def build_unet(img_shape): | |
# input layer shape is equal to patch image size | |
inputs = Input(shape=img_shape) | |
# rescale images from (0, 255) to (0, 1) | |
rescale = Rescaling(scale=1. / 255, input_shape=(img_height, img_width, img_channels))(inputs) | |
previous_block_activation = rescale # Set aside residual | |
contraction = {} | |
# # Contraction path: Blocks 1 through 5 are identical apart from the feature depth | |
for f in [16, 32, 64, 128]: | |
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(previous_block_activation) | |
x = Dropout(0.1)(x) | |
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x) | |
contraction[f'conv{f}'] = x | |
x = MaxPooling2D((2, 2))(x) | |
previous_block_activation = x | |
c5 = Conv2D(160, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(previous_block_activation) | |
c5 = Dropout(0.2)(c5) | |
c5 = Conv2D(160, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5) | |
previous_block_activation = c5 | |
# Expansive path: Second half of the network: upsampling inputs | |
for f in reversed([16, 32, 64, 128]): | |
x = Conv2DTranspose(f, (2, 2), strides=(2, 2), padding='same')(previous_block_activation) | |
x = concatenate([x, contraction[f'conv{f}']]) | |
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x) | |
x = Dropout(0.2)(x) | |
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x) | |
previous_block_activation = x | |
outputs = Conv2D(filters=n_classes, kernel_size=(1, 1), activation="softmax")(previous_block_activation) | |
return Model(inputs=inputs, outputs=outputs) | |
# build model | |
model = build_unet(img_shape=(img_height, img_width, img_channels)) | |
model.summary() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ======================================================= | |
# add callbacks, compile model and fit training data | |
# save best model with maximum validation accuracy | |
checkpoint = ModelCheckpoint(model_checkpoint_filepath, monitor="val_accuracy", verbose=1, save_best_only=True, mode="max") | |
# stop model training early if validation loss doesn't continue to decrease over 2 iterations | |
early_stopping = EarlyStopping(monitor="val_loss", patience=2, verbose=1, mode="min") | |
# log training console output to csv | |
csv_logger = CSVLogger(csv_logger, separator=",", append=False) | |
# create list of callbacks | |
callbacks_list = [checkpoint, csv_logger] # early_stopping | |
# compile model | |
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy", jaccard_index]) | |
# train and save model | |
model.fit(X_train, Y_train, epochs=20, batch_size=32, validation_data=(X_test, Y_test), callbacks=callbacks_list, verbose=1) | |
model.save(model_save_path) | |
print("model saved:", model_save_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# jaccard similarity: the size of the intersection divided by the size of the union of two sets | |
def jaccard_index(y_true, y_pred): | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
intersection = K.sum(y_true_f * y_pred_f) | |
return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_images_and_patchify(directory_path, patch_size): | |
""" | |
:param patch_size: image patchify square size | |
:param directory_path: path to root directory containing training and test images | |
:return: list of images from directory | |
""" | |
# initialize empty list for images | |
instances = [] | |
# iterate through files in directory | |
for file_number, filepath in tqdm(enumerate(os.listdir(directory_path))): | |
extension = filepath.split(".")[-1] | |
if extension == "jpg" or extension == "png": | |
# current image path | |
img_path = rf"{directory_path}/{filepath}" | |
# Reads image as BGR | |
image = cv2.imread(img_path) | |
# convert image to RBG | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
size_x = (image.shape[1] // patch_size) * patch_size # get width to nearest size divisible by patch size | |
size_y = (image.shape[0] // patch_size) * patch_size # get height to nearest size divisible by patch size | |
image = Image.fromarray(image) | |
# Crop original image to size divisible by patch size from top left corner | |
image = np.array(image.crop((0, 0, size_x, size_y))) | |
# Extract patches from each image, step=patch_size means no overlap | |
patch_img = patchify(image, (patch_size, patch_size, 3), step=patch_size) | |
# iterate over vertical patch axis | |
for j in range(patch_img.shape[0]): | |
# iterate over horizontal patch axis | |
for k in range(patch_img.shape[1]): | |
# patches are located like a grid. use (j, k) indices to extract single patched image | |
single_patch_img = patch_img[j, k] | |
# Drop extra extra dimension from patchify | |
instances.append(np.squeeze(single_patch_img)) | |
return instances |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ===================================================== | |
# load pre-trained model | |
model = load_model( | |
'.../final_aerial_segmentation_2022-03-31 13_28_03_079442.hdf5', | |
custom_objects={'jaccard_index': jaccard_index} | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# mask color codes | |
class MaskColorMap(Enum): | |
Unlabelled = (155, 155, 155) | |
Building = (60, 16, 152) | |
Land = (132, 41, 246) | |
Road = (110, 193, 228) | |
Vegetation = (254, 221, 58) | |
Water = (226, 169, 41) | |
def one_hot_encode_masks(masks, num_classes): | |
""" | |
:param masks: Y_train patched mask dataset | |
:param num_classes: number of classes | |
:return: | |
""" | |
# initialise list for integer encoded masks | |
integer_encoded_labels = [] | |
# iterate over each mask | |
for mask in tqdm(masks): | |
# get image shape | |
_img_height, _img_width, _img_channels = mask.shape | |
# create new mask of zeros | |
encoded_image = np.zeros((_img_height, _img_width, 1)).astype(int) | |
for j, cls in enumerate(MaskColorMap): | |
encoded_image[np.all(mask == cls.value, axis=-1)] = j | |
# append encoded image | |
integer_encoded_labels.append(encoded_image) | |
# return one-hot encoded labels | |
return to_categorical(y=integer_encoded_labels, num_classes=num_classes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import math | |
import os | |
from enum import Enum | |
import cv2 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
from keras import backend as K | |
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping | |
from keras.models import Model, load_model | |
from keras.utils import to_categorical | |
from patchify import patchify | |
from sklearn.model_selection import train_test_split | |
from keras.layers import Input, Conv2D, MaxPooling2D, concatenate, Conv2DTranspose, Dropout | |
from keras.layers.experimental.preprocessing import Rescaling | |
from tqdm import tqdm | |
# ======================================================= | |
# image preprocessing | |
def load_images_and_patchify(directory_path, patch_size): | |
""" | |
:param patch_size: image patchify square size | |
:param directory_path: path to root directory containing training and test images | |
:return: list of images from directory | |
""" | |
# initialize empty list for images | |
instances = [] | |
# iterate through files in directory | |
for file_number, filepath in tqdm(enumerate(os.listdir(directory_path))): | |
extension = filepath.split(".")[-1] | |
if extension == "jpg" or extension == "png": | |
# current image path | |
img_path = rf"{directory_path}/{filepath}" | |
# Reads image as BGR | |
image = cv2.imread(img_path) | |
# convert image to RBG | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
size_x = (image.shape[1] // patch_size) * patch_size # get width to nearest size divisible by patch size | |
size_y = (image.shape[0] // patch_size) * patch_size # get height to nearest size divisible by patch size | |
image = Image.fromarray(image) | |
# Crop original image to size divisible by patch size from top left corner | |
image = np.array(image.crop((0, 0, size_x, size_y))) | |
# Extract patches from each image, step=patch_size means no overlap | |
patch_img = patchify(image, (patch_size, patch_size, 3), step=patch_size) | |
# iterate over vertical patch axis | |
for j in range(patch_img.shape[0]): | |
# iterate over horizontal patch axis | |
for k in range(patch_img.shape[1]): | |
# patches are located like a grid. use (j, k) indices to extract single patched image | |
single_patch_img = patch_img[j, k] | |
# Drop extra extra dimension from patchify | |
instances.append(np.squeeze(single_patch_img)) | |
return instances | |
def reshape_images(instances): | |
""" | |
:param instances: list of images | |
:return: reshaped images | |
""" | |
for j in range(len(instances)): | |
instances[j] = instances[j].reshape(-1, 1) | |
return instances | |
def get_minimum_image_size(instances): | |
""" | |
:param instances: list of images | |
:return: min and max dimensions out of all images | |
""" | |
# initialize minimum values to infinity | |
min_x = math.inf | |
min_y = math.inf | |
# loop through each instance | |
for image in instances: | |
# check min x (rows) | |
min_x = image.shape[0] if image.shape[0] < min_x else min_x | |
# check min y (columns) | |
min_y = image.shape[1] if image.shape[1] < min_y else min_y | |
return min_x, min_y | |
def display_images(instances, rows=2, titles=None): | |
""" | |
:param instances: list of images | |
:param rows: number of rows in subplot | |
:param titles: subplot titles | |
:return: | |
""" | |
n = len(instances) | |
cols = n // rows if (n / rows) % rows == 0 else (n // rows) + 1 | |
# iterate through images and display subplots | |
for j, image in enumerate(instances): | |
plt.subplot(rows, cols, j + 1) | |
plt.title('') if titles is None else plt.title(titles[j]) | |
plt.axis("off") | |
plt.imshow(image) | |
# show the figure | |
plt.show() | |
# ===================================================== | |
# prepare training data input images | |
def get_training_data(root_directory): | |
# initialise lists | |
image_dataset, mask_dataset = [], [] | |
# define image patch size | |
patch_size = 160 | |
# walk through root directory | |
for path, directories, files in os.walk(root_directory): | |
for subdirectory in directories: | |
# extract training input images and patchify | |
if subdirectory == "images": | |
image_dataset.extend( | |
load_images_and_patchify(os.path.join(path, subdirectory), patch_size=patch_size)) | |
# extract training label masks and patchify | |
elif subdirectory == "masks": | |
mask_dataset.extend( | |
load_images_and_patchify(os.path.join(path, subdirectory), patch_size=patch_size)) | |
# return input images and masks | |
return np.array(image_dataset), np.array(mask_dataset) | |
# mask color codes | |
class MaskColorMap(Enum): | |
Unlabelled = (155, 155, 155) | |
Building = (60, 16, 152) | |
Land = (132, 41, 246) | |
Road = (110, 193, 228) | |
Vegetation = (254, 221, 58) | |
Water = (226, 169, 41) | |
def one_hot_encode_masks(masks, num_classes): | |
""" | |
:param masks: Y_train patched mask dataset | |
:param num_classes: number of classes | |
:return: | |
""" | |
# initialise list for integer encoded masks | |
integer_encoded_labels = [] | |
# iterate over each mask | |
for mask in tqdm(masks): | |
# get image shape | |
_img_height, _img_width, _img_channels = mask.shape | |
# create new mask of zeros | |
encoded_image = np.zeros((_img_height, _img_width, 1)).astype(int) | |
for j, cls in enumerate(MaskColorMap): | |
encoded_image[np.all(mask == cls.value, axis=-1)] = j | |
# append encoded image | |
integer_encoded_labels.append(encoded_image) | |
# return one-hot encoded labels | |
return to_categorical(y=integer_encoded_labels, num_classes=num_classes) | |
# ===================================================== | |
# output directories | |
# datetime for filename saving | |
dt_now = str(datetime.datetime.now()).replace(".", "_").replace(":", "_") | |
model_img_save_path = f"{os.getcwd()}/models/final_aerial_segmentation_{dt_now}.png" | |
model_save_path = f"{os.getcwd()}/models/final_aerial_segmentation_{dt_now}.hdf5" | |
model_checkpoint_filepath = os.getcwd() + "/models/weights-improvement-{epoch:02d}-{val_accuracy:.2f}.hdf5" | |
csv_logger = rf"{os.getcwd()}/logs/aerial_segmentation_log_{dt_now}.csv" | |
# ======================================================= | |
# training metrics | |
# jaccard similarity: the size of the intersection divided by the size of the union of two sets | |
def jaccard_index(y_true, y_pred): | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
intersection = K.sum(y_true_f * y_pred_f) | |
return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0) | |
# ===================================================== | |
# get training data | |
# number of classes in segmentation dataset | |
n_classes = 6 | |
# dataset directory | |
data_dir = r"/Users/andrewdavies/Code/Python/MachineLearning/earth-observation/data/semantic-segmentation-dataset" | |
# create (X, Y) training data | |
X, Y = get_training_data(root_directory=data_dir) | |
# extract X_train shape parameters | |
m, img_height, img_width, img_channels = X.shape | |
print('number of patched image training data:', m) | |
# display images from both training and test sets | |
display_count = 6 | |
random_index = [np.random.randint(0, m) for _ in range(display_count)] | |
sample_images = [x for z in zip(list(X[random_index]), list(Y[random_index])) for x in z] | |
display_images(sample_images, rows=2) | |
# convert RGB values to ineger encoded labels for categorial_crossentropy | |
Y = one_hot_encode_masks(Y, num_classes=n_classes) | |
# split dataset into training and test groups | |
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.10, random_state=42) | |
# ===================================================== | |
# define U-Net model architecture | |
def build_unet(img_shape): | |
# input layer shape is equal to patch image size | |
inputs = Input(shape=img_shape) | |
# rescale images from (0, 255) to (0, 1) | |
rescale = Rescaling(scale=1. / 255, input_shape=(img_height, img_width, img_channels))(inputs) | |
previous_block_activation = rescale # Set aside residual | |
contraction = {} | |
# # Contraction path: Blocks 1 through 5 are identical apart from the feature depth | |
for f in [16, 32, 64, 128]: | |
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(previous_block_activation) | |
x = Dropout(0.1)(x) | |
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x) | |
contraction[f'conv{f}'] = x | |
x = MaxPooling2D((2, 2))(x) | |
previous_block_activation = x | |
c5 = Conv2D(160, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(previous_block_activation) | |
c5 = Dropout(0.2)(c5) | |
c5 = Conv2D(160, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5) | |
previous_block_activation = c5 | |
# Expansive path: Second half of the network: upsampling inputs | |
for f in reversed([16, 32, 64, 128]): | |
x = Conv2DTranspose(f, (2, 2), strides=(2, 2), padding='same')(previous_block_activation) | |
x = concatenate([x, contraction[f'conv{f}']]) | |
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x) | |
x = Dropout(0.2)(x) | |
x = Conv2D(f, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(x) | |
previous_block_activation = x | |
outputs = Conv2D(filters=n_classes, kernel_size=(1, 1), activation="softmax")(previous_block_activation) | |
return Model(inputs=inputs, outputs=outputs) | |
# build model | |
model = build_unet(img_shape=(img_height, img_width, img_channels)) | |
model.summary() | |
# ======================================================= | |
# add callbacks, compile model and fit training data | |
# save best model with maximum validation accuracy | |
checkpoint = ModelCheckpoint(model_checkpoint_filepath, monitor="val_accuracy", verbose=1, save_best_only=True, mode="max") | |
# stop model training early if validation loss doesn't continue to decrease over 2 iterations | |
early_stopping = EarlyStopping(monitor="val_loss", patience=2, verbose=1, mode="min") | |
# log training console output to csv | |
csv_logger = CSVLogger(csv_logger, separator=",", append=False) | |
# create list of callbacks | |
callbacks_list = [checkpoint, csv_logger] # early_stopping | |
# compile model | |
model.compile(optimizer="adam", loss="categorical_crossentropy", | |
metrics=["accuracy", iou_coefficient, jaccard_index]) | |
# train and save model | |
model.fit(X_train, Y_train, epochs=20, batch_size=32, validation_data=(X_test, Y_test), callbacks=callbacks_list, verbose=1) | |
model.save(model_save_path) | |
print("model saved:", model_save_path) | |
# ===================================================== | |
# load pre-trained model | |
# model = load_model( | |
# '/Users/andrewdavies/Code/Python/MachineLearning/earth-observation/models/final_aerial_segmentation_2022-03-31 13_28_03_079442.hdf5', | |
# custom_objects={'iou_coefficient': iou_coefficient, 'jaccard_index': jaccard_index} | |
# ) | |
# ===================================================== | |
# Predict | |
def rgb_encode_mask(mask): | |
# initialize rgb image with equal spatial resolution | |
rgb_encode_image = np.zeros((mask.shape[0], mask.shape[1], 3)) | |
# iterate over MaskColorMap | |
for j, cls in enumerate(MaskColorMap): | |
# convert single integer channel to RGB channels | |
rgb_encode_image[(mask == j)] = np.array(cls.value) / 255. | |
return rgb_encode_image | |
for _ in range(20): | |
# choose random number from 0 to test set size | |
test_img_number = np.random.randint(0, len(X_test)) | |
# extract test input image | |
test_img = X_test[test_img_number] | |
# ground truth test label converted from one-hot to integer encoding | |
ground_truth = np.argmax(Y_test[test_img_number], axis=-1) | |
# expand first dimension as U-Net requires (m, h, w, nc) input shape | |
test_img_input = np.expand_dims(test_img, 0) | |
# make prediction with model and remove extra dimension | |
prediction = np.squeeze(model.predict(test_img_input)) | |
# convert softmax probabilities to integer values | |
predicted_img = np.argmax(prediction, axis=-1) | |
# convert integer encoding to rgb values | |
rgb_image = rgb_encode_mask(predicted_img) | |
rgb_ground_truth = rgb_encode_mask(ground_truth) | |
# visualize model predictions | |
display_images( | |
[test_img, rgb_ground_truth, rgb_image], | |
rows=1, titles=['Aerial', 'Ground Truth', 'Prediction'] | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ===================================================== | |
# Predict | |
def rgb_encode_mask(mask): | |
# initialize rgb image with equal spatial resolution | |
rgb_encode_image = np.zeros((mask.shape[0], mask.shape[1], 3)) | |
# iterate over MaskColorMap | |
for j, cls in enumerate(MaskColorMap): | |
# convert single integer channel to RGB channels | |
rgb_encode_image[(mask == j)] = np.array(cls.value) / 255. | |
return rgb_encode_image | |
for _ in range(20): | |
# choose random number from 0 to test set size | |
test_img_number = np.random.randint(0, len(X_test)) | |
# extract test input image | |
test_img = X_test[test_img_number] | |
# ground truth test label converted from one-hot to integer encoding | |
ground_truth = np.argmax(Y_test[test_img_number], axis=-1) | |
# expand first dimension as U-Net requires (m, h, w, nc) input shape | |
test_img_input = np.expand_dims(test_img, 0) | |
# make prediction with model and remove extra dimension | |
prediction = np.squeeze(model.predict(test_img_input)) | |
# convert softmax probabilities to integer values | |
predicted_img = np.argmax(prediction, axis=-1) | |
# convert integer encoding to rgb values | |
rgb_image = rgb_encode_mask(predicted_img) | |
rgb_ground_truth = rgb_encode_mask(ground_truth) | |
# visualize model predictions | |
display_images( | |
[test_img, rgb_ground_truth, rgb_image], | |
rows=1, titles=['Aerial', 'Ground Truth', 'Prediction'] | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment