Skip to content

Instantly share code, notes, and snippets.

@soravux
Created November 23, 2017 00:39
Show Gist options
  • Save soravux/18a13892e700b57a964a484e97c314ed to your computer and use it in GitHub Desktop.
Save soravux/18a13892e700b57a964a484e97c314ed to your computer and use it in GitHub Desktop.
import numpy as np
from scipy.misc import imread, imsave
from matplotlib import pyplot as plt
# Get a linear version of the images
target_im = (imread("lighting_gt.png").astype('float32') / 255.)**2.2 # Poor man un-sRGB
source_im = (imread("warped_0150.png").astype('float32') / 255.)**2.2
# This will not work robustly
# # Get a pixel that *seems* white
# target_luminance = target_im.dot([0.299, 0.587, 0.114])
# source_luminance = source_im.dot([0.299, 0.587, 0.114])
# target_white_gray = np.percentile(target_luminance[target_luminance < 1.], 80)
# src_white_gray = np.percentile(source_luminance[source_luminance < 1.], 80)
# target_idx_h, target_idx_w = np.where(target_luminance == target_white_gray)
# src_idx_h, src_idx_w = np.where(source_luminance == src_white_gray)
mouse_x, mouse_y = -1, -1
def onclick(event):
global mouse_x, mouse_y
mouse_x = event.xdata
mouse_y = event.ydata
# Ask for point in target image
fig = plt.figure()
plt.imshow(target_im**(1./2.2))
cid = fig.canvas.mpl_connect('button_press_event', onclick)
plt.show()
target_idx_h = int(mouse_y)
target_idx_w = int(mouse_x)
# Ask for point in source image
fig = plt.figure()
plt.imshow(source_im**(1./2.2))
cid = fig.canvas.mpl_connect('button_press_event', onclick)
plt.show()
src_idx_h = int(mouse_y)
src_idx_w = int(mouse_x)
# Get area of 5x5 pixels around the selected positions
target_white = np.mean(target_im[target_idx_h-2:target_idx_h+3, target_idx_w-2:target_idx_w+3, :], axis=(0, 1))
src_white = np.mean(source_im[src_idx_h-2:src_idx_h+3, src_idx_w-2:src_idx_w+3, :], axis=(0, 1))
# Normalize the colors (we don't want to change brightness, just color correction)
target_white /= target_white.sum()
src_white /= src_white.sum()
print("target white:", target_white)
print("source white:", src_white)
# Apply white balance correction
fix_matrix = np.diag(target_white/src_white)
source_fixed = source_im.dot(fix_matrix)
print("Correction matrix:\n", fix_matrix)
plt.subplot(221); plt.imshow(target_im**(1./2.2)); plt.scatter(target_idx_w, target_idx_h, s=1); plt.axis('off')
plt.subplot(222); plt.imshow(source_im**(1./2.2)); plt.scatter(src_idx_w, src_idx_h, s=1); plt.axis('off')
plt.subplot(223); plt.imshow(np.clip(source_fixed, 0, 1)**(1./2.2)); plt.axis('off')
plt.show()
# Save back as sRGB
source_fixed_sRGB = np.clip((255.*source_fixed**(1./2.2)), 0, 255).astype('uint8')
imsave("warped_0150_fixed.png", source_fixed_sRGB)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment