Skip to content

Instantly share code, notes, and snippets.

@mayrajeo
Last active July 13, 2021 17:15
Show Gist options
  • Save mayrajeo/3f893f7947688ac84ebffcb38a672a67 to your computer and use it in GitHub Desktop.
Save mayrajeo/3f893f7947688ac84ebffcb38a672a67 to your computer and use it in GitHub Desktop.
from fastai.core import ItemBase
from fastai.vision import *
from fastai.basic_data import *
from fastai.imports import *
import rasterio as rio
import scipy
"""
TODO: separate transformations and such to different files
"""
class HSDataBunch(ImageDataBunch):
"""Subclassing ImageDataBunch because normalize is defined by it"""
_square_show = True
def show_batch(self, mode:str='spectral', rows:int=5, ds_type:DatasetType=DatasetType.Train, reverse:bool=False, **kwargs)->None:
"Show a batch of data in `ds_type` on a few `rows`."
if mode not in ['spectral', 'rgb']:
print(f'Error! {mode} is invalid visualization type.')
return
x,y = self.one_batch(ds_type, True, True)
if reverse: x,y = x.flip(0),y.flip(0)
n_items = rows **2 if self.train_ds.x._square_show else rows
if self.dl(ds_type).batch_size < n_items: n_items = self.dl(ds_type).batch_size
xs = [self.train_ds.x.reconstruct(grab_idx(x, i)) for i in range(n_items)]
#TODO: get rid of has_arg if possible
if has_arg(self.train_ds.y.reconstruct, 'x'):
ys = [self.train_ds.y.reconstruct(grab_idx(y, i), x=x) for i,x in enumerate(xs)]
else : ys = [self.train_ds.y.reconstruct(grab_idx(y, i)) for i in range(n_items)]
self.train_ds.x.show_xys(xs, ys, mode, **kwargs)
class HSImage(Image):
"""Custom class for HSImage, modifies the show method"""
def __init__(self, px:Tensor):
super().__init__(px)
def show(self, mode='rgb', ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=False, y:Any=None, **kwargs):
"""
Show RGB render or average spectra
"""
if mode == 'spectral':
ax = show_spectra(self, ax=ax, hide_axis=hide_axis, figsize=figsize)
else:
ax = show_rgb(self, ax=ax, hide_axis=True, figsize=figsize)
if y is not None: y.show(ax=ax, **kwargs)
if title is not None: ax.set_title(title)
class HSImageItemList(ItemList):
"""
Custom ItemList for N-dimensional images either as image or volumetric data. Plotting utilities also added
flow:
"""
_bunch = HSDataBunch
_square_show = True
def __init__(self, items, dims=3, chans=list(range(461)), **kwargs):
super().__init__(items, **kwargs)
self.dims = dims
self.chans = chans
self.copy_new.append('dims')
self.copy_new.append('chans')
def open(self, fn)->HSImage:
if fn.endswith('.npy'): return open_npy(fn, dims=self.dims, chans=self.chans)
elif fn.endswith('.tif'): return open_geotiff(fn, dims=self.dims, chans=self.chans)
else:
print('Invalid file format, only .npy and .tif are supported now')
return None
@classmethod
def from_df(cls, df:DataFrame, path:PathOrStr, cols:IntsOrStrs=0, folder:PathOrStr=None, suffix:str='.npy', **kwargs)->ItemList:
suffix = suffix or ''
res = super().from_df(df, path=path, cols=cols, **kwargs)
pref = f'{res.path}{os.path.sep}'
if folder is not None: pref += f'{folder}{os.path.sep}'
res.items = np.char.add(np.char.add(pref, res.items.astype(str)), suffix)
return res
def get(self, i)->HSImage:
fn = super().get(i)
res = self.open(fn)
return res
def reconstruct(self, t:Tensor)->HSImage:
return HSImage(t.float())
def show_xys(self, xs, ys, mode='rgb', imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Show the `xs` (inputs) and `ys` (targets) on a figure of `figsize`."
rows = int(np.ceil(math.sqrt(len(xs))))
axs = subplots(rows, rows, imgsize=imgsize, figsize=figsize)
for x,y,ax in zip(xs, ys, axs.flatten()): x.show(ax=ax, y=y, mode=mode, **kwargs)
for ax in axs.flatten()[len(xs):]: ax.axis('off')
plt.tight_layout()
def show_xyzs(self, xs, ys, zs, mode='rgb', figsize:Tuple[int, int]=None, **kwargs)->None:
"Not sure if this works, fix if doesnt"
figsize = ifnone(figsize, (6, 3*len(xs)))
fig, ax = plt.subplots(len(xs), len(xs), figsize=figsize)
fig.suptitle('Ground truth / Predictions', weight='bold', size=14)
for i,(x,y,z) in enumerate(zip(xs,ys,zs)):
x.show(ax=ax[i,0], y=y, mode=mode, **kwargs)
x.show(ax=ax[i,1], y=z, mode=mode, **kwargs)
"Open function working with numpy arrays"
def open_npy(fn:PathOrStr, cls:type, dims:int=2, chans:Collection[int]=None):
im = torch.from_numpy(np.load(fn))
if chans is not None: im = im[chans]
if dims == 3: im = im[None]
return cls(im)
def open_geotiff(fn:PathOrStr, cls:type, dims:int=2, chans:Collection[int]=None):
with rio.open(fn) as f:
data = f.read()
data = data.astype(np.float32)
im = torch.from_numpy(data)
if chans is not None: im = im[chans]
if dims == 3: im = im[None]
return cls(im)
"""
Custom visualization functions here, need refactoring
For geospatial images(Sentinel-2 multispectral images), useful visualizations include:
Single channel composite
RGB-composite with possibilities to select used channels
Normalized spectral indices, such as NDVI: (NIR-RED)/(NIR+RED)
"""
def show_spectra(img:Tensor, ax:plt.Axes=None, figsize:tuple=(3,6), hide_axis:bool=False, cmap:str='binary', **kwargs)->plt.Axes:
"Shows mean spectra of the tree, don't add CHM channel"
if ax is None: fig, ax = plt.subplots(figsize=figsize)
tempim = img.data.cpu().numpy()
axis = (-2,-1) if len(tempim.shape) == 3 else (0, -2, -1)
ax.plot(np.mean(tempim, axis = axis)[:-1], **kwargs)
if hide_axis: ax.axis('off')
ax.grid()
return ax
def show_rgb(img:Tensor, rgb_chans:Tuple[int,int,int]=(82,49,28),
ax:plt.Axes=None, figsize:tuple=(3,3), hide_axis:bool=False, **kwargs)->plt.axes:
"RGB rendition, mean reflectance according to Sentinel-2 RGB-bands but without overlap in B and G bands"
if ax is None: fig, ax = plt.subplots(figsize=figsize)
tempim = img.data.cpu().numpy()
if len(tempim.shape) == 4: tempim = tempim[0]
im = np.zeros((tempim.shape[1], tempim.shape[2], 3))
im[...,0] = tempim[rgb_chans[0]]
im[...,1] = tempim[rgb_chans[1]]
im[...,2] = tempim[rgb_chans[2]]
im = (im - np.min(im))/(np.max(im) - np.min(im)) #quickhack visualization
ax.imshow(im)
ax.set_xticks([])
ax.set_yticks([])
return ax
"""
Custom transforms for n-dimensional image or volumetric data
"""
def _hs_dihedral(x, k:partial(uniform_int,0,7)):
"Randomly flip `x` image based on `k`."
flips=[]
if k&1: flips.append(-2)
if k&2: flips.append(-1)
if flips: x = torch.flip(x,flips)
if k&4: x = x.transpose(-2,-1)
return x.contiguous()
hs_dihedral = TfmPixel(_hs_dihedral)
from torch.distributions.normal import Normal
def _brightness_hsi(x, change:uniform):
"""Apply `change` in brightness of image `x`.
Should be TfmLighting but it messes everything up for me"""
x.mul_(1+change)
# Chen et all also add random noise
#noise = torch.zeros(11,11)
#noise.data.normal_(0.5, 0.5)
#x[:,:-1].add_(1/25*(noise))
return x
brightness_hsi = TfmPixel(_brightness_hsi)
from random import sample
def _cutout_hs(x, n_holes:uniform_int=1, length:uniform_int=3, cutout_d_pct:float=1):
"""Cut out `n_holes` number of square holes of size `length` in image
at random locations. Option to cutout either all bands or random pct of bands
All bands are zeroed when cutout_d_pct=1.0"""
d, h,w = x.shape[-3:]
cut_d = sorted(sample(list(range(d)), int(cutout_d_pct*d)))
for n in range(n_holes):
h_y = np.random.randint(0, h)
h_x = np.random.randint(0, w)
y1 = int(np.clip(h_y - length / 2, 0, h))
y2 = int(np.clip(h_y + length / 2, 0, h))
x1 = int(np.clip(h_x - length / 2, 0, w))
x2 = int(np.clip(h_x + length / 2, 0, w))
x[..., cut_d, y1:y2, x1:x2] = 0
return x
cutout_hs = TfmPixel(_cutout_hs, order=20)
def _cutout_chen(x, length:uniform_int=3, cutout_d_pct:float=0.1):
"Cutout for 3d data, Chen et al 2019. Cutouts random pct of bands, all random locations"
d,h,w = x.shape[-3:]
cut_d = sorted(sample(list(range(d)), int(cutout_d_pct*d)))
for i in cut_d:
h_y = np.random.randint(0,h)
h_x = np.random.randint(0,w)
y1 = int(np.clip(h_y - length / 2, 0, h))
y2 = int(np.clip(h_y + length / 2, 0, h))
x1 = int(np.clip(h_x - length / 2, 0, w))
x2 = int(np.clip(h_x + length / 2, 0, w))
x[..., i, y1:y2, x1:x2] = 0
return x
cutout_chen = TfmPixel(_cutout_chen, order=20)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment