Skip to content

Instantly share code, notes, and snippets.

@a-maumau
Created July 26, 2018 07:38
Show Gist options
  • Save a-maumau/6d0a50ec15be89851e9fc65629111d07 to your computer and use it in GitHub Desktop.
Save a-maumau/6d0a50ec15be89851e9fc65629111d07 to your computer and use it in GitHub Desktop.
parse SpaceNet dataset.
import os
### geopandas must be imported earlier than osgeo ###
import geopandas as gpd
from osgeo import gdal, ogr, osr
import numpy as np
import scipy
from scipy.misc import bytescale
import osmnx
import cv2
import skimage
from skimage import exposure
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
# nvidia stretch code
def retrieve_bands(ds, x_size, y_size, bands):
stack = np.zeros([x_size, y_size, len(bands)])
for i, band in enumerate(bands):
src_band = ds.GetRasterBand(band)
band_arr = src_band.ReadAsArray()
stack[:, :, i] = band_arr
return stack
def contrast_stretch(np_img, p1_clip=2, p2_clip=90):
x, y, bands = np_img.shape
return_stack = np.zeros([x, y, bands], dtype=np.uint8)
for b in range(bands):
cur_b = np_img[:, :, b]
p1_pix, p2_pix = np.percentile(cur_b, (p1_clip, p2_clip))
return_stack[:, :, b] = bytescale(exposure.rescale_intensity(cur_b, out_range=(p1_pix, p2_pix)))
return return_stack
# below codes are from https://github.com/CosmiQ/apls/blob/master/src/apls_tools.py
def create_buffer_geopandas(geoJsonFileName,
bufferDistanceMeters=2,
bufferRoundness=1, projectToUTM=True):
'''
Create a buffer around the lines of the geojson.
Return a geodataframe.
'''
inGDF = gpd.read_file(geoJsonFileName)
# set a few columns that we will need later
inGDF['type'] = inGDF['road_type'].values
inGDF['class'] = 'highway'
inGDF['highway'] = 'highway'
if len(inGDF) == 0:
return [], []
# Transform gdf Roadlines into UTM so that Buffer makes sense
if projectToUTM:
tmpGDF = osmnx.project_gdf(inGDF)
else:
tmpGDF = inGDF
gdf_utm_buffer = tmpGDF
# perform Buffer to produce polygons from Line Segments
gdf_utm_buffer['geometry'] = tmpGDF.buffer(bufferDistanceMeters,bufferRoundness)
gdf_utm_dissolve = gdf_utm_buffer.dissolve(by='class')
gdf_utm_dissolve.crs = gdf_utm_buffer.crs
if projectToUTM:
gdf_buffer = gdf_utm_dissolve.to_crs(inGDF.crs)
else:
gdf_buffer = gdf_utm_dissolve
return gdf_buffer
def gdf_to_array(gdf, im_file, output_raster, burnValue=150):
'''
Turn geodataframe to array, save as image file with non-null pixels
set to burnValue
'''
NoData_value = 0 # -9999
gdata = gdal.Open(im_file)
# set target info
target_ds = gdal.GetDriverByName('GTiff').Create(output_raster,
gdata.RasterXSize,
gdata.RasterYSize, 1, gdal.GDT_Byte)
target_ds.SetGeoTransform(gdata.GetGeoTransform())
# set raster info
raster_srs = osr.SpatialReference()
raster_srs.ImportFromWkt(gdata.GetProjectionRef())
target_ds.SetProjection(raster_srs.ExportToWkt())
band = target_ds.GetRasterBand(1)
band.SetNoDataValue(NoData_value)
outdriver=ogr.GetDriverByName('MEMORY')
outDataSource=outdriver.CreateDataSource('memData')
tmp=outdriver.Open('memData',1)
outLayer = outDataSource.CreateLayer("states_extent", raster_srs,
geom_type=ogr.wkbMultiPolygon)
# burn
burnField = "burn"
idField = ogr.FieldDefn(burnField, ogr.OFTInteger)
outLayer.CreateField(idField)
featureDefn = outLayer.GetLayerDefn()
for geomShape in gdf['geometry'].values:
outFeature = ogr.Feature(featureDefn)
outFeature.SetGeometry(ogr.CreateGeometryFromWkt(geomShape.wkt))
outFeature.SetField(burnField, burnValue)
outLayer.CreateFeature(outFeature)
outFeature = 0
gdal.RasterizeLayer(target_ds, [1], outLayer, burn_values=[burnValue])
outLayer = 0
outDatSource = 0
tmp = 0
return
def get_road_buffer(geoJson, im_vis_file, output_raster,
buffer_meters=2, burnValue=1, bufferRoundness=6,
plot_file='', figsize=(6,6), fontsize=6,
dpi=800, show_plot=False,
verbose=False):
'''
Get buffer around roads defined by geojson and image files.
Calls create_buffer_geopandas() and gdf_to_array().
Assumes in_vis_file is an 8-bit RGB file.
Returns geodataframe and ouptut mask.
'''
gdf_buffer = create_buffer_geopandas(geoJson,
bufferDistanceMeters=buffer_meters,
bufferRoundness=bufferRoundness,
projectToUTM=True)
# create label image
if len(gdf_buffer) == 0:
mask_gray = np.zeros(cv2.imread(im_vis_file,0).shape)
cv2.imwrite(output_raster, mask_gray)
else:
gdf_to_array(gdf_buffer, im_vis_file, output_raster,
burnValue=burnValue)
# load mask
mask_gray = cv2.imread(output_raster, 0)
# make plots
if plot_file:
# plot all in a line
if (figsize[0] != figsize[1]):
fig, (ax0, ax1, ax2, ax3) = plt.subplots(1,4, figsize=figsize)#(13,4))
# else, plot a 2 x 2 grid
else:
fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2,2, figsize=figsize)
# road lines
try:
gdfRoadLines = gpd.read_file(geoJson)
gdfRoadLines.plot(ax=ax0, marker='o', color='red')
except:
ax0.imshow(mask_gray)
ax0.axis('off')
ax0.set_aspect('equal')
ax0.set_title('Roads from GeoJson', fontsize=fontsize)
# first show raw image
im_vis = cv2.imread(im_vis_file, 1)
img_mpl = cv2.cvtColor(im_vis, cv2.COLOR_BGR2RGB)
ax1.imshow(img_mpl)
ax1.axis('off')
ax1.set_title('8-bit RGB Image', fontsize=fontsize)
# plot mask
ax2.imshow(mask_gray)
ax2.axis('off')
ax2.set_title('Roads Mask (' + str(np.round(buffer_meters)) \
+ ' meter buffer)', fontsize=fontsize)
# plot combined
ax3.imshow(img_mpl)
# overlay mask
# set zeros to nan
z = mask_gray.astype(float)
z[z==0] = np.nan
# change palette to orange
palette = plt.cm.gray
#palette.set_over('yellow', 0.9)
palette.set_over('lime', 0.9)
ax3.imshow(z, cmap=palette, alpha=0.66,
norm=matplotlib.colors.Normalize(vmin=0.5, vmax=0.9, clip=False))
ax3.set_title('8-bit RGB Image + Buffered Roads', fontsize=fontsize)
ax3.axis('off')
#plt.axes().set_aspect('equal', 'datalim')
plt.tight_layout()
plt.savefig(plot_file, dpi=dpi)
if not show_plot:
plt.close()
return mask_gray, gdf_buffer
min_percent = 5
max_percent = 90
mask_output_dir = "masks"
jpg_output_dir = "jpgs"
plot_file = os.path.join('mask_plot.png')
# plase rewrite to be compatible with your dataset directory.
tif_image_dir = "AOI_2_Vegas_Roads_Train/RGB-PanSharpen"
geojson_file_dir = "AOI_2_Vegas_Roads_Train/geojson/spacenetroads"
if not os.path.exists(mask_output_dir):
os.makedirs(mask_output_dir)
if not os.path.exists(jpg_output_dir):
os.makedirs(jpg_output_dir)
image_list = os.listdir(tif_image_dir)
for name in image_list:
try:
print(name)
image_name = os.path.join(tif_image_dir, name)
# plase rewrite to be compatible with your dataset.
geojson_file = os.path.join(geojson_file_dir, name.replace(".tif", ".geojson").replace("RGB-PanSharpen_AOI_2_Vegas_img", "spacenetroads_AOI_2_Vegas_img"))
jpg_name = os.path.join(jpg_output_dir, name.replace(".tif", ".jpg"))
mask_raster = os.path.join(mask_output_dir, name.replace(".tif", ".png"))
# in the `get_road_buffer` function, it save images, so actually we don't need mask and gdf_buffer
mask, gdf_buffer = get_road_buffer(geojson_file, image_name,
mask_raster,
#buffer_meters=2,
#burnValue=1,
bufferRoundness=6,
plot_file=plot_file,
figsize= (6,6), #(13,4),
fontsize=8,
dpi=200, show_plot=False,
verbose=False)
ds = gdal.Open(image_name)
# channel w x h x channel
channel = np.array([ds.GetRasterBand(1).ReadAsArray(), ds.GetRasterBand(2).ReadAsArray(), ds.GetRasterBand(2).ReadAsArray()]).transpose(1,2,0)
band = contrast_stretch(channel, min_percent, max_percent)
# little bit change
#band[:,:,0] = band[:,:,0]
#band[band[:,:,1]>1] = band[band[:,:,1]>1]-2
#band[band[:,:,2]>9] = band[band[:,:,2]>9]-10
img = Image.fromarray(np.uint8(band))
img.save(jpg_name, quality=100)
except:
print("skip {}".format(image_name))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment