Skip to content

Instantly share code, notes, and snippets.

@TimSC
Last active April 22, 2020 00:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TimSC/debcf71eae41c5b54eaf44d587d7744c to your computer and use it in GitHub Desktop.
Save TimSC/debcf71eae41c5b54eaf44d587d7744c to your computer and use it in GitHub Desktop.
Using Keras to tackle the Inria aerial image labeling dataset
#Using Keras to tackle the Inria aerial image labeling dataset
# https://project.inria.fr/aerialimagelabeling/
import os
#Work around for https://github.com/tensorflow/tensorflow/issues/24496
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
# Work around for https://github.com/tensorflow/tensorflow/issues/33024
import tensorflow.compat as compat
compat.v1.disable_eager_execution()
import zipfile
import imageio
import io
from keras_segmentation.models.fcn import fcn_8
from keras_segmentation.models.unet import vgg_unet
import tensorflow.keras as keras
import numpy as np
from matplotlib import pyplot as plt
def AerialListFiles(pth):
z = zipfile.ZipFile(pth)
zinfo = z.infolist()
zinfoDict = {}
for zi in zinfo:
zinfoDict[zi.filename] = zi
testLocs = ["bellingham", "bloomington", "innsbruck", "sfo", "tyrol-e"]
trainLocs = ["austin", "chicago", "kitsap", "tyrol-w", "vienna"]
xData, xTest, yData, yTest = [], [], [], []
for loc in trainLocs:
for i in range(1, 37):
ifile = "AerialImageDataset/train/images/{}{}.tif".format(loc, i)
gtFile = "AerialImageDataset/train/gt/{}{}.tif".format(loc, i)
xData.append(zinfoDict[ifile])
yData.append(zinfoDict[gtFile])
for loc in testLocs:
for i in range(1, 37):
ifile = "AerialImageDataset/test/images/{}{}.tif".format(loc, i)
xTest.append(zinfoDict[ifile])
yTest.append(None)
return z, xData, yData, yData, yTest
def SplitTrainAndValidation(xData, yData):
xTrain, xVal, yTrain, yVal = [], [], [], []
n = 0
for locNum in range(5):
for i in range(1, 37):
if i > 5:
xTrain.append(xData[n])
yTrain.append(yData[n])
else:
xVal.append(xData[n])
yVal.append(yData[n])
n += 1
return xTrain, xVal, yTrain, yVal
class AerialDataGenerator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, z, dataX, dataY, batchesPerEpoch=50, filesInBatch=10, cropsInFile=10, cropSize=128, cropMargin=-32):
'Initialization'
self.z = z
self.dataX = dataX
self.dataY = dataY
self.batchesPerEpoch = batchesPerEpoch
self.filesInBatch = filesInBatch
self.cropsInFile = cropsInFile
self.cropSize = cropSize
self.cropMargin = cropMargin
def __len__(self):
'Denotes the number of batches per epoch'
return self.batchesPerEpoch
def __getitem__(self, index):
'Generate one batch of data'
dataX, dataY = [], []
sizeWithMargin = self.cropSize + 2 * self.cropMargin
posMarginOrZero = max(0, self.cropMargin)
for j in range(self.filesInBatch):
fileId = np.random.randint(len(self.dataX))
imgData = io.BytesIO(z.open(self.dataX[fileId]).read())
img = imageio.imread(imgData)
del imgData
gtData = io.BytesIO(z.open(self.dataY[fileId]).read())
gt = imageio.imread(gtData)
del gtData
for i in range(self.cropsInFile):
#Get a random crop
r = np.random.randint(posMarginOrZero, img.shape[0]-self.cropSize-posMarginOrZero)
c = np.random.randint(posMarginOrZero, img.shape[1]-self.cropSize-posMarginOrZero)
imgc = img[r:r+self.cropSize,:,:]
imgc = imgc[:,c:c+self.cropSize,:]
gtc = gt[r-self.cropMargin:r+self.cropSize+self.cropMargin,:]
gtc = gtc[:,c-self.cropMargin:c+self.cropSize+self.cropMargin]
#Rescale
imgc = np.array(imgc, dtype=np.float32) / 255.0
gtc = gtc.reshape((sizeWithMargin*sizeWithMargin,)) > 128
gtc = keras.utils.to_categorical(gtc, num_classes=2)
dataX.append(imgc)
dataY.append(gtc)
dataX = np.array(dataX)
dataY = np.array(dataY)
return dataX, dataY
def on_epoch_end(self):
pass
# plot diagnostic learning curves
def summarize_diagnostics(histories):
for i in range(len(histories)):
# plot loss
plt.subplot(2, 1, 1)
plt.title('Cross Entropy Loss')
plt.plot(histories[i].history['loss'], color='blue', label='train')
plt.plot(histories[i].history['val_loss'], color='orange', label='test')
# plot accuracy
plt.subplot(2, 1, 2)
plt.title('Classification Accuracy')
plt.plot(histories[i].history['accuracy'], color='blue', label='train')
plt.plot(histories[i].history['val_accuracy'], color='orange', label='test')
plt.show()
def PredictOnImages(z, xVal, model):
# Do prediction on specified images
for imgInfo in xVal:
print (imgInfo)
imgData = io.BytesIO(z.open(imgInfo).read())
img = imageio.imread(imgData)
del imgData
predImg = np.zeros((img.shape[0], img.shape[1]), dtype=np.int8)
margin = 32
for r in range(margin, img.shape[0]-64-margin, 64):
patches = []
patchPosLi = []
for c in range(margin, img.shape[1]-64-margin, 64):
#print (r, c)
imgc = img[r-margin:r+64+margin,:,:]
imgc = imgc[:,c-margin:c+64+margin,:]
patches.append(imgc)
patchPosLi.append((r, c))
patches = np.array(patches)
result = model.predict(patches)
for (r, c), pred in zip(patchPosLi, result):
outp = predImg[r:r+64,:]
outp = outp[:,c:c+64]
pred = pred.reshape((64, 64, 2))
pred = pred[:,:,1]
print (r, c, pred.shape, outp.shape)
predImg[r:r+64,c:c+64] = (pred > 0.5)
plt.imshow(predImg)
plt.show()
if __name__=="__main__":
z, xData, yData, yData, yTest = AerialListFiles("/home/tim/Downloads/aerialimagelabeling/NEW2-AerialImageDataset.zip")
xTrain, xVal, yTrain, yVal = SplitTrainAndValidation(xData, yData)
if True:
trainGen = AerialDataGenerator(z, xTrain, yTrain)
valGen = AerialDataGenerator(z, xVal, yVal, filesInBatch=1, cropsInFile=100)
model = vgg_unet(2, input_height=128, input_width=128)
print (type(model))
print ("Compiling model")
model.compile(optimizer="adadelta", loss='categorical_crossentropy', metrics=['accuracy'])
print ("Fitting model")
history = model.fit_generator(
trainGen,
validation_data=valGen,
validation_steps=50,
epochs=20,
workers=0
)
#keras.models.save_model(model, 'aerial.h5')
summarize_diagnostics([history])
else:
model = keras.models.load_model('aerial.h5')
PredictOnImages(z, xVal, model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment