Last active
January 7, 2022 01:31
Random Affine Crop in the style of Albumentations for a Rasterio Dataset with minimal dependencies
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Utility functions for managing 3x3 matrices for cv2.warpAffine in pure numpy | |
import numpy as np | |
def identity(): | |
return np.eye(3, dtype=np.float64) | |
def affine(A=None, t=None): | |
aff = identity() | |
if A is not None: | |
aff[0:2, 0:2] = A | |
if t is not None: | |
aff[0:2, 2] = t | |
return aff | |
def rotate(theta): | |
'''Rotate counter-clockwise.''' | |
return affine(A=[[ np.cos(theta), np.sin(theta)], | |
[-np.sin(theta), np.cos(theta)]]) | |
def rotate_around(rx, ry, theta): | |
''' | |
Rotate counter-clockwise around (rx, ry) | |
Rotating around (0, 0) in opencv is actually rotating around the centre | |
of the first pixel. | |
''' | |
return concatenate([ | |
translate(-rx, -ry), | |
rotate(theta), | |
translate(rx, ry) | |
]) | |
def scale(sx, sy=None): | |
'''Scale.''' | |
if sy is None: sy = sx | |
return affine(A=[[sx, 0], | |
[ 0, sy]]) | |
def scale_around(srx, sry, sx, sy=None): | |
''' Scale around a different registration point ''' | |
return concatenate([ | |
translate(-srx, -sry), | |
scale(sx, sy), | |
translate(srx, sry), | |
]) | |
def translate(tx, ty): | |
'''Translate.''' | |
return affine(t=[tx, ty]) | |
def concatenate(matrices): | |
matrix = identity() | |
for m in reversed(matrices): | |
matrix = matrix @ m | |
return matrix | |
def homogeneous(coords): | |
''' | |
Takes 2D coords [N, ..., {x, y}] to homogeneous coords [N, ..., {x, y, 1}] | |
''' | |
ones = np.ones((*coords.shape[:-1], 1)) | |
return np.concatenate([coords, ones], axis=-1) | |
def unhomogeneous(coords): | |
''' Inverse of homogeneous() ''' | |
# Normalise | |
coords_norm = coords/coords[..., 2:] | |
return coords_norm[..., :2] | |
def transform(mat, coords): | |
''' Batch transform `coords` shapend [..., 2]. ''' | |
coords_np = homogeneous(np.asarray(coords)) | |
return unhomogeneous(coords_np @ mat.T) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This was created for handling large TIF files efficiently. | |
# There is one main class in this file: RandomAffineCrop | |
# This can do most of what Albumentations.Affine does but with a major | |
# distinction. | |
# You can call RandomAffineCrop.apply() on a rasterio.Dataset, and it will | |
# only load the minimal amount of data required to provide a crop. | |
# This is in contrast with Albumentations.Affine which requires an image | |
# to be fully loaded in memory. | |
import decimal | |
import math | |
import cv2 | |
import numpy as np | |
import albumentations as A | |
import spatial | |
import mat3 | |
def safe_int(v, atol=7): | |
''' | |
Simply casting to int truncates 0.999999999999 to 0. | |
This accounts for floating point imprecisions. | |
''' | |
return int(round(decimal.Decimal(v), atol)) | |
class RandomAffineCrop(A.DualTransform): | |
''' | |
Applies the following operations in order: | |
translation | |
scaling | |
rotation | |
This is represented as a params dictionary: | |
{ | |
'translate': (dx, dy), | |
'scale': (srx, sry, scale), | |
'rotate': (rx, ry, angle) | |
} | |
- dx, dy are normalised to IMAGE dimensions | |
- srx, sry are normalised to CROP dimensions | |
- rx, ry are normalised to CROP dimensions | |
- angle is in radians | |
You can generate a dictionary like this with random values by using | |
self.get_params*() functions | |
Then theoretically `apply()` crops the resulting image at crop_size. | |
In reality, it only loads the minimum image data from disk. | |
You may provide the random distribution for each operation as a callable | |
that takes no parameters. | |
(in the below descriptions `f` refers to `float`) | |
Attributes | |
---------- | |
crop_size : int or (int, int) | |
Size of crop in pixels | |
translate_dist : (f, f) or callable, opt | |
if 2-tuple, interpreted as uniform range for both dimensions | |
if callable, must return (dx, dy) | |
scale_dist : f or (f, f, f) or callable, opt | |
if float, interpreted as symmetric range limit. | |
e.g. scale_dist=2 means, select scale between 0.5 and 2, with equal | |
chance to spatial <1 as to spatial >1 | |
The scale registration point defaults to crop centre | |
if 3-tuple, interpreted as (srx, sry, symmetric range limit) | |
e.g. (0.5, 0.5, 2) | |
(both above examples perform the same sampling) | |
if callable, must return (srx, sry, scale) | |
rotate_dist : (f, f) or (f, f, (f,f)) or callable, opt | |
if 2-tuple interpreted as angle (radians) range. | |
e.g. (-math.pi/2, math.pi/2) | |
the rotation centre defaults to the centre of the crop | |
if 3-tuple interpreted as (rx, ry, angle (radians) range) | |
e.g. (0.5, 0.5, (-math.pi/2, math.pi/2)) | |
(both above examples perform the same sampling) | |
if callable, must return (rx, ry, angle) | |
''' | |
def __init__( | |
self, | |
crop_size, | |
translate=spatial.translate_uniform_dist, | |
scale=spatial.scale_uniform_dist, | |
rotate=spatial.rotate_uniform_dist, | |
always_apply: bool = False, p: float = 0.5 | |
): | |
super().__init__(always_apply, p) | |
if isinstance(crop_size, int): | |
crop_size = (crop_size, crop_size) | |
self.crop_size = crop_size | |
self.translate_dist = spatial.translate_sampler_fnc(translate) | |
self.scale_dist = spatial.scale_sampler_fnc(scale) | |
self.rotate_dist = spatial.rotate_sampler_fnc(rotate) | |
def get_params(self): | |
''' | |
Gets a set of random affine parameters sampled from the distributions | |
this object was initialised with. | |
''' | |
return { | |
'translate': self.translate_dist(), | |
'scale': self.scale_dist(), | |
'rotate': self.rotate_dist(), | |
} | |
def get_params_safe(self, image_size): | |
''' | |
Gets a set of params sampled uniformly from provided ranges, while | |
ensuring crop is within bounds of the image | |
''' | |
# TODO: How can you make this safe, efficiently? | |
# Central square can be sampled normally. | |
# Triangles on the sides: trigonometry to determine how much are | |
# valid crop locations | |
# Select between by area | |
# In practice, for large images, this is very unlikely to occur more than once | |
for i in range(100): | |
params = self.get_params() | |
M = self.get_pixel_transform(image_size, **params) | |
if self.check_sample_inside(M, image_size): | |
return params | |
raise Exception('Could\'t spatial params for crop within image. Check image and crop sizes') | |
def check_sample_inside(self, M, image_size): | |
''' Returns true if this will spatial wholly within the image ''' | |
iw, ih = image_size | |
xlo, ylo, xhi, yhi = self.get_crop_bounds(M) | |
return xlo >= 0 and ylo >= 0 and xhi < iw and yhi < ih | |
def get_pixel_transform(self, img_size, translate=(0, 0), scale=(0, 0, 1), rotate=(0, 0, 0)): | |
return spatial.get_pixel_transform(self.crop_size, translate, scale, rotate, img_size) | |
def get_crop_pixel_transform(self, M): | |
return spatial.get_crop_pixel_transform(M, self.crop_size) | |
def get_crop_bounds(self, M): | |
return spatial.get_crop_bounds(M, self.crop_size) | |
def fetch(self, img, interpolate, M=None, **params): | |
img_size = img.shape[1::-1] | |
if M is None: | |
M = self.get_pixel_transform(img_size, **params) | |
return spatial.fetch(img, M, self.crop_size, interpolate) | |
def apply(self, img, rows=None, cols=None, **params): | |
return self.fetch(img, interpolate=cv2.INTER_LINEAR, **params) | |
def apply_to_mask(self, img, rows=None, cols=None, **params): | |
return self.fetch(img, interpolate=cv2.INTER_NEAREST, **params) | |
def apply_to_bbox(self, bbox, rows, cols, M=None, **params): | |
''' | |
Applies the transform to the bbox. | |
Note: takes axis-aligned coordinates, returns axis-aligned coordinates. | |
Thus, if there's a rotation, the final axis-aligned corners are the | |
minimum axis-aligned box to cover the rotated corners and the area | |
gets larger. | |
i.e. rotating 45 degrees and then rotating -45 degrees won't give | |
you the same bbox. | |
''' | |
if M is None: | |
M = self.get_pixel_transform((cols, rows), **params) | |
xlo, ylo, xhi, yhi = bbox | |
# Translate to pixel values | |
pxlo = xlo * cols | |
pylo = ylo * rows | |
pxhi = xhi * cols | |
pyhi = yhi * rows | |
# Apply transform and re-align corners with axis | |
initial_corners = [(pxlo, pylo), (pxhi, pylo), (pxlo, pyhi), (pxhi, pyhi)] | |
corners = mat3.transform(M, initial_corners) | |
npxlo, npylo = corners.min(axis=0) | |
npxhi, npyhi = corners.max(axis=0) | |
# Normalise to crop size | |
nxlo = npxlo / self.crop_size[0] | |
nylo = npylo / self.crop_size[1] | |
nxhi = npxhi / self.crop_size[0] | |
nyhi = npyhi / self.crop_size[1] | |
return nxlo, nylo, nxhi, nyhi | |
def apply_to_keypoint(self, keypoint, rows, cols, M=None, **params): | |
''' | |
Keypoints can actually be a vector; they start somewhere and point somewhere else. | |
https://albumentations.ai/docs/getting_started/keypoints_augmentation/ | |
''' | |
if M is None: | |
M = self.get_pixel_transform((cols, rows), **params) | |
# Get point and vector | |
x, y, angle, scale = keypoint | |
vx = math.cos(angle)*scale | |
vy = -math.sin(angle)*scale | |
# Transform point and vector in pixel space | |
points = [[x, y],[x+vx, y+vy]] | |
# New x,y denoted nx,ny. | |
(nx, ny), (nevx, nevy) = mat3.transform(M, points) | |
# Vector is relative to nx, ny | |
nvx = nevx-nx | |
nvy = nevy-ny | |
nscale = math.sqrt(nvx**2 + nvy**2) | |
nangle = math.atan2(nvy, nvx) | |
# Albumentations follows the convention that a positive rotation | |
# is counter-clockwise for the angle of the keypoint | |
nangle = -nangle | |
# Return vector in polar form | |
# To match with `albumentations.Rotate`, we take the int part. | |
return safe_int(nx), safe_int(ny), nangle, nscale |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import cv2 | |
import numpy as np | |
import mat3 | |
# Default uniform ranges | |
TRANSLATE_DEFAULT_RANGE = (-1, 0) | |
SCALE_DEFAULT_RANGE = 1.1 | |
ROTATE_DEFAULT_RANGE = (-math.pi/2, math.pi/2) | |
# Sampling functions | |
def translate_uniform_dist(v=TRANSLATE_DEFAULT_RANGE): | |
return tuple(np.random.uniform(*v, (2,))) | |
def multiplicatively_symmetrical_uniform(v): | |
''' | |
If you sample in the range (0.5, 2) then you are more likely | |
to sample a value greater than 1 than less than 1. | |
This function balances the weight such that sampling a value in (1, v) | |
is equally likely as sampling from the range (1/v, 1) | |
''' | |
sample = np.random.uniform(1, v) | |
r = np.random.randn() | |
return sample if r > 0 else 1/sample | |
def scale_uniform_dist(srx=0.5, sry=0.5, v=SCALE_DEFAULT_RANGE): | |
return (srx, sry, multiplicatively_symmetrical_uniform(v)) | |
def rotate_uniform_dist(rx=0.5, ry=0.5, v=ROTATE_DEFAULT_RANGE): | |
return (rx, ry, np.random.uniform(*v)) | |
def translate_sampler_fnc(translate=translate_uniform_dist): | |
''' Returns a parameterless function for sampling translation ''' | |
if isinstance(translate, tuple) or isinstance(translate, list): | |
if not(len(translate) == 2): | |
raise ValueError('Translate as a tuple must be a 2-tuple') | |
return lambda: translate_uniform_dist(translate) | |
elif translate is None: | |
return lambda: (0, 0) | |
elif callable(translate): | |
return translate | |
else: | |
raise ValueError('translate must be one of: (float, float), None or callable') | |
def scale_sampler_fnc(scale=scale_uniform_dist): | |
if isinstance(scale, tuple) or isinstance(scale, list): | |
return lambda: scale_uniform_dist(*scale) | |
elif isinstance(scale, float): | |
return lambda: scale_uniform_dist(v=scale) | |
elif scale is None: | |
return lambda: (0, 0, 1) | |
elif callable(scale): | |
return scale | |
else: | |
raise ValueError('Scale must be one of: float, (float, float, float) or callable') | |
def rotate_sampler_fnc(rotate=rotate_uniform_dist): | |
if isinstance(rotate, tuple) or isinstance(rotate, list): | |
if not(len(rotate) == 3): | |
raise ValueError('Rotate as a tuple must be a 3-tuple') | |
return lambda: rotate_uniform_dist(*rotate) | |
elif isinstance(rotate, float): | |
return lambda: rotate_uniform_dist(v=rotate) | |
elif rotate is None: | |
return lambda: (0, 0, 0) | |
elif callable(rotate): | |
return rotate | |
else: | |
raise ValueError('Rotate must be one of: (float, float), ' | |
'(float, float, (float, float)) or callable') | |
def affine_sampler_fnc(translate=translate_uniform_dist, | |
scale=scale_uniform_dist, | |
rotate=rotate_uniform_dist): | |
t = translate_sampler_fnc(translate) | |
s = scale_sampler_fnc(scale) | |
r = rotate_sampler_fnc(rotate) | |
def do_sample(): | |
return { | |
'translate': t(), | |
'scale': s(), | |
'rotate': r(), | |
} | |
return do_sample | |
# Matrix functions | |
def _translate_norm_matrix(img_size, crop_size, translate): | |
iw, ih = img_size | |
cw, ch = crop_size | |
dx, dy = translate | |
# Normalised translation is bounded by viable crop locations | |
return mat3.translate((iw - cw + 1)*dx, (ih - ch + 1)*dy) | |
def get_pixel_transform(crop_size, translate=(0, 0), scale=(0, 0, 1), rotate=(0, 0, 0), | |
img_size=None): | |
''' | |
Gets a 3x3 matrix which moves coordinates to new positions | |
in transformed image space. | |
e.g. Given a bbox at (10, 10), rx=ry=angle=0, sc=1, dx=20, dy=40 | |
applying this transform will move it to (30, 50) | |
''' | |
cw, ch = crop_size | |
rx, ry, angle = rotate | |
srx, sry, sc = scale | |
if img_size is None: | |
t = mat3.translate(*translate) | |
else: | |
t = _translate_norm_matrix(img_size, crop_size, translate) | |
# (rx, ry) and (srx, sry) are relative to the crop origin. | |
s = mat3.scale_around(srx*cw, sry*ch, sc) | |
r = mat3.rotate_around(rx*cw, ry*ch, angle) | |
M = mat3.concatenate([t, s, r]) | |
return M | |
def get_crop_pixel_transform(M, crop_size): | |
''' | |
Gets a 3x3 transform which moves coordinates within a crop | |
as if you used M on the whole image. | |
''' | |
xlo, ylo, _, _ = get_crop_bounds(M, crop_size) | |
U = mat3.translate(xlo, ylo) | |
Q = mat3.concatenate([U, M]) | |
return Q | |
def get_crop_bounds(M, crop_size): | |
''' | |
Returns minimum pixel boundaries required to obtain pixel data. | |
''' | |
# Assume crop is in top-right, then use M to move corners of crop | |
# to correct place in pixel-space | |
initial_corners = np.array([(0., 0), (1, 0), (1, 1), (0, 1)]) * crop_size | |
corners = mat3.transform(np.linalg.inv(M), initial_corners) | |
xlo, ylo = corners.min(axis=0) | |
xhi, yhi = corners.max(axis=0) | |
xlo, ylo = math.floor(xlo), math.floor(ylo) | |
xhi, yhi = math.ceil(xhi), math.ceil(yhi) | |
return xlo, ylo, xhi, yhi | |
def fetch(img, M, crop_size, interpolate=cv2.INTER_LINEAR): | |
# Find M, relative to the crop | |
Q = get_crop_pixel_transform(M, crop_size) | |
# Get crop and return selected data | |
bounds = get_crop_bounds(M, crop_size) | |
crop = read_crop(img, bounds, ch_axis=2) | |
# Note: warpAffine treats (0,0) to be the centre of the top-left pixel. | |
# This is not a bug; it is necessary for proper sampling. | |
# However, this can cause unexpected mis-alignments between bboxes | |
# and the sampled image | |
half_in = mat3.translate(1/2, 1/2) | |
half_out = mat3.translate(-1/2, -1/2) | |
P = mat3.concatenate([half_in, Q, half_out]) | |
return cv2.warpAffine(crop, P[:2], crop_size, | |
flags=interpolate, borderMode=cv2.BORDER_CONSTANT) | |
def get_crop_slices(img_shp, window): | |
''' | |
Slices for filling a crop from an image when cropping with `window` | |
while accounting for the window being off the edge of `img`. | |
*Note:* negative values in `window` are interpreted as-is, not as "from the end". | |
''' | |
img_shp, window = np.array(img_shp), np.array(window) | |
start = window[:, 0] | |
end = window[:, 1] | |
window_shp = end - start | |
# Calculate crop slice positions | |
crop_low = np.clip(0 - start, a_min=0, a_max=window_shp) | |
crop_high = window_shp - np.clip(end-img_shp, a_min=0, a_max=window_shp) | |
crop_slices = tuple(slice(low, high) for low, high in zip(crop_low, crop_high)) | |
# Calculate img slice positions | |
start = np.clip(start, a_min=0, a_max=img_shp) | |
end = np.clip(end, a_min=0, a_max=img_shp) | |
img_slices = tuple(slice(low, high) for low, high in zip(start, end)) | |
return img_slices, crop_slices | |
def read_crop(img, window, ch_axis=0): | |
''' | |
Gets from `img` a crop defined by `pos`/`size`. | |
Note: `pos`/`size` may extend beyond image boundaries, and | |
negative indices are considered as-is, not "from the end". | |
Parameters | |
---------- | |
img : rasterio.Dataset or np.ndarray | |
Dataset/img to pull raster data from | |
pos : ((int, int), (int, int)) or (int, int, int, int) or (int, int) | |
1st form: bounds (xlo, ylo, xhi, yhi) | |
2nd form: pixel location to pull from (y, x) (must provide `size`) | |
3rd form: slices to pull from (y_window, x_window) | |
size : (int, int), opt | |
Size of crop (yx) if window is given as a position | |
ch_axis : int, opt | |
Which axis to place the channels in output | |
Returns | |
------- | |
np.ndarray | |
Crop raster data | |
''' | |
window = normalise_window(window) | |
(ylo, yhi), (xlo, xhi) = window | |
ysize, xsize = (yhi-ylo), (xhi-xlo) | |
dtype = img.dtypes[0] if hasattr(img, 'dtypes') else img.dtype | |
# Window may extend beyond the image boundaries, thus, instead of simply | |
# stacking the outputs, which wouldn't be the full size, we create | |
# a np.zeros and then paste into that the usable parts of the crop | |
crop_shp = [ysize, xsize] | |
img_shp = list(img.shape) | |
if hasattr(img, 'count'): | |
crop_shp.insert(ch_axis, img.count) | |
elif len(img.shape) == 3: | |
# Note: pop drops the channel from img_shp | |
crop_shp.insert(ch_axis, img_shp.pop(ch_axis)) | |
crop = np.zeros(crop_shp, dtype=dtype) | |
img_slices, crop_slices = get_slices(img_shp, window) | |
if hasattr(img, 'read') and hasattr(img, 'count'): | |
for j in range(img.count): | |
band_crop_slices = list(crop_slices) | |
band_crop_slices.insert(ch_axis, j) | |
crop[tuple(band_crop_slices)] = img.read(j+1, window=img_slices) | |
else: | |
if len(img.shape) == 3: | |
crop_slices = list(crop_slices) | |
img_slices = list(img_slices) | |
crop_slices.insert(ch_axis, slice(None)) | |
img_slices.insert(ch_axis, slice(None)) | |
crop[tuple(crop_slices)] = img[tuple(img_slices)] | |
return crop |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This was created for handling large TIF files efficiently. | |
# There is one main class in this file: RandomAffineCrop | |
# This can do most of what Albumentations.Affine does but with a major | |
# distinction. | |
# You can call RandomAffineCrop.apply() on a rasterio.Dataset, and it will | |
# only load the minimal amount of data required to provide a crop. | |
# This is in contrast with Albumentations.Affine which requires an image | |
# to be fully loaded in memory. | |
import decimal | |
import math | |
import cv2 | |
import numpy as np | |
import albumentations as A | |
import spatial | |
import mat3 | |
def safe_int(v, atol=7): | |
''' | |
Simply casting to int truncates 0.999999999999 to 0. | |
This accounts for floating point imprecisions. | |
''' | |
return int(round(decimal.Decimal(v), atol)) | |
class RandomAffineCrop(A.DualTransform): | |
''' | |
Applies the following operations in order: | |
translation | |
scaling | |
rotation | |
This is represented as a params dictionary: | |
{ | |
'translate': (dx, dy), | |
'scale': (srx, sry, scale), | |
'rotate': (rx, ry, angle) | |
} | |
- dx, dy are normalised to IMAGE dimensions | |
- srx, sry are normalised to CROP dimensions | |
- rx, ry are normalised to CROP dimensions | |
- angle is in radians | |
You can generate a dictionary like this with random values by using | |
self.get_params*() functions | |
Then theoretically `apply()` crops the resulting image at crop_size. | |
In reality, it only loads the minimum image data from disk. | |
You may provide the random distribution for each operation as a callable | |
that takes no parameters. | |
(in the below descriptions `f` refers to `float`) | |
Attributes | |
---------- | |
crop_size : int or (int, int) | |
Size of crop in pixels | |
translate_dist : (f, f) or callable, opt | |
if 2-tuple, interpreted as uniform range for both dimensions | |
if callable, must return (dx, dy) | |
scale_dist : f or (f, f, f) or callable, opt | |
if float, interpreted as symmetric range limit. | |
e.g. scale_dist=2 means, select scale between 0.5 and 2, with equal | |
chance to spatial <1 as to spatial >1 | |
The scale registration point defaults to crop centre | |
if 3-tuple, interpreted as (srx, sry, symmetric range limit) | |
e.g. (0.5, 0.5, 2) | |
(both above examples perform the same sampling) | |
if callable, must return (srx, sry, scale) | |
rotate_dist : (f, f) or (f, f, (f,f)) or callable, opt | |
if 2-tuple interpreted as angle (radians) range. | |
e.g. (-math.pi/2, math.pi/2) | |
the rotation centre defaults to the centre of the crop | |
if 3-tuple interpreted as (rx, ry, angle (radians) range) | |
e.g. (0.5, 0.5, (-math.pi/2, math.pi/2)) | |
(both above examples perform the same sampling) | |
if callable, must return (rx, ry, angle) | |
''' | |
def __init__( | |
self, | |
crop_size, | |
translate=spatial.translate_uniform_dist, | |
scale=spatial.scale_uniform_dist, | |
rotate=spatial.rotate_uniform_dist, | |
always_apply: bool = False, p: float = 0.5 | |
): | |
super().__init__(always_apply, p) | |
if isinstance(crop_size, int): | |
crop_size = (crop_size, crop_size) | |
self.crop_size = crop_size | |
self.translate_dist = spatial.translate_sampler_fnc(translate) | |
self.scale_dist = spatial.scale_sampler_fnc(scale) | |
self.rotate_dist = spatial.rotate_sampler_fnc(rotate) | |
def get_params(self): | |
''' | |
Gets a set of random affine parameters sampled from the distributions | |
this object was initialised with. | |
''' | |
return { | |
'translate': self.translate_dist(), | |
'scale': self.scale_dist(), | |
'rotate': self.rotate_dist(), | |
} | |
def get_params_safe(self, image_size): | |
''' | |
Gets a set of params sampled uniformly from provided ranges, while | |
ensuring crop is within bounds of the image | |
''' | |
# TODO: How can you make this safe, efficiently? | |
# Central square can be sampled normally. | |
# Triangles on the sides: trigonometry to determine how much are | |
# valid crop locations | |
# Select between by area | |
# In practice, for large images, this is very unlikely to occur more than once | |
for i in range(100): | |
params = self.get_params() | |
M = self.get_pixel_transform(image_size, **params) | |
if self.check_sample_inside(M, image_size): | |
return params | |
raise Exception('Could\'t spatial params for crop within image. Check image and crop sizes') | |
def check_sample_inside(self, M, image_size): | |
''' Returns true if this will spatial wholly within the image ''' | |
iw, ih = image_size | |
xlo, ylo, xhi, yhi = self.get_crop_bounds(M) | |
return xlo >= 0 and ylo >= 0 and xhi < iw and yhi < ih | |
def get_pixel_transform(self, img_size, translate=(0, 0), scale=(0, 0, 1), rotate=(0, 0, 0)): | |
return spatial.get_pixel_transform(self.crop_size, translate, scale, rotate, img_size) | |
def get_crop_pixel_transform(self, M): | |
return spatial.get_crop_pixel_transform(M, self.crop_size) | |
def get_crop_bounds(self, M): | |
return spatial.get_crop_bounds(M, self.crop_size) | |
def fetch(self, img, interpolate, M=None, **params): | |
img_size = img.shape[1::-1] | |
if M is None: | |
M = self.get_pixel_transform(img_size, **params) | |
return spatial.fetch(img, M, self.crop_size, interpolate) | |
def apply(self, img, rows=None, cols=None, **params): | |
return self.fetch(img, interpolate=cv2.INTER_LINEAR, **params) | |
def apply_to_mask(self, img, rows=None, cols=None, **params): | |
return self.fetch(img, interpolate=cv2.INTER_NEAREST, **params) | |
def apply_to_bbox(self, bbox, rows, cols, M=None, **params): | |
''' | |
Applies the transform to the bbox. | |
Note: takes axis-aligned coordinates, returns axis-aligned coordinates. | |
Thus, if there's a rotation, the final axis-aligned corners are the | |
minimum axis-aligned box to cover the rotated corners and the area | |
gets larger. | |
i.e. rotating 45 degrees and then rotating -45 degrees won't give | |
you the same bbox. | |
''' | |
if M is None: | |
M = self.get_pixel_transform((cols, rows), **params) | |
xlo, ylo, xhi, yhi = bbox | |
# Translate to pixel values | |
pxlo = xlo * cols | |
pylo = ylo * rows | |
pxhi = xhi * cols | |
pyhi = yhi * rows | |
# Apply transform and re-align corners with axis | |
initial_corners = [(pxlo, pylo), (pxhi, pylo), (pxlo, pyhi), (pxhi, pyhi)] | |
corners = mat3.transform(M, initial_corners) | |
npxlo, npylo = corners.min(axis=0) | |
npxhi, npyhi = corners.max(axis=0) | |
# Normalise to crop size | |
nxlo = npxlo / self.crop_size[0] | |
nylo = npylo / self.crop_size[1] | |
nxhi = npxhi / self.crop_size[0] | |
nyhi = npyhi / self.crop_size[1] | |
return nxlo, nylo, nxhi, nyhi | |
def apply_to_keypoint(self, keypoint, rows, cols, M=None, **params): | |
''' | |
Keypoints can actually be a vector; they start somewhere and point somewhere else. | |
https://albumentations.ai/docs/getting_started/keypoints_augmentation/ | |
''' | |
if M is None: | |
M = self.get_pixel_transform((cols, rows), **params) | |
# Get point and vector | |
x, y, angle, scale = keypoint | |
vx = math.cos(angle)*scale | |
vy = -math.sin(angle)*scale | |
# Transform point and vector in pixel space | |
points = [[x, y],[x+vx, y+vy]] | |
# New x,y denoted nx,ny. | |
(nx, ny), (nevx, nevy) = mat3.transform(M, points) | |
# Vector is relative to nx, ny | |
nvx = nevx-nx | |
nvy = nevy-ny | |
nscale = math.sqrt(nvx**2 + nvy**2) | |
nangle = math.atan2(nvy, nvx) | |
# Albumentations follows the convention that a positive rotation | |
# is counter-clockwise for the angle of the keypoint | |
nangle = -nangle | |
# Return vector in polar form | |
# To match with `albumentations.Rotate`, we take the int part. | |
return safe_int(nx), safe_int(ny), nangle, nscale |
Minimal functional usage example:
import rasterio
import spatial
crop_size = (224, 224)
tif = rasterio.open('some.tif')
sampler = spatial.affine_sampler_fnc()
M = spatial.get_pixel_transform(crop_size, img_size=tif.shape, **sampler())
crop = spatial.fetch(tif, M, crop_size)
Note spatial.py
has no dependency on Albumentations.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Minimal usage example: