Skip to content

Instantly share code, notes, and snippets.

@gfacciol
Last active April 30, 2021 10:55
Show Gist options
  • Save gfacciol/48d946731e8cdac0693b3b4b3ea6d6c1 to your computer and use it in GitHub Desktop.
Save gfacciol/48d946731e8cdac0693b3b4b3ea6d6c1 to your computer and use it in GitHub Desktop.
Pytorch implementation of Hamilton-Adams demosaicing
# Pytorch implementation of Hamilton-Adams demosaicing
# J. Hamilton Jr. and J. Adams Jr. Adaptive color plan interpolation
# in single sensor color electronic camera, 1997, US Patent 5,629,734.
#
# Copyright (c) 2021 Gabriele Facciolo
# based on code by Yu Guo and Qiyu Jin
import numpy as np
import torch
def mosaic_bayer(rgb, pattern):
"""
generate a mosaic from a rgb image
pattern can be: 'grbg', 'rggb', 'gbrg', 'bggr'
"""
num = np.zeros(len(pattern))
pattern_list = list(pattern)
p = pattern_list.index('r')
num[p] = 0
p = [idx for idx, i in enumerate(pattern_list) if i == 'g']
num[p] = 1
p = pattern_list.index('b')
num[p] = 2
size_rgb = rgb.shape
# handle the case when the input image is already a squeezed mosaic
if len(size_rgb)==2:
rgb = rgb.unsqueeze(2).repeat(1,1,3)
mask = torch.zeros((size_rgb[0], size_rgb[1], 3))
# Generate mask
mask[0::2, 0::2, int(num[0])] = 1
mask[0::2, 1::2, int(num[1])] = 1
mask[1::2, 0::2, int(num[2])] = 1
mask[1::2, 1::2, int(num[3])] = 1
# Generate mosaic
mosaic = rgb * mask
return mosaic, mask
def myconv(im, ker):
"""
convolve the 2d tensor image (im) with the 2d kernel (ker) and return a 2d tensor image
pads the image to preserve the shape of the input tensor by replicating boundaries
"""
# compute image padding
pad = np.array(ker.shape)//2
pad = (pad[1], pad[1], pad[0],pad[0]) # lefr,right,top,bottom
# pad and unsqueeze the data
im = torch.nn.functional.pad(im.unsqueeze(0).unsqueeze(0), pad, mode='replicate')
# unsqueeze the kernel
ker=torch.Tensor( ker ).unsqueeze(0).unsqueeze(0)
# apply conv, squeeze x 2 and return
return torch.nn.functional.conv2d(im , weight=ker).squeeze(0).squeeze(0)
# This functions implements Algorithm 1
def hagreen_interpolation(mosaic, mask):
"""
hamilton-adams green channel processing
"""
Kh = np.array([[1/2, 0, 1/2]])
Kv = Kh.T
Deltah = np.array([[1, 0, -2, 0, 1]])
Deltav = Deltah.T
Diffh = np.array([[1, 0, -1]])
Diffv = Diffh.T
rawq = torch.sum(mosaic, axis=2) # get the raw CFA data
rawh = myconv( rawq, Kh ) - myconv( rawq, Deltah/4 )
rawv = myconv( rawq, Kv ) - myconv( rawq, Deltav/4 )
CLh = torch.abs( myconv(rawq, Diffh) ) + torch.abs( myconv(rawq, Deltah) )
CLv = torch.abs( myconv(rawq, Diffv) ) + torch.abs( myconv(rawq, Deltav) )
# this implements the logic assigning rawh when CLv > CLh
# rawv when CLv < CLh;
# (rawh+rawv)/2 otherwise
CLlocation = torch.sign(CLh - CLv)
green = (1 + CLlocation) * rawv / 2 + (1 - CLlocation) * rawh / 2
imask = (mask == 0) # inverse mask
green = green * imask[:, :, 1] + rawq * mask[:, :, 1]
return green
# This functions implements Algorithm 2 (blue pixels)
def hablue_interpolation(green, mosaic, mask, pattern):
"""
hamilton-adams blue channel processing
"""
# mask
size_rawq = mosaic.shape
maskGr = torch.zeros((size_rawq[0], size_rawq[1]))
maskGb = torch.zeros((size_rawq[0], size_rawq[1]))
if pattern == 'grbg':
maskGr[0::2, 0::2] = 1
maskGb[1::2, 1::2] = 1
elif pattern == 'rggb':
maskGr[0::2, 1::2] = 1
maskGb[1::2, 0::2] = 1
elif pattern == 'gbrg':
maskGb[0::2, 0::2] = 1
maskGr[1::2, 1::2] = 1
elif pattern == 'bggr':
maskGb[0::2, 1::2] = 1
maskGr[1::2, 0::2] = 1
maskR = mask[:,:,0]
Kh = np.array([[1, 0, 1]])
Kv = Kh.T
Kp = np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]])
Kn = np.array([[0, 0, 1], [0, 0, 0], [1, 0, 0]])
Deltap = np.array([[1, 0, 0], [0, -2, 0], [0, 0, 1]])
Deltan = np.array([[0, 0, 1], [0, -2, 0], [1, 0, 0]])
Deltah = np.array([[1, -2, 1]])
Deltav = Deltah.T
Diffp = np.array([[-1, 0, 0], [0, 0, 0], [0, 0, 1]])
Diffn = np.array([[0, 0, -1], [0, 0, 0], [1, 0, 0]])
Bh = maskGb * ( 0.5 * myconv( mosaic[:,:,2], Kh ) - 0.25 * myconv( green, Deltah ))
Bv = maskGr * ( 0.5 * myconv( mosaic[:,:,2], Kv ) - 0.25 * myconv( green, Deltav ))
Bp = maskR * ( 0.5 * myconv( mosaic[:,:,2], Kp ) - 0.25 * myconv( green, Deltap ))
Bn = maskR * ( 0.5 * myconv( mosaic[:,:,2], Kn ) - 0.25 * myconv( green, Deltan ))
CLp = maskR * (torch.abs( myconv( mosaic[:,:,2], Diffp )) + torch.abs( myconv( green, Deltap )) )
CLn = maskR * (torch.abs( myconv( mosaic[:,:,2], Diffn )) + torch.abs( myconv( green, Deltan )) )
CLlocation = torch.sign(CLp - CLn)
blue = (1 + CLlocation) * Bn / 2 + (1 - CLlocation) * Bp / 2
blue = blue + Bh + Bv + mosaic[:, :, 2]
return blue
# This functions implements Algorithm 2 (red pixels)
def hared_interpolation(green, mosaic, mask, pattern):
"""
hamilton-adams red channel processing
"""
# mask
size_rawq = mosaic.shape
maskGr = torch.zeros((size_rawq[0], size_rawq[1]))
maskGb = torch.zeros((size_rawq[0], size_rawq[1]))
if pattern == 'grbg':
maskGr[0::2, 0::2] = 1
maskGb[1::2, 1::2] = 1
elif pattern == 'rggb':
maskGr[0::2, 1::2] = 1
maskGb[1::2, 0::2] = 1
elif pattern == 'gbrg':
maskGb[0::2, 0::2] = 1
maskGr[1::2, 1::2] = 1
elif pattern == 'bggr':
maskGb[0::2, 1::2] = 1
maskGr[1::2, 0::2] = 1
maskB = mask[:,:,2]
Kh = np.array([[1, 0, 1]])
Kv = Kh.T
Kp = np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]])
Kn = np.array([[0, 0, 1], [0, 0, 0], [1, 0, 0]])
Deltap = np.array([[1, 0, 0], [0, -2, 0], [0, 0, 1]])
Deltan = np.array([[0, 0, 1], [0, -2, 0], [1, 0, 0]])
Deltah = np.array([[1, -2, 1]])
Deltav = Deltah.T
Diffp = np.array([[-1, 0, 0], [0, 0, 0], [0, 0, 1]])
Diffn = np.array([[0, 0, -1], [0, 0, 0], [1, 0, 0]])
Rh = maskGr * ( 0.5 * myconv( mosaic[:,:,0], Kh ) - 0.25 * myconv( green, Deltah ))
Rv = maskGb * ( 0.5 * myconv( mosaic[:,:,0], Kv ) - 0.25 * myconv( green, Deltav ))
Rp = maskB * ( 0.5 * myconv( mosaic[:,:,0], Kp ) - 0.25 * myconv( green, Deltap ))
Rn = maskB * ( 0.5 * myconv( mosaic[:,:,0], Kn ) - 0.25 * myconv( green, Deltan ))
CLp = maskB * (torch.abs( myconv( mosaic[:,:,0], Diffp )) + torch.abs( myconv( green, Deltap )) )
CLn = maskB * (torch.abs( myconv( mosaic[:,:,0], Diffn )) + torch.abs( myconv( green, Deltan )) )
CLlocation = torch.sign(CLp - CLn)
red = (1 + CLlocation) * Rn / 2 + (1 - CLlocation) * Rp / 2
red = red + Rh + Rv + mosaic[:, :, 0]
return red
def demosaic_HA(mosaic, pattern):
"""
Hamilton-Adams demosaicing main function
mosaic can be a 2D or 3D array with dimensions W,H,C
pattern can be: 'grbg', 'rggb', 'gbrg', 'bggr'
"""
# mosaic and mask (just to generate the mask)
mosaic, mask = mosaic_bayer(mosaic, pattern)
# green interpolation (implements Algorithm 1)
green = hagreen_interpolation(mosaic, mask)
# Red and Blue demosaicing (implements Algorithm 2)
red = hared_interpolation(green, mosaic, mask, pattern)
blue = hablue_interpolation(green, mosaic, mask, pattern)
# result image
rgb_size = mosaic.shape
rgb_dem = torch.zeros((rgb_size[0], rgb_size[1], 3))
rgb_dem[:, :, 0] = red
rgb_dem[:, :, 1] = green
rgb_dem[:, :, 2] = blue
return rgb_dem
if __name__ == "__main__":
import torch
from skimage.io import imread, imsave
rgb = imread('Sans_bruit_13.PNG')
rgb = rgb.astype('float32')
pattern = 'grbg'
rgb = torch.Tensor(rgb)
# generate the mosaic
mosaic, _ = mosaic_bayer(rgb, pattern)
# call the demosaicing
rgb_dem = demosaic_HA(mosaic, pattern)
rgb_dem = torch.clip(rgb_dem, 0, 255)
imsave('test_HA2.png', rgb_dem.numpy().astype('uint8'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment