Created
October 29, 2019 09:13
-
-
Save suricactus/d7ec0e1718271f15a432f033d6d2e7af to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import (Union, Dict, Any) | |
import numpy as np | |
import pandas as pd | |
import geopandas as gpd | |
from geopandas.geodataframe import GeoDataFrame | |
import rasterio | |
import rasterio.features | |
from rasterio.io import DatasetWriter | |
from rasterio.enums import ColorInterp | |
from rasterio.plot import show, show_hist | |
from skimage import exposure | |
import mpl_toolkits.mplot3d | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
WV2_COLORBANDS = (ColorInterp.blue, ColorInterp.green, | |
ColorInterp.red, ColorInterp.yellow, | |
ColorInterp.yellow, ColorInterp.undefined, | |
ColorInterp.undefined, ColorInterp.undefined) | |
BAND_NAMES = ['Red', 'Blue', 'Green'] | |
def hist_stretch(in_arr, bands=None, clip_extremes=(0, 100)): | |
"""General purpose histogram stretching.""" | |
total_bands, *dimensions = in_arr.shape | |
if bands is None: | |
bands = range(1, total_bands) | |
stretched = [] | |
if isinstance(clip_extremes, int): | |
clip_extremes = (0 + clip_extremes, 100 - clip_extremes) | |
for band in bands: | |
arr = in_arr[band] | |
percentile_min, percentile_max = np.percentile(arr, clip_extremes) | |
img_rescale = exposure.rescale_intensity(arr, in_range=(percentile_min, percentile_max)) | |
stretched.append(img_rescale) | |
return np.array(stretched) | |
def make_data_frame(arr, col_names): | |
if len(arr.shape) > 2: | |
total_bands, width, height = arr.shape | |
data = arr.flatten().reshape(total_bands, width * height).T | |
else: | |
data = arr.flatten() | |
df = pd.DataFrame(data, columns=col_names) | |
df = df.where(df != 0) | |
return df | |
def prepare_dataset(filename: Union[str, GeoDataFrame], img_source: DatasetWriter) -> np.ndarray: | |
shapefile: GeoDataFrame | |
if isinstance(filename, str): | |
shapefile = normalize_labels_df(gpd.read_file(filename)) | |
else: | |
shapefile = filename | |
assert isinstance(shapefile, GeoDataFrame) | |
shapefile['id'] = shapefile['id'].astype('int8') | |
shapefile[['red', 'green', 'blue']] = shapefile['colors'].str.split( | |
',', expand=True).astype('int8') | |
mask = rasterio.features.rasterize( | |
((f['geometry'], f['id']) for i, f in shapefile.iterrows()), | |
out_shape=img.shape[1:], | |
transform=img_source.transform, | |
dtype='uint8' | |
) | |
return mask | |
def count_occurances(arr: np.ndarray) -> Dict[Any, int]: | |
unique, counts = np.unique(arr, return_counts=True) | |
return dict(zip(unique, counts)) | |
def get_classified_pixels(img, mask, labels: GeoDataFrame): | |
bands_count = img.shape[0] | |
label_band_vals = np.empty((bands_count + 3, 0), dtype='int32') | |
for i, label in labels.iterrows(): | |
idx_x, idx_y = np.where(mask == label['id']) | |
band_vals = img[:, idx_x, idx_y] | |
shape = (3, band_vals.shape[1]) | |
class_and_x_and_y = np.full(shape, label['id']) | |
class_and_x_and_y[1] = idx_x | |
class_and_x_and_y[2] = idx_y | |
band_vals = np.append( | |
class_and_x_and_y, | |
band_vals, | |
axis=0) | |
label_band_vals = np.append(label_band_vals, band_vals, axis=1) | |
columns = ['class_id', 'x', 'y', *['b{}'.format(i + 1) for i in range(0, bands_count)]] | |
df = pd.DataFrame(label_band_vals.transpose(), columns=columns) | |
return df | |
def normalize_labels_df(df: GeoDataFrame): | |
return df.rename(columns={ | |
'CLASS_ID': 'id', | |
'CLASS_NAME': 'class', | |
'CLASS_CLRS': 'colors', | |
}) | |
img_source: DatasetWriter | |
with rasterio.open('./src/input/WV2_Zundert', 'r+') as img_source: | |
img_source.colorinterp = WV2_COLORBANDS | |
original_img = img_source.read() | |
img = hist_stretch(original_img, bands=(7, 5, 3), clip_extremes=2) | |
img_dataframe = make_data_frame(img, BAND_NAMES) | |
training_df = normalize_labels_df(gpd.read_file('./src/input/training_set.shp')) | |
testing_df = normalize_labels_df(gpd.read_file('./src/input/test_set.shp')) | |
training_mask = prepare_dataset(training_df, img_source) | |
testing_mask = prepare_dataset(testing_df, img_source) | |
labels = training_df[['id', 'class']].copy().sort_values('id').drop_duplicates('id') | |
training_px_df = get_classified_pixels(img, training_mask, labels) | |
testing_px_df = get_classified_pixels(img, testing_mask, labels) | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection='3d') | |
ax.scatter(training_px_df['b1'], | |
training_px_df['b2'], | |
training_px_df['b3'], | |
c=training_px_df['class_id']) | |
#%% | |
# fig = plt.figure() | |
# ax = fig.add_subplot(111) | |
# ax.scatter(training_px_df['b1'], | |
# training_px_df['b2']) | |
plt.show() | |
# for class_id in np.unique(mask_training): | |
# if class_id == 0: | |
# continue | |
# idx_v, idx_h = np.where(mask_training == class_id) | |
# temp_result = img[:, idx_v, idx_h] | |
# shape = (1, temp_result.shape[1]) | |
# temp_result = np.append(temp_result, np.full(shape, class_id), axis=0) | |
# result = np.append(result, temp_result, axis=1) | |
# result = result.transpose() | |
# # show(mask_testing) | |
# # show_hist(img, bins=50, histtype='stepfilled', | |
# # lw=0.0, stacked=False, alpha=0.3) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment