Skip to content

Instantly share code, notes, and snippets.

@xmodar
Created March 16, 2020 22:52
Show Gist options
  • Save xmodar/437cc4e7d17a3ab033ec026b36a2d3ea to your computer and use it in GitHub Desktop.
Save xmodar/437cc4e7d17a3ab033ec026b36a2d3ea to your computer and use it in GitHub Desktop.
A color-toning transform made to match TorchVision implementations. Inspired by https://www.pyimagesearch.com/2014/06/30/super-fast-color-transfer-images/
import numpy as np
from PIL import Image
from skimage.color import lab2rgb, rgb2lab
class RandomColorToning:
def __init__(self, scale_mean, scale_std, shift_mean, shift_std):
self.scale_mean = scale_mean
self.scale_std = scale_std
self.shift_mean = shift_mean
self.shift_std = shift_std
def __call__(self, image):
mean = np.random.randn(3) * self.shift_std + self.shift_mean
std = np.random.randn(3) * self.scale_std + self.scale_mean
return Image.fromarray(self.transfer(image, mean, std))
@staticmethod
def transfer(target, source, source_std=None):
stats = lambda m: (m.reshape(-1, 3).mean(0), m.reshape(-1, 3).std(0))
target = rgb2lab(target)
target_mean, target_std = stats(target)
if source_std is None:
source, source_std = stats(rgb2lab(source))
scale = source_std / target_std
bias = source - scale * target_mean
transformed = scale * target + bias
lightness_bound = (116. / 200.) * transformed[..., 2] + (1e-10 - 16.)
transformed[..., 0] = np.maximum(transformed[..., 0], lightness_bound)
output = lab2rgb(transformed) * 255
return output.round().astype(np.uint8)
def __repr__(self):
return (f'{type(self).__name__}('
f'scale_mean={self.scale_mean}, '
f'scale_std={self.scale_std}, '
f'shift_mean={self.shift_mean}, '
f'shift_std={self.shift_std})')
if __name__ == '__main__':
from torchvision import transforms as T
def __main():
image_height, image_width = 480, 270
# these values are dataset dependent, they can be computed as follows:
# shift, scale = zip(*(stats(rgb2lab(image)) for image in dataset))
scale_mean = (15, 25, 25) # np.mean(scale, 0)
scale_std = (5, 10, 10) # np.std(scale, 0)
shift_mean = (70, 0, 0) # np.mean(shift, 0)
shift_std = (10, 20, 20) # np.std(shift, 0)
toner = RandomColorToning(scale_mean, scale_std, shift_mean, shift_std)
transforms = T.Compose([
T.Resize(int(min(image_height, image_width) * 1.5)),
T.RandomCrop((image_height, image_width)),
T.RandomHorizontalFlip(p=0.5),
T.RandomApply([toner], p=0.5),
T.ToTensor(),
])
print(transforms)
__main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment