Skip to content

Instantly share code, notes, and snippets.

@SebastianGrans
Created September 30, 2021 16:08
Show Gist options
  • Save SebastianGrans/888e576e0d39e2f968eae2f6fa05f2db to your computer and use it in GitHub Desktop.
Save SebastianGrans/888e576e0d39e2f968eae2f6fa05f2db to your computer and use it in GitHub Desktop.
Debayering a BayerRG (RGGB) image using PyTorch
# This only works for BayerRG color filter arrays. I.e. RGGB
# Inspired by this: https://github.com/cheind/pytorch-debayer
#
# I just wanted a minimal implementation to understand it better.
#
import torch
import numpy as np
import matplotlib.pyplot as plt
img_raw = ... # Load your image here. A (W, H) numpy array
img_raw = img_raw.astype(np.float)
# Change this to your original bit depth. My picture was a BayerRG12
img_raw /= 2**12 # E.g. 2**8 for a BayerRG8
# Convert it to a torch tensor
# Torch expects shapes to be (B, C, W, H)
# The raw image is uint16 which torch doesn't support.
# I convert it to int32 just to be safe.
# Convolution requires float, and then finally we set it to use the GPU.
img_torch = torch.from_numpy(
img_raw[None, None, :]
).to(torch.float32).to('cuda')
B, C, H, W = img_torch.shape
kernels = torch.tensor([
[0,0,0],
[0,1,0],
[0,0,0],
[0, 0.25, 0],
[0.25, 0, 0.25],
[0, 0.25, 0],
[0.25, 0, 0.25],
[0, 0, 0],
[0.25, 0, 0.25],
[0, 0, 0],
[0.5, 0, 0.5],
[0, 0, 0],
[0, 0.5, 0],
[0, 0, 0],
[0, 0.5, 0],
]).reshape(5,1,3,3).to('cuda')
indices = torch.tensor([
# dest channel r
[0, 3],
[4, 2],
# dest channel g
[1, 0],
[0, 1],
# dest channel b
[2, 4],
[3, 0],
]).reshape(1,3,2,2).to('cuda')
channel_selection_map = indices.repeat(B,1,H//2,W//2)
# With a 3x3 kernel we need to pad the boundry. Here we pad
# by simply replicating the pixels.
# img_torch thus has the shape (1, 1, W+2, H+2)
img_torch = torch.nn.functional.pad(img_torch, (1,1,1,1), mode='replicate')
# Apply the convolution resulting in the output, c, to have
# the shape (1, 5, W, H)
img_convolved = torch.nn.functional.conv2d(img_torch, kernels)
# Now we need to get the final image by, for each pixel,
# select the appropriate channels from dimension 1 of the tensor.
img_debayered = torch.gather(img_convolved, 1, channel_selection_map)
# Reshape it to (W, H, 3) like a normal image.
img_debayered = img_debayered[0].permute(1,2,0).to('cpu')
# Plot.
plt.imshow(img_debayered)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment