Skip to content

Instantly share code, notes, and snippets.

@KMarkert
Created March 10, 2020 16:34
Show Gist options
  • Save KMarkert/83d75e7d17044ec3a48240341265792c to your computer and use it in GitHub Desktop.
Save KMarkert/83d75e7d17044ec3a48240341265792c to your computer and use it in GitHub Desktop.
import os
import fire
import numpy as np
from osgeo import gdal
from scipy import ndimage
import dask.array as da
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras import layers
from tensorflow.python.keras import backend as K
from numba import jit
from numba.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings
warnings.simplefilter('ignore')
warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
@jit(forceobj=True)
def minMaxScaler(x):
imin = np.nanmin(np.nanmin(x,axis=0),axis=0)
imax = np.nanmax(np.nanmax(x,axis=0),axis=0)
return (x-imin)/(imax-imin)
@jit(forceobj=True)
def blockshaped(arr,nrows,ncols):
"""
Return an array of shape (n, nrows, ncols) where
n * nrows * ncols = arr.size
If arr is a 2D array, the returned array should look like n subblocks with
each subblock preserving the "physical" layout of arr.
"""
h, w = arr.shape
assert h % nrows == 0, "{} rows is not evenly divisble by {}".format(h, nrows)
assert w % ncols == 0, "{} cols is not evenly divisble by {}".format(w, ncols)
return (arr.reshape(h//nrows, nrows, -1, ncols)
.swapaxes(1,2)
.reshape(-1, nrows, ncols))
@jit(forceobj=True)
def unblockshaped(arr, h, w):
"""
Return an array of shape (h, w) where
h * w = arr.size
If arr is of shape (n, nrows, ncols), n sublocks of shape (nrows, ncols),
then the returned array preserves the "physical" layout of the sublocks.
"""
n, nrows, ncols = arr.shape
return (arr.reshape(h//nrows, -1, nrows, ncols)
.swapaxes(1,2)
.reshape(h, w))
@jit(forceobj=True)
def merge_img(arr,targetXY):
nrows,ncols = targetXY
arr_shape = arr.shape
e, h, w, d = arr_shape
out = [unblockshaped(arr[:,:,:,i],nrows,ncols) for i in range(d)]
out = np.stack(out,axis=2)
return out
@jit(forceobj=True)
def chunk_img(arr, targetXY):
"""
Return an array of shape (n, nrows, ncols) where
n * nrows * ncols = arr.size
If arr is a 2D array, the returned array should look like n subblocks with
each subblock preserving the "physical" layout of arr.
"""
nrows,ncols = targetXY
arr_shape = arr.shape
if len(arr_shape) == 2:
h,w = arr_shape
out = blockshaped(arr,nrow,ncols)
elif len(arr_shape) == 3:
h, w, d = arr_shape
out = [blockshaped(arr[:,:,i],nrows,ncols) for i in range(d)]
out = np.stack(out,axis=3)
return out
def load_img(img_path,tileSize=256):
ds = gdal.Open(img_path)
img = ds.ReadAsArray()
img = np.moveaxis(img,0,2)
ds = None
yDiff, xDiff = tileSize - (np.array(img.shape[:2]) % tileSize)
y1,y2 = int(np.floor(yDiff/2)),int(np.ceil(yDiff/2))
x1,x2 = int(np.floor(xDiff/2)),int(np.ceil(xDiff/2))
img_pad = np.pad(img,((y1,y2),(x1,x2),(0,0)),mode='constant')
# ndvi = (img_pad[:,:,3] - img_pad[:,:,2]) / (img_pad[:,:,3] + img_pad[:,:,2])
ndwi = (img_pad[:,:,1] - img_pad[:,:,3]) / (img_pad[:,:,1] + img_pad[:,:,3])
stack = np.dstack([img_pad,ndwi])
stack = minMaxScaler(stack)
# return da.from_array(scaled,chunks=(tileSize,tileSize,scaled.shape[2]))
return stack
def save_img(outpath,data,template,tileSize=256,noData=-1):
ds = gdal.Open(template)
yDiff, xDiff = tileSize - (np.array([ds.RasterYSize,ds.RasterXSize]) % tileSize)
y1,y2 = int(np.floor(yDiff/2)),int(np.ceil(yDiff/2))
x1,x2 = int(np.floor(xDiff/2)),int(np.ceil(xDiff/2))
gt = ds.GetGeoTransform()
srs = ds.GetProjection()
if len(data.shape) == 3:
yDim,xDim = data.shape[:2]
nBands = data.shape[2]
trimmed = data[y1:-y2,x1:-x2,:]
elif len(data.shape) == 2:
yDim,xDim = data.shape
nBands = 1
trimmed = data[y1:-y2,x1:-x2]
driver = gdal.GetDriverByName('GTiff')
outDs = driver.Create(outpath,xDim,yDim,nBands,gdal.GDT_Float32)
outDs.SetGeoTransform(gt)
outDs.SetProjection(srs)
for b in range(nBands):
band = outDs.GetRasterBand(b+1)
if noData:
band.SetNoDataValue(noData)
if nBands > 1:
band.WriteArray(trimmed[:,:,b])
else:
band.WriteArray(trimmed)
band = None
outDs.FlushCache()
return
def dice_loss(y_true, y_pred, smooth=1):
intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
true_sum = K.sum(K.square(y_true),-1)
pred_sum = K.sum(K.square(y_pred),-1)
return 1 - ((2. * intersection + smooth) / (true_sum + pred_sum + smooth))
def unet():
# A set of helper functions for defining a convolutional layer. We use batch
# normalization to speed up training given our limited training data, therefore
# we can't use vanilla conv2d(activation='relu', ...)
def conv_block(input_tensor, num_filters):
encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
encoder = layers.BatchNormalization()(encoder)
encoder = layers.Activation('relu')(encoder)
encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(encoder)
encoder = layers.BatchNormalization()(encoder)
encoder = layers.Activation('relu')(encoder)
return encoder
def encoder_block(input_tensor, num_filters):
encoder = conv_block(input_tensor, num_filters)
encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)
return encoder_pool, encoder
def decoder_block(input_tensor, concat_tensor, num_filters):
decoder = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
decoder = layers.concatenate([concat_tensor, decoder], axis=-1)
decoder = layers.BatchNormalization()(decoder)
decoder = layers.Activation('relu')(decoder)
decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
decoder = layers.BatchNormalization()(decoder)
decoder = layers.Activation('relu')(decoder)
decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
decoder = layers.BatchNormalization()(decoder)
decoder = layers.Activation('relu')(decoder)
return decoder
inputs = layers.Input(shape=[None, None, 5]) # 64
tensor_shape = tf.shape(inputs)
inputConv = layers.Conv2D(16, (1, 1), padding='same')(inputs)
inputConv = layers.BatchNormalization()(inputConv)
inputConv = layers.Activation('relu')(inputConv)
encoder0_pool, encoder0 = encoder_block(inputConv, 32) # 32
encoder1_pool, encoder1 = encoder_block(encoder0_pool, 64) # 16
encoder2_pool, encoder2 = encoder_block(encoder1_pool, 128) # 8
center = conv_block(encoder2_pool, 256) # center
decoder2 = decoder_block(center, encoder2, 128) # 64
decoder1 = decoder_block(decoder2, encoder1, 64) # 128
decoder0 = decoder_block(decoder1, encoder0, 32) # 256
# dense = layers.Conv2D(16, (3,3), padding='same')(decoder0)
# dense = layers.BatchNormalization()(dense)
# dense = layers.Activation('relu')(dense)
# final = layers.Lambda(lambda x: tf.image.resize_bicubic(x,tensor_shape[1:3]))(dense)
# This is suitable for a [0,1] response. Use linear otherwise.
outputs = layers.Conv2D(2, (1, 1),activation='softmax')(decoder0)
return inputs, outputs
def load_model(weights):
# get the UNet model
ins,outs = unet()
# and instantiate the model
model = keras.models.Model(inputs=[ins], outputs=[outs])
model.compile(optimizer='adam', loss=dice_loss, metrics=['accuracy'])
# load in weights from training
model.load_weights(weights)
return model
def make_prediction(img_path,model_weights,outpath=None,verbose=False,tileSize=256):
if verbose:
print("\nLoading image and tiling into examples...")
img = load_img(img_path,tileSize=tileSize)
examples = np.pad(chunk_img(img,(tileSize,tileSize)),
((0,0),(32,32),(32,32),(0,0)),
mode='reflect'
)
if verbose:
print("Done loading data. Loading convolutional network...")
model = load_model(model_weights)
if verbose:
print("Done loading network. Applying inference on image examples...")
pred = np.array([i[32:-32,32:-32] for i in model.predict(examples, verbose=int(verbose))])
# pred = da.map_overlap(in_da,lambda x: model.predict(x),
# depth={1:30,2:30},chunks=in_da.chunks,dtype=in_da.dtype).compute()
if verbose:
print("Done applying inference. Merging examples back to an image...")
final = merge_img(pred,img.shape[:2])
final = ndimage.gaussian_filter(final,sigma=2.5)
if outpath == None:
outpath = 'pond_inference.tif'
save_img(outpath,final[:,:,-1],img_path,tileSize=tileSize)
if verbose:
print("All done!")
return
if __name__ == "__main__":
fire.Fire()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment