Last active
April 30, 2021 10:55
-
-
Save gfacciol/48d946731e8cdac0693b3b4b3ea6d6c1 to your computer and use it in GitHub Desktop.
Pytorch implementation of Hamilton-Adams demosaicing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Pytorch implementation of Hamilton-Adams demosaicing | |
# J. Hamilton Jr. and J. Adams Jr. Adaptive color plan interpolation | |
# in single sensor color electronic camera, 1997, US Patent 5,629,734. | |
# | |
# Copyright (c) 2021 Gabriele Facciolo | |
# based on code by Yu Guo and Qiyu Jin | |
import numpy as np | |
import torch | |
def mosaic_bayer(rgb, pattern): | |
""" | |
generate a mosaic from a rgb image | |
pattern can be: 'grbg', 'rggb', 'gbrg', 'bggr' | |
""" | |
num = np.zeros(len(pattern)) | |
pattern_list = list(pattern) | |
p = pattern_list.index('r') | |
num[p] = 0 | |
p = [idx for idx, i in enumerate(pattern_list) if i == 'g'] | |
num[p] = 1 | |
p = pattern_list.index('b') | |
num[p] = 2 | |
size_rgb = rgb.shape | |
# handle the case when the input image is already a squeezed mosaic | |
if len(size_rgb)==2: | |
rgb = rgb.unsqueeze(2).repeat(1,1,3) | |
mask = torch.zeros((size_rgb[0], size_rgb[1], 3)) | |
# Generate mask | |
mask[0::2, 0::2, int(num[0])] = 1 | |
mask[0::2, 1::2, int(num[1])] = 1 | |
mask[1::2, 0::2, int(num[2])] = 1 | |
mask[1::2, 1::2, int(num[3])] = 1 | |
# Generate mosaic | |
mosaic = rgb * mask | |
return mosaic, mask | |
def myconv(im, ker): | |
""" | |
convolve the 2d tensor image (im) with the 2d kernel (ker) and return a 2d tensor image | |
pads the image to preserve the shape of the input tensor by replicating boundaries | |
""" | |
# compute image padding | |
pad = np.array(ker.shape)//2 | |
pad = (pad[1], pad[1], pad[0],pad[0]) # lefr,right,top,bottom | |
# pad and unsqueeze the data | |
im = torch.nn.functional.pad(im.unsqueeze(0).unsqueeze(0), pad, mode='replicate') | |
# unsqueeze the kernel | |
ker=torch.Tensor( ker ).unsqueeze(0).unsqueeze(0) | |
# apply conv, squeeze x 2 and return | |
return torch.nn.functional.conv2d(im , weight=ker).squeeze(0).squeeze(0) | |
# This functions implements Algorithm 1 | |
def hagreen_interpolation(mosaic, mask): | |
""" | |
hamilton-adams green channel processing | |
""" | |
Kh = np.array([[1/2, 0, 1/2]]) | |
Kv = Kh.T | |
Deltah = np.array([[1, 0, -2, 0, 1]]) | |
Deltav = Deltah.T | |
Diffh = np.array([[1, 0, -1]]) | |
Diffv = Diffh.T | |
rawq = torch.sum(mosaic, axis=2) # get the raw CFA data | |
rawh = myconv( rawq, Kh ) - myconv( rawq, Deltah/4 ) | |
rawv = myconv( rawq, Kv ) - myconv( rawq, Deltav/4 ) | |
CLh = torch.abs( myconv(rawq, Diffh) ) + torch.abs( myconv(rawq, Deltah) ) | |
CLv = torch.abs( myconv(rawq, Diffv) ) + torch.abs( myconv(rawq, Deltav) ) | |
# this implements the logic assigning rawh when CLv > CLh | |
# rawv when CLv < CLh; | |
# (rawh+rawv)/2 otherwise | |
CLlocation = torch.sign(CLh - CLv) | |
green = (1 + CLlocation) * rawv / 2 + (1 - CLlocation) * rawh / 2 | |
imask = (mask == 0) # inverse mask | |
green = green * imask[:, :, 1] + rawq * mask[:, :, 1] | |
return green | |
# This functions implements Algorithm 2 (blue pixels) | |
def hablue_interpolation(green, mosaic, mask, pattern): | |
""" | |
hamilton-adams blue channel processing | |
""" | |
# mask | |
size_rawq = mosaic.shape | |
maskGr = torch.zeros((size_rawq[0], size_rawq[1])) | |
maskGb = torch.zeros((size_rawq[0], size_rawq[1])) | |
if pattern == 'grbg': | |
maskGr[0::2, 0::2] = 1 | |
maskGb[1::2, 1::2] = 1 | |
elif pattern == 'rggb': | |
maskGr[0::2, 1::2] = 1 | |
maskGb[1::2, 0::2] = 1 | |
elif pattern == 'gbrg': | |
maskGb[0::2, 0::2] = 1 | |
maskGr[1::2, 1::2] = 1 | |
elif pattern == 'bggr': | |
maskGb[0::2, 1::2] = 1 | |
maskGr[1::2, 0::2] = 1 | |
maskR = mask[:,:,0] | |
Kh = np.array([[1, 0, 1]]) | |
Kv = Kh.T | |
Kp = np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]]) | |
Kn = np.array([[0, 0, 1], [0, 0, 0], [1, 0, 0]]) | |
Deltap = np.array([[1, 0, 0], [0, -2, 0], [0, 0, 1]]) | |
Deltan = np.array([[0, 0, 1], [0, -2, 0], [1, 0, 0]]) | |
Deltah = np.array([[1, -2, 1]]) | |
Deltav = Deltah.T | |
Diffp = np.array([[-1, 0, 0], [0, 0, 0], [0, 0, 1]]) | |
Diffn = np.array([[0, 0, -1], [0, 0, 0], [1, 0, 0]]) | |
Bh = maskGb * ( 0.5 * myconv( mosaic[:,:,2], Kh ) - 0.25 * myconv( green, Deltah )) | |
Bv = maskGr * ( 0.5 * myconv( mosaic[:,:,2], Kv ) - 0.25 * myconv( green, Deltav )) | |
Bp = maskR * ( 0.5 * myconv( mosaic[:,:,2], Kp ) - 0.25 * myconv( green, Deltap )) | |
Bn = maskR * ( 0.5 * myconv( mosaic[:,:,2], Kn ) - 0.25 * myconv( green, Deltan )) | |
CLp = maskR * (torch.abs( myconv( mosaic[:,:,2], Diffp )) + torch.abs( myconv( green, Deltap )) ) | |
CLn = maskR * (torch.abs( myconv( mosaic[:,:,2], Diffn )) + torch.abs( myconv( green, Deltan )) ) | |
CLlocation = torch.sign(CLp - CLn) | |
blue = (1 + CLlocation) * Bn / 2 + (1 - CLlocation) * Bp / 2 | |
blue = blue + Bh + Bv + mosaic[:, :, 2] | |
return blue | |
# This functions implements Algorithm 2 (red pixels) | |
def hared_interpolation(green, mosaic, mask, pattern): | |
""" | |
hamilton-adams red channel processing | |
""" | |
# mask | |
size_rawq = mosaic.shape | |
maskGr = torch.zeros((size_rawq[0], size_rawq[1])) | |
maskGb = torch.zeros((size_rawq[0], size_rawq[1])) | |
if pattern == 'grbg': | |
maskGr[0::2, 0::2] = 1 | |
maskGb[1::2, 1::2] = 1 | |
elif pattern == 'rggb': | |
maskGr[0::2, 1::2] = 1 | |
maskGb[1::2, 0::2] = 1 | |
elif pattern == 'gbrg': | |
maskGb[0::2, 0::2] = 1 | |
maskGr[1::2, 1::2] = 1 | |
elif pattern == 'bggr': | |
maskGb[0::2, 1::2] = 1 | |
maskGr[1::2, 0::2] = 1 | |
maskB = mask[:,:,2] | |
Kh = np.array([[1, 0, 1]]) | |
Kv = Kh.T | |
Kp = np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]]) | |
Kn = np.array([[0, 0, 1], [0, 0, 0], [1, 0, 0]]) | |
Deltap = np.array([[1, 0, 0], [0, -2, 0], [0, 0, 1]]) | |
Deltan = np.array([[0, 0, 1], [0, -2, 0], [1, 0, 0]]) | |
Deltah = np.array([[1, -2, 1]]) | |
Deltav = Deltah.T | |
Diffp = np.array([[-1, 0, 0], [0, 0, 0], [0, 0, 1]]) | |
Diffn = np.array([[0, 0, -1], [0, 0, 0], [1, 0, 0]]) | |
Rh = maskGr * ( 0.5 * myconv( mosaic[:,:,0], Kh ) - 0.25 * myconv( green, Deltah )) | |
Rv = maskGb * ( 0.5 * myconv( mosaic[:,:,0], Kv ) - 0.25 * myconv( green, Deltav )) | |
Rp = maskB * ( 0.5 * myconv( mosaic[:,:,0], Kp ) - 0.25 * myconv( green, Deltap )) | |
Rn = maskB * ( 0.5 * myconv( mosaic[:,:,0], Kn ) - 0.25 * myconv( green, Deltan )) | |
CLp = maskB * (torch.abs( myconv( mosaic[:,:,0], Diffp )) + torch.abs( myconv( green, Deltap )) ) | |
CLn = maskB * (torch.abs( myconv( mosaic[:,:,0], Diffn )) + torch.abs( myconv( green, Deltan )) ) | |
CLlocation = torch.sign(CLp - CLn) | |
red = (1 + CLlocation) * Rn / 2 + (1 - CLlocation) * Rp / 2 | |
red = red + Rh + Rv + mosaic[:, :, 0] | |
return red | |
def demosaic_HA(mosaic, pattern): | |
""" | |
Hamilton-Adams demosaicing main function | |
mosaic can be a 2D or 3D array with dimensions W,H,C | |
pattern can be: 'grbg', 'rggb', 'gbrg', 'bggr' | |
""" | |
# mosaic and mask (just to generate the mask) | |
mosaic, mask = mosaic_bayer(mosaic, pattern) | |
# green interpolation (implements Algorithm 1) | |
green = hagreen_interpolation(mosaic, mask) | |
# Red and Blue demosaicing (implements Algorithm 2) | |
red = hared_interpolation(green, mosaic, mask, pattern) | |
blue = hablue_interpolation(green, mosaic, mask, pattern) | |
# result image | |
rgb_size = mosaic.shape | |
rgb_dem = torch.zeros((rgb_size[0], rgb_size[1], 3)) | |
rgb_dem[:, :, 0] = red | |
rgb_dem[:, :, 1] = green | |
rgb_dem[:, :, 2] = blue | |
return rgb_dem | |
if __name__ == "__main__": | |
import torch | |
from skimage.io import imread, imsave | |
rgb = imread('Sans_bruit_13.PNG') | |
rgb = rgb.astype('float32') | |
pattern = 'grbg' | |
rgb = torch.Tensor(rgb) | |
# generate the mosaic | |
mosaic, _ = mosaic_bayer(rgb, pattern) | |
# call the demosaicing | |
rgb_dem = demosaic_HA(mosaic, pattern) | |
rgb_dem = torch.clip(rgb_dem, 0, 255) | |
imsave('test_HA2.png', rgb_dem.numpy().astype('uint8')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment