Skip to content

Instantly share code, notes, and snippets.

@tonysyu
Created October 18, 2012 01:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tonysyu/3909400 to your computer and use it in GitHub Desktop.
Save tonysyu/3909400 to your computer and use it in GitHub Desktop.
Mock up of decorator to handle RGB images in grayscale images
import numpy as np
import matplotlib.pyplot as plt
from skimage import data, color, filter, exposure
from skimage.util import dtype
from scipy.ndimage import gaussian_filter
lab_range = (0, 100)
def adapt_rgb(image_filter):
def image_filter_adapted(image, *args, **kwargs):
rgb_behavior = kwargs.pop('rgb_behavior', 'lightness')
if color.is_rgb(image):
if rgb_behavior == 'lightness':
lab = color.rgb2lab(image)
lightness = lab[:, :, 0]
lightness = exposure.rescale_intensity(lightness,
in_range=lab_range,
out_range=(0, 1))
lightness = image_filter(lightness)
in_range = dtype.dtype_range[lightness.dtype.type]
if np.all(lightness > 0):
in_range = (0, in_range[1])
lightness = exposure.rescale_intensity(lightness,
in_range=in_range,
out_range=lab_range)
lab[:, :, 0] = lightness
out = color.lab2rgb(lab)
elif rgb_behavior == 'hsv':
hsv = color.rgb2hsv(image)
value = hsv[:, :, 2]
value = image_filter(value)
hsv[:, :, 2] = value
out = color.hsv2rgb(hsv)
elif rgb_behavior == 'each channel':
c_new = [image_filter(c) for c in image.T]
out = np.array(c_new).T
else:
out = image_filter(image)
return out
return image_filter_adapted
@adapt_rgb
def edges(image):
return filter.sobel(image)
@adapt_rgb
def smooth(image):
# return filter.tv_denoise(image) # slow
return gaussian_filter(image, 10)
fig, axes = plt.subplots(ncols=3, nrows=2)
axes[0, 0].imshow(smooth(data.lena(), rgb_behavior='each channel'))
axes[1, 0].imshow(edges(data.lena(), rgb_behavior='each channel'))
axes[0, 1].imshow(smooth(data.lena(), rgb_behavior='hsv'))
axes[1, 1].imshow(edges(data.lena(), rgb_behavior='hsv'))
axes[0, 2].imshow(smooth(data.lena(), rgb_behavior='lightness'))
axes[1, 2].imshow(edges(data.lena(), rgb_behavior='lightness'))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment