Last active
July 13, 2021 17:15
-
-
Save mayrajeo/3f893f7947688ac84ebffcb38a672a67 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.core import ItemBase | |
from fastai.vision import * | |
from fastai.basic_data import * | |
from fastai.imports import * | |
import rasterio as rio | |
import scipy | |
""" | |
TODO: separate transformations and such to different files | |
""" | |
class HSDataBunch(ImageDataBunch): | |
"""Subclassing ImageDataBunch because normalize is defined by it""" | |
_square_show = True | |
def show_batch(self, mode:str='spectral', rows:int=5, ds_type:DatasetType=DatasetType.Train, reverse:bool=False, **kwargs)->None: | |
"Show a batch of data in `ds_type` on a few `rows`." | |
if mode not in ['spectral', 'rgb']: | |
print(f'Error! {mode} is invalid visualization type.') | |
return | |
x,y = self.one_batch(ds_type, True, True) | |
if reverse: x,y = x.flip(0),y.flip(0) | |
n_items = rows **2 if self.train_ds.x._square_show else rows | |
if self.dl(ds_type).batch_size < n_items: n_items = self.dl(ds_type).batch_size | |
xs = [self.train_ds.x.reconstruct(grab_idx(x, i)) for i in range(n_items)] | |
#TODO: get rid of has_arg if possible | |
if has_arg(self.train_ds.y.reconstruct, 'x'): | |
ys = [self.train_ds.y.reconstruct(grab_idx(y, i), x=x) for i,x in enumerate(xs)] | |
else : ys = [self.train_ds.y.reconstruct(grab_idx(y, i)) for i in range(n_items)] | |
self.train_ds.x.show_xys(xs, ys, mode, **kwargs) | |
class HSImage(Image): | |
"""Custom class for HSImage, modifies the show method""" | |
def __init__(self, px:Tensor): | |
super().__init__(px) | |
def show(self, mode='rgb', ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=False, y:Any=None, **kwargs): | |
""" | |
Show RGB render or average spectra | |
""" | |
if mode == 'spectral': | |
ax = show_spectra(self, ax=ax, hide_axis=hide_axis, figsize=figsize) | |
else: | |
ax = show_rgb(self, ax=ax, hide_axis=True, figsize=figsize) | |
if y is not None: y.show(ax=ax, **kwargs) | |
if title is not None: ax.set_title(title) | |
class HSImageItemList(ItemList): | |
""" | |
Custom ItemList for N-dimensional images either as image or volumetric data. Plotting utilities also added | |
flow: | |
""" | |
_bunch = HSDataBunch | |
_square_show = True | |
def __init__(self, items, dims=3, chans=list(range(461)), **kwargs): | |
super().__init__(items, **kwargs) | |
self.dims = dims | |
self.chans = chans | |
self.copy_new.append('dims') | |
self.copy_new.append('chans') | |
def open(self, fn)->HSImage: | |
if fn.endswith('.npy'): return open_npy(fn, dims=self.dims, chans=self.chans) | |
elif fn.endswith('.tif'): return open_geotiff(fn, dims=self.dims, chans=self.chans) | |
else: | |
print('Invalid file format, only .npy and .tif are supported now') | |
return None | |
@classmethod | |
def from_df(cls, df:DataFrame, path:PathOrStr, cols:IntsOrStrs=0, folder:PathOrStr=None, suffix:str='.npy', **kwargs)->ItemList: | |
suffix = suffix or '' | |
res = super().from_df(df, path=path, cols=cols, **kwargs) | |
pref = f'{res.path}{os.path.sep}' | |
if folder is not None: pref += f'{folder}{os.path.sep}' | |
res.items = np.char.add(np.char.add(pref, res.items.astype(str)), suffix) | |
return res | |
def get(self, i)->HSImage: | |
fn = super().get(i) | |
res = self.open(fn) | |
return res | |
def reconstruct(self, t:Tensor)->HSImage: | |
return HSImage(t.float()) | |
def show_xys(self, xs, ys, mode='rgb', imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs): | |
"Show the `xs` (inputs) and `ys` (targets) on a figure of `figsize`." | |
rows = int(np.ceil(math.sqrt(len(xs)))) | |
axs = subplots(rows, rows, imgsize=imgsize, figsize=figsize) | |
for x,y,ax in zip(xs, ys, axs.flatten()): x.show(ax=ax, y=y, mode=mode, **kwargs) | |
for ax in axs.flatten()[len(xs):]: ax.axis('off') | |
plt.tight_layout() | |
def show_xyzs(self, xs, ys, zs, mode='rgb', figsize:Tuple[int, int]=None, **kwargs)->None: | |
"Not sure if this works, fix if doesnt" | |
figsize = ifnone(figsize, (6, 3*len(xs))) | |
fig, ax = plt.subplots(len(xs), len(xs), figsize=figsize) | |
fig.suptitle('Ground truth / Predictions', weight='bold', size=14) | |
for i,(x,y,z) in enumerate(zip(xs,ys,zs)): | |
x.show(ax=ax[i,0], y=y, mode=mode, **kwargs) | |
x.show(ax=ax[i,1], y=z, mode=mode, **kwargs) | |
"Open function working with numpy arrays" | |
def open_npy(fn:PathOrStr, cls:type, dims:int=2, chans:Collection[int]=None): | |
im = torch.from_numpy(np.load(fn)) | |
if chans is not None: im = im[chans] | |
if dims == 3: im = im[None] | |
return cls(im) | |
def open_geotiff(fn:PathOrStr, cls:type, dims:int=2, chans:Collection[int]=None): | |
with rio.open(fn) as f: | |
data = f.read() | |
data = data.astype(np.float32) | |
im = torch.from_numpy(data) | |
if chans is not None: im = im[chans] | |
if dims == 3: im = im[None] | |
return cls(im) | |
""" | |
Custom visualization functions here, need refactoring | |
For geospatial images(Sentinel-2 multispectral images), useful visualizations include: | |
Single channel composite | |
RGB-composite with possibilities to select used channels | |
Normalized spectral indices, such as NDVI: (NIR-RED)/(NIR+RED) | |
""" | |
def show_spectra(img:Tensor, ax:plt.Axes=None, figsize:tuple=(3,6), hide_axis:bool=False, cmap:str='binary', **kwargs)->plt.Axes: | |
"Shows mean spectra of the tree, don't add CHM channel" | |
if ax is None: fig, ax = plt.subplots(figsize=figsize) | |
tempim = img.data.cpu().numpy() | |
axis = (-2,-1) if len(tempim.shape) == 3 else (0, -2, -1) | |
ax.plot(np.mean(tempim, axis = axis)[:-1], **kwargs) | |
if hide_axis: ax.axis('off') | |
ax.grid() | |
return ax | |
def show_rgb(img:Tensor, rgb_chans:Tuple[int,int,int]=(82,49,28), | |
ax:plt.Axes=None, figsize:tuple=(3,3), hide_axis:bool=False, **kwargs)->plt.axes: | |
"RGB rendition, mean reflectance according to Sentinel-2 RGB-bands but without overlap in B and G bands" | |
if ax is None: fig, ax = plt.subplots(figsize=figsize) | |
tempim = img.data.cpu().numpy() | |
if len(tempim.shape) == 4: tempim = tempim[0] | |
im = np.zeros((tempim.shape[1], tempim.shape[2], 3)) | |
im[...,0] = tempim[rgb_chans[0]] | |
im[...,1] = tempim[rgb_chans[1]] | |
im[...,2] = tempim[rgb_chans[2]] | |
im = (im - np.min(im))/(np.max(im) - np.min(im)) #quickhack visualization | |
ax.imshow(im) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
return ax | |
""" | |
Custom transforms for n-dimensional image or volumetric data | |
""" | |
def _hs_dihedral(x, k:partial(uniform_int,0,7)): | |
"Randomly flip `x` image based on `k`." | |
flips=[] | |
if k&1: flips.append(-2) | |
if k&2: flips.append(-1) | |
if flips: x = torch.flip(x,flips) | |
if k&4: x = x.transpose(-2,-1) | |
return x.contiguous() | |
hs_dihedral = TfmPixel(_hs_dihedral) | |
from torch.distributions.normal import Normal | |
def _brightness_hsi(x, change:uniform): | |
"""Apply `change` in brightness of image `x`. | |
Should be TfmLighting but it messes everything up for me""" | |
x.mul_(1+change) | |
# Chen et all also add random noise | |
#noise = torch.zeros(11,11) | |
#noise.data.normal_(0.5, 0.5) | |
#x[:,:-1].add_(1/25*(noise)) | |
return x | |
brightness_hsi = TfmPixel(_brightness_hsi) | |
from random import sample | |
def _cutout_hs(x, n_holes:uniform_int=1, length:uniform_int=3, cutout_d_pct:float=1): | |
"""Cut out `n_holes` number of square holes of size `length` in image | |
at random locations. Option to cutout either all bands or random pct of bands | |
All bands are zeroed when cutout_d_pct=1.0""" | |
d, h,w = x.shape[-3:] | |
cut_d = sorted(sample(list(range(d)), int(cutout_d_pct*d))) | |
for n in range(n_holes): | |
h_y = np.random.randint(0, h) | |
h_x = np.random.randint(0, w) | |
y1 = int(np.clip(h_y - length / 2, 0, h)) | |
y2 = int(np.clip(h_y + length / 2, 0, h)) | |
x1 = int(np.clip(h_x - length / 2, 0, w)) | |
x2 = int(np.clip(h_x + length / 2, 0, w)) | |
x[..., cut_d, y1:y2, x1:x2] = 0 | |
return x | |
cutout_hs = TfmPixel(_cutout_hs, order=20) | |
def _cutout_chen(x, length:uniform_int=3, cutout_d_pct:float=0.1): | |
"Cutout for 3d data, Chen et al 2019. Cutouts random pct of bands, all random locations" | |
d,h,w = x.shape[-3:] | |
cut_d = sorted(sample(list(range(d)), int(cutout_d_pct*d))) | |
for i in cut_d: | |
h_y = np.random.randint(0,h) | |
h_x = np.random.randint(0,w) | |
y1 = int(np.clip(h_y - length / 2, 0, h)) | |
y2 = int(np.clip(h_y + length / 2, 0, h)) | |
x1 = int(np.clip(h_x - length / 2, 0, w)) | |
x2 = int(np.clip(h_x + length / 2, 0, w)) | |
x[..., i, y1:y2, x1:x2] = 0 | |
return x | |
cutout_chen = TfmPixel(_cutout_chen, order=20) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment