Skip to content

Instantly share code, notes, and snippets.

@CasiaFan
Created June 25, 2018 09:03
Show Gist options
  • Save CasiaFan/639abfd185edaea123d7c54d5a87a418 to your computer and use it in GitHub Desktop.
Save CasiaFan/639abfd185edaea123d7c54d5a87a418 to your computer and use it in GitHub Desktop.
Add local brightness adjustment radically for image augmentation
import cv2
import numpy as np
import random
from scipy.stats import norm
def generate_spot_light_mask(mask_size,
position=None,
max_brightness=255,
min_brightness=0,
mode="gaussian",
linear_decay_rate=None,
speedup=False):
"""
Generate decayed light mask generated by spot light given position, direction. Multiple spotlights are accepted.
Args:
mask_size: tuple of integers (w, h) defining generated mask size
position: list of tuple of integers (x, y) defining the center of spotlight light position,
which is the reference point during rotating
max_brightness: integer that max brightness in the mask
min_brightness: integer that min brightness in the mask
mode: the way that brightness decay from max to min: linear or gaussian
linear_decay_rate: only valid in linear_static mode. Suggested value is within [0.2, 2]
speedup: use `shrinkage then expansion` strategy to speed up vale calculation
Return:
light_mask: ndarray in float type consisting value from max_brightness to min_brightness. If in 'linear' mode
minimum value could be smaller than given min_brightness.
"""
if position is None:
position = [(random.randint(0, mask_size[0]), random.randint(0, mask_size[1]))]
if linear_decay_rate is None:
if mode == "linear_static":
linear_decay_rate = random.uniform(0.25, 1)
assert mode in ["linear", "gaussian"], \
"mode must be linear_dynamic, linear_static or gaussian"
mask = np.zeros(shape=(mask_size[1], mask_size[0]), dtype=np.float32)
if mode == "gaussian":
mu = np.sqrt(mask.shape[0]**2+mask.shape[1]**2)
dev = mu / 3.5
mask = _decay_value_radically_norm_in_matrix(mask_size, position, max_brightness, min_brightness, dev)
mask = np.asarray(mask, dtype=np.uint8)
# add median blur
mask = cv2.medianBlur(mask, 5)
mask = 255 - mask
# cv2.imshow("mask", mask)
# cv2.waitKey(0)
return mask
def _decay_value_radically_norm_in_matrix(mask_size, centers, max_value, min_value, dev):
"""
_decay_value_radically_norm function in matrix format
"""
center_prob = norm.pdf(0, 0, dev)
x_value_rate = np.zeros((mask_size[1], mask_size[0]))
for center in centers:
coord_x = np.arange(mask_size[0])
coord_y = np.arange(mask_size[1])
xv, yv = np.meshgrid(coord_x, coord_y)
dist_x = xv - center[0]
dist_y = yv - center[1]
dist = np.sqrt(np.power(dist_x, 2) + np.power(dist_y, 2))
x_value_rate += norm.pdf(dist, 0, dev) / center_prob
mask = x_value_rate * (max_value - min_value) + min_value
mask[mask > 255] = 255
return mask
def _decay_value_radically_norm(x, centers, max_value, min_value, dev):
"""
Calculate point value decayed from center following Gaussian decay. If multiple centers are given, value
from each center sums up while limiting the accumulated value into [0, 255]
NOTE: assuming light at each center is identical: same brightness and same decay rate
"""
center_prob = norm.pdf(0, 0, dev)
x_value_rate = 0
for center in centers:
distance = np.sqrt((center[0]-x[0])**2 + (center[1]-x[1])**2)
x_value_rate += norm.pdf(distance, 0, dev) / center_prob
x_value = x_value_rate * (max_value - min_value) + min_value
x_value = 255 if x_value > 255 else x_value
return x_value
def add_spot_light(image, light_position=None, max_brightness=255, min_brightness=0,
mode='gaussian', linear_decay_rate=None, transparency=None):
"""
Add mask generated from spot light to given image
"""
if transparency is None:
transparency = random.uniform(0.5, 0.85)
frame = cv2.imread(image)
height, width, _ = frame.shape
hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
mask = generate_spot_light_mask(mask_size=(width, height),
position=light_position,
max_brightness=max_brightness,
min_brightness=min_brightness,
mode=mode,
linear_decay_rate=linear_decay_rate)
hsv[:, :, 2] = hsv[:, :, 2] * transparency + mask * (1 - transparency)
frame = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
frame[frame > 255] = 255
frame = np.asarray(frame, dtype=np.uint8)
return frame
@joihn
Copy link

joihn commented Jan 3, 2023

thanks for this code,

"speedup" part of the code seems missing, despite the argument beeing present

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment