Skip to content

Instantly share code, notes, and snippets.

@shunsukeaihara
Created January 23, 2013 08:45
Show Gist options
  • Save shunsukeaihara/4603234 to your computer and use it in GitHub Desktop.
Save shunsukeaihara/4603234 to your computer and use it in GitHub Desktop.
some of color correction algorithm in python
# -*- coding: utf-8 -*-
import numpy as np
import Image
import sys
def from_pil(pimg):
pimg = pimg.convert(mode='RGB')
nimg = np.asarray(pimg)
nimg.flags.writeable = True
return nimg
def to_pil(nimg):
return Image.fromarray(np.uint8(nimg))
def stretch_pre(nimg):
"""
from 'Applicability Of White-Balancing Algorithms to Restoring Faded Colour Slides: An Empirical Evaluation'
"""
nimg = nimg.transpose(2, 0, 1)
nimg[0] = np.maximum(nimg[0]-nimg[0].min(),0)
nimg[1] = np.maximum(nimg[1]-nimg[1].min(),0)
nimg[2] = np.maximum(nimg[2]-nimg[2].min(),0)
return nimg.transpose(1, 2, 0)
def grey_world(nimg):
nimg = nimg.transpose(2, 0, 1).astype(np.uint32)
mu_g = np.average(nimg[1])
nimg[0] = np.minimum(nimg[0]*(mu_g/np.average(nimg[0])),255)
nimg[2] = np.minimum(nimg[2]*(mu_g/np.average(nimg[2])),255)
return nimg.transpose(1, 2, 0).astype(np.uint8)
def max_white(nimg):
if nimg.dtype==np.uint8:
brightest=float(2**8)
elif nimg.dtype==np.uint16:
brightest=float(2**16)
elif nimg.dtype==np.uint32:
brightest=float(2**32)
else:
brightest==float(2**8)
nimg = nimg.transpose(2, 0, 1)
nimg = nimg.astype(np.int32)
nimg[0] = np.minimum(nimg[0] * (brightest/float(nimg[0].max())),255)
nimg[1] = np.minimum(nimg[1] * (brightest/float(nimg[1].max())),255)
nimg[2] = np.minimum(nimg[2] * (brightest/float(nimg[2].max())),255)
return nimg.transpose(1, 2, 0).astype(np.uint8)
def stretch(nimg):
return max_white(stretch_pre(nimg))
def retinex(nimg):
nimg = nimg.transpose(2, 0, 1).astype(np.uint32)
mu_g = nimg[1].max()
nimg[0] = np.minimum(nimg[0]*(mu_g/float(nimg[0].max())),255)
nimg[2] = np.minimum(nimg[2]*(mu_g/float(nimg[2].max())),255)
return nimg.transpose(1, 2, 0).astype(np.uint8)
def retinex_adjust(nimg):
"""
from 'Combining Gray World and Retinex Theory for Automatic White Balance in Digital Photography'
"""
nimg = nimg.transpose(2, 0, 1).astype(np.uint32)
sum_r = np.sum(nimg[0])
sum_r2 = np.sum(nimg[0]**2)
max_r = nimg[0].max()
max_r2 = max_r**2
sum_g = np.sum(nimg[1])
max_g = nimg[1].max()
coefficient = np.linalg.solve(np.array([[sum_r2,sum_r],[max_r2,max_r]]),
np.array([sum_g,max_g]))
nimg[0] = np.minimum((nimg[0]**2)*coefficient[0] + nimg[0]*coefficient[1],255)
sum_b = np.sum(nimg[1])
sum_b2 = np.sum(nimg[1]**2)
max_b = nimg[1].max()
max_b2 = max_r**2
coefficient = np.linalg.solve(np.array([[sum_b2,sum_b],[max_b2,max_b]]),
np.array([sum_g,max_g]))
nimg[1] = np.minimum((nimg[1]**2)*coefficient[0] + nimg[1]*coefficient[1],255)
return nimg.transpose(1, 2, 0).astype(np.uint8)
def retinex_with_adjust(nimg):
return retinex_adjust(retinex(nimg))
def standard_deviation_weighted_grey_world(nimg,subwidth,subheight):
"""
This function does not work correctly
"""
nimg = nimg.astype(np.uint32)
height, width,ch = nimg.shape
strides = nimg.itemsize*np.array([width*subheight,subwidth,width,3,1])
shape = (height/subheight, width/subwidth, subheight, subwidth,3)
blocks = np.lib.stride_tricks.as_strided(nimg, shape=shape, strides=strides)
y,x = blocks.shape[:2]
std_r = np.zeros([y,x],dtype=np.float16)
std_g = np.zeros([y,x],dtype=np.float16)
std_b = np.zeros([y,x],dtype=np.float16)
std_r_sum = 0.0
std_g_sum = 0.0
std_b_sum = 0.0
for i in xrange(y):
for j in xrange(x):
subblock = blocks[i,j]
subb = subblock.transpose(2, 0, 1)
std_r[i,j]=np.std(subb[0])
std_g[i,j]=np.std(subb[1])
std_b[i,j]=np.std(subb[2])
std_r_sum += std_r[i,j]
std_g_sum += std_g[i,j]
std_b_sum += std_b[i,j]
sdwa_r = 0.0
sdwa_g = 0.0
sdwa_b = 0.0
for i in xrange(y):
for j in xrange(x):
subblock = blocks[i,j]
subb = subblock.transpose(2, 0, 1)
mean_r=np.mean(subb[0])
mean_g=np.mean(subb[1])
mean_b=np.mean(subb[2])
sdwa_r += (std_r[i,j]/std_r_sum)*mean_r
sdwa_g += (std_g[i,j]/std_g_sum)*mean_g
sdwa_b += (std_b[i,j]/std_b_sum)*mean_b
sdwa_avg = (sdwa_r+sdwa_g+sdwa_b)/3
gain_r = sdwa_avg/sdwa_r
gain_g = sdwa_avg/sdwa_g
gain_b = sdwa_avg/sdwa_b
nimg = nimg.transpose(2, 0, 1)
nimg[0] = np.minimum(nimg[0]*gain_r,255)
nimg[1] = np.minimum(nimg[1]*gain_g,255)
nimg[2] = np.minimum(nimg[2]*gain_b,255)
return nimg.transpose(1, 2, 0).astype(np.uint8)
if __name__=="__main__":
img = Image.open(sys.argv[1])
img.show()
to_pil(stretch(from_pil(img))).show()
to_pil(grey_world(from_pil(img))).show()
to_pil(retinex(from_pil(img))).show()
to_pil(max_white(from_pil(img))).show()
to_pil(retinex_adjust(retinex(from_pil(img)))).show()
to_pil(standard_deviation_weighted_grey_world(from_pil(img),50,50)).show()
@YustasDev
Copy link

Hi!
What is "Image" in "import Image" ?

@Pizzanomicon
Copy link

@YustasDev it appears to be pillow. I would recommend replacing that line with:

try:
    from PIL import Image
except ImportError:
    import Image

I found this project looking for white balancing algorithms, but it looks like this needs a touch of work. I also recommend shifting from:

def from_pil(pimg):
    pimg = pimg.convert(mode='RGB')
    nimg = np.asarray(pimg)
    nimg.flags.writeable = True
    return nimg

to:

def from_pil(pimg):
    pimg = pimg.convert(mode='RGB')
    nimg = np.asarray(pimg)
    nimg2 = np.dstack((nimg[:,:,0],nimg[:,:,1],nimg[:,:,2])) # hacky!
    nimg2.setflags(write=True)
    return nimg2

though I'm not sure if this project is capable of working with anything outside of 8b as it stands.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment