Skip to content

Instantly share code, notes, and snippets.

@suricactus
Created October 29, 2019 09:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save suricactus/d7ec0e1718271f15a432f033d6d2e7af to your computer and use it in GitHub Desktop.
Save suricactus/d7ec0e1718271f15a432f033d6d2e7af to your computer and use it in GitHub Desktop.
from typing import (Union, Dict, Any)
import numpy as np
import pandas as pd
import geopandas as gpd
from geopandas.geodataframe import GeoDataFrame
import rasterio
import rasterio.features
from rasterio.io import DatasetWriter
from rasterio.enums import ColorInterp
from rasterio.plot import show, show_hist
from skimage import exposure
import mpl_toolkits.mplot3d
import matplotlib.pyplot as plt
import matplotlib.cm as cm
WV2_COLORBANDS = (ColorInterp.blue, ColorInterp.green,
ColorInterp.red, ColorInterp.yellow,
ColorInterp.yellow, ColorInterp.undefined,
ColorInterp.undefined, ColorInterp.undefined)
BAND_NAMES = ['Red', 'Blue', 'Green']
def hist_stretch(in_arr, bands=None, clip_extremes=(0, 100)):
"""General purpose histogram stretching."""
total_bands, *dimensions = in_arr.shape
if bands is None:
bands = range(1, total_bands)
stretched = []
if isinstance(clip_extremes, int):
clip_extremes = (0 + clip_extremes, 100 - clip_extremes)
for band in bands:
arr = in_arr[band]
percentile_min, percentile_max = np.percentile(arr, clip_extremes)
img_rescale = exposure.rescale_intensity(arr, in_range=(percentile_min, percentile_max))
stretched.append(img_rescale)
return np.array(stretched)
def make_data_frame(arr, col_names):
if len(arr.shape) > 2:
total_bands, width, height = arr.shape
data = arr.flatten().reshape(total_bands, width * height).T
else:
data = arr.flatten()
df = pd.DataFrame(data, columns=col_names)
df = df.where(df != 0)
return df
def prepare_dataset(filename: Union[str, GeoDataFrame], img_source: DatasetWriter) -> np.ndarray:
shapefile: GeoDataFrame
if isinstance(filename, str):
shapefile = normalize_labels_df(gpd.read_file(filename))
else:
shapefile = filename
assert isinstance(shapefile, GeoDataFrame)
shapefile['id'] = shapefile['id'].astype('int8')
shapefile[['red', 'green', 'blue']] = shapefile['colors'].str.split(
',', expand=True).astype('int8')
mask = rasterio.features.rasterize(
((f['geometry'], f['id']) for i, f in shapefile.iterrows()),
out_shape=img.shape[1:],
transform=img_source.transform,
dtype='uint8'
)
return mask
def count_occurances(arr: np.ndarray) -> Dict[Any, int]:
unique, counts = np.unique(arr, return_counts=True)
return dict(zip(unique, counts))
def get_classified_pixels(img, mask, labels: GeoDataFrame):
bands_count = img.shape[0]
label_band_vals = np.empty((bands_count + 3, 0), dtype='int32')
for i, label in labels.iterrows():
idx_x, idx_y = np.where(mask == label['id'])
band_vals = img[:, idx_x, idx_y]
shape = (3, band_vals.shape[1])
class_and_x_and_y = np.full(shape, label['id'])
class_and_x_and_y[1] = idx_x
class_and_x_and_y[2] = idx_y
band_vals = np.append(
class_and_x_and_y,
band_vals,
axis=0)
label_band_vals = np.append(label_band_vals, band_vals, axis=1)
columns = ['class_id', 'x', 'y', *['b{}'.format(i + 1) for i in range(0, bands_count)]]
df = pd.DataFrame(label_band_vals.transpose(), columns=columns)
return df
def normalize_labels_df(df: GeoDataFrame):
return df.rename(columns={
'CLASS_ID': 'id',
'CLASS_NAME': 'class',
'CLASS_CLRS': 'colors',
})
img_source: DatasetWriter
with rasterio.open('./src/input/WV2_Zundert', 'r+') as img_source:
img_source.colorinterp = WV2_COLORBANDS
original_img = img_source.read()
img = hist_stretch(original_img, bands=(7, 5, 3), clip_extremes=2)
img_dataframe = make_data_frame(img, BAND_NAMES)
training_df = normalize_labels_df(gpd.read_file('./src/input/training_set.shp'))
testing_df = normalize_labels_df(gpd.read_file('./src/input/test_set.shp'))
training_mask = prepare_dataset(training_df, img_source)
testing_mask = prepare_dataset(testing_df, img_source)
labels = training_df[['id', 'class']].copy().sort_values('id').drop_duplicates('id')
training_px_df = get_classified_pixels(img, training_mask, labels)
testing_px_df = get_classified_pixels(img, testing_mask, labels)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(training_px_df['b1'],
training_px_df['b2'],
training_px_df['b3'],
c=training_px_df['class_id'])
#%%
# fig = plt.figure()
# ax = fig.add_subplot(111)
# ax.scatter(training_px_df['b1'],
# training_px_df['b2'])
plt.show()
# for class_id in np.unique(mask_training):
# if class_id == 0:
# continue
# idx_v, idx_h = np.where(mask_training == class_id)
# temp_result = img[:, idx_v, idx_h]
# shape = (1, temp_result.shape[1])
# temp_result = np.append(temp_result, np.full(shape, class_id), axis=0)
# result = np.append(result, temp_result, axis=1)
# result = result.transpose()
# # show(mask_testing)
# # show_hist(img, bins=50, histtype='stepfilled',
# # lw=0.0, stacked=False, alpha=0.3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment