Skip to content

Instantly share code, notes, and snippets.

@DavidYKay
Last active June 26, 2023 00:52
Show Gist options
  • Star 33 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save DavidYKay/9dad6c4ab0d8d7dbf3dc to your computer and use it in GitHub Desktop.
Save DavidYKay/9dad6c4ab0d8d7dbf3dc to your computer and use it in GitHub Desktop.
Simple color balance algorithm using Python 2.7.8 and OpenCV 2.4.10. Ported from: http://www.morethantechnical.com/2015/01/14/simplest-color-balance-with-opencv-wcode/
import cv2
import math
import numpy as np
import sys
def apply_mask(matrix, mask, fill_value):
masked = np.ma.array(matrix, mask=mask, fill_value=fill_value)
return masked.filled()
def apply_threshold(matrix, low_value, high_value):
low_mask = matrix < low_value
matrix = apply_mask(matrix, low_mask, low_value)
high_mask = matrix > high_value
matrix = apply_mask(matrix, high_mask, high_value)
return matrix
def simplest_cb(img, percent):
assert img.shape[2] == 3
assert percent > 0 and percent < 100
half_percent = percent / 200.0
channels = cv2.split(img)
out_channels = []
for channel in channels:
assert len(channel.shape) == 2
# find the low and high precentile values (based on the input percentile)
height, width = channel.shape
vec_size = width * height
flat = channel.reshape(vec_size)
assert len(flat.shape) == 1
flat = np.sort(flat)
n_cols = flat.shape[0]
low_val = flat[math.floor(n_cols * half_percent)]
high_val = flat[math.ceil( n_cols * (1.0 - half_percent))]
print "Lowval: ", low_val
print "Highval: ", high_val
# saturate below the low percentile and above the high percentile
thresholded = apply_threshold(channel, low_val, high_val)
# scale the channel
normalized = cv2.normalize(thresholded, thresholded.copy(), 0, 255, cv2.NORM_MINMAX)
out_channels.append(normalized)
return cv2.merge(out_channels)
if __name__ == '__main__':
img = cv2.imread(sys.argv[1])
out = simplest_cb(img, 1)
cv2.imshow("before", img)
cv2.imshow("after", out)
cv2.waitKey(0)
@oak-tree
Copy link

oak-tree commented May 8, 2017

note that for python 2.7 math.floor(n_cols * half_percent) returns double. I warped it with int()

Copy link

ghost commented Dec 12, 2018

line 41 and 42 should be replaced by::
low_val = flat[int(math.floor(n_cols * half_percent))]
high_val = flat[int(math.ceil( n_cols * (1.0 - half_percent)))]

@umarkhalidAI
Copy link

Traceback (most recent call last):
File "C:/Users/Umar/PycharmProjects/WhiteBalance/Algo1_simple.py", line 57, in
out = simplest_cb(img, 1)
File "C:/Users/Umar/PycharmProjects/WhiteBalance/Algo1_simple.py", line 20, in simplest_cb
assert img.shape[2] == 3
AttributeError: 'NoneType' object has no attribute 'shape'

I am getting the above-mentioned errors while running the above code. I am reading an image in line 56 like this: img = cv2.imread("pic1")

@JennyLouise
Copy link

@umarkhalidcs, with cv2.imread("pic1"), if it can't find the image file it will just return None, as opposed to throwing an error. Check it's looking for the right file path.

@JackDesBwa
Copy link

Thanks for the transposition to python, it saved me a lot of time.

However, I executed it on an embedded processor and the implementation is particularly inefficient so that it took several seconds to work on two 5-megapixels images on my target. I analyzed the code and saw two points that might improve the efficiency:

  • Sort: Sorting is an "expensive" O(NlogN) operation which is only used to find two values. The same values can be found with an histogram which reads each value only one time, and two partial accumulations (because we know that it will be close to an end of the histogram). The two operations could be merged but I did not go that deep. In addition, the implementation I used is probably not the fastest one.

  • Levels: The way to clip the levels is very complicated with lots fo copies, allocations, modifications of the pixels, re-reading of the matrix to normalize, etc. But in fact, the algorithm already has all the information (min/max) of the linear curve to apply and can simply generate a LUT of 256 values to apply on the channels.

With these two corrections (I first wrote the term "optimizations", but I am not sure we can call it that way when the initial code is so inefficient) the algorithm is way faster (the sort vs histogram is alone theoretically at least 40× improvement) so that the sensation shifted from "have to wait" to "it's fast".

Optimizations might likely be pushed further, but this version was fast enough for my application.
Here is the resulting code. As you can see, it is more concise too:

"""Apply Simplest Color Balance algorithm
Reimplemented based on https://gist.github.com/DavidYKay/9dad6c4ab0d8d7dbf3dc"""
def simplest_cb(img, percent):
    out_channels = []
    channels = cv2.split(img)
    totalstop = channels[0].shape[0] * channels[0].shape[1] * percent / 200.0
    for channel in channels:
        bc = np.bincount(channel.ravel(), minlength=256)
        lv = np.searchsorted(np.cumsum(bc), totalstop)
        hv = 255-np.searchsorted(np.cumsum(bc[::-1]), totalstop)
        out_channels.append(cv2.LUT(channel, np.array(tuple(0 if i < lv else 255 if i > hv else round((i-lv)/(hv-lv)*255) for i in np.arange(0, 256)), dtype="uint8")))
    return cv2.merge(out_channels)

@jbonyun
Copy link

jbonyun commented Sep 12, 2019

@JackDesBwa, I was inspired by your improvements, but I have some further suggestions (for future readers).

First, if you're in Python 2.7, you need to float-ify some variables in the calculation of the lut values. Otherwise, the lut values become all 0's and 255's. I believe Python 3 gets away from this problem, but it's a gotcha for any 2.7 users.

Second, I found that np.bincount was slower than what you're replacing. Using numpy 1.16.5 on a x64 architecture, it takes about twice as long as the original, for me. But I didn't want to give up on this idea, so I looked around for something equivalent. The answer? cv2.calcHist. This runs much faster than the original (and much much faster than np.bincount). I'm not sure why, but that's what I'm getting.

So my proposed code is:

def simplest_cb(img, percent):
    out_channels = []
    channels = cv2.split(img)
    totalstop = channels[0].shape[0] * channels[0].shape[1] * percent / 200.0
    for channel in channels:
        bc = cv2.calcHist([channel], [0], None, [256], (0,256), accumulate=False)
        lv = np.searchsorted(np.cumsum(bc), totalstop)
        hv = 255-np.searchsorted(np.cumsum(bc[::-1]), totalstop)
        lut = np.array([0 if i < lv else (255 if i > hv else round(float(i-lv)/float(hv-lv)*255)) for i in np.arange(0, 256)], dtype="uint8")
        out_channels.append(cv2.LUT(channel, lut))
    return cv2.merge(out_channels)

@JackDesBwa
Copy link

JackDesBwa commented Sep 12, 2019

For the python 2.7 version, you are totally right. As I use python3 for a long while (version 3.7 at the time of writing), I easily forgot this old version is still vividly alive.

For the histogram computation, I read that bincount was faster and I blindly used it without timing it myself. In the hurry of my project, it was good enough.

Also, the theoretical huge improvements I mentioned was considering a mono-core simple CPU which might be completely changed by vectorization, multi-core or even usage of a GPU.

So let's time it on 3 targets I have here:
{1} Intel i7-8550U 4×2 cores (x86_64 CPU) with integrated UHD Graphics 620 (GPU)
{2} Intel i7-8550U 4×2 cores (x86_64 CPU) with GeForce GTX 1050 (GPU)
{3} Broadcom BCM2837 4 cores (armv8 CPU) with integrated VideoCore IV (GPU)
The third is orders of magnitude less powerful, and it was the embedded target of my project.
Also {1} and {2} use openCV 4.1.1, while the system of {3} is limited to version 2.4.9 in my current setup. I use python3 on all cases.

The timing measure is done with this image https://raspi.tv/wp-content/uploads/2013/05/RasPiCam.jpg and percent=1 using the timeit API, reporting the minimum of 10 attempts of doing 20 conversions. The method is not perfect so that there is still a variability between attempts, but the order of magnitude is not so bad.

Here are the results:
{1}
Original: 7.61s
My version: 0.76s
Your version: 0.17s
{2}
Original: 7.48s
My version: 0.75s
Your version: 0.16s
{3}
Original: 101.59s
My version: 15.51s
Your version: 7.85s

Note that you have to divide by the number of conversion (20) to get the time used to convert one image. It is coherent to the time observed in my actual application.

The opencv histogram version is indeed faster. The source I read was outdated about bincount speed-up, or perhaps it compared it with another numpy histogram method and not the opencv one.
Anecdotally, there is no difference between integrated graphics and bigger GPU for a such simple task, but the relative speed-up seems to depend on the hardware nonetheless (might come from software versions too).

As I will continue on my project, I improved the snippet slightly further again by merging the cumsum computations. The searchsorted used just after uses a binary search algorithm, and thus the direction distinction between the high and low case was useless. I also changed the way the lut is computed for a more readable way which does not hurt speed. The timing difference is very very tiny (about 30ms on {3} in the same conditions and even a bit better with my real images), but as importantly the structure of the algorithm seems expressed more clearly in my opinion.

def simplest_cb(img, percent=1):
    out_channels = []
    cumstops = (
        img.shape[0] * img.shape[1] * percent / 200.0,
        img.shape[0] * img.shape[1] * (1 - percent / 200.0)
    )
    for channel in cv2.split(img):
        cumhist = np.cumsum(cv2.calcHist([channel], [0], None, [256], (0,256)))
        low_cut, high_cut = np.searchsorted(cumhist, cumstops)
        lut = np.concatenate((
            np.zeros(low_cut),
            np.around(np.linspace(0, 255, high_cut - low_cut + 1)),
            255 * np.ones(255 - high_cut)
        ))
        out_channels.append(cv2.LUT(channel, lut.astype('uint8')))
    return cv2.merge(out_channels)

@rahulmistry751
Copy link

what is used of below line:
half_percent = percent / 200.0.
why percent is divided by 200.

@JackDesBwa
Copy link

why percent is divided by 200.

The percent parameter is expressed in natural language (number 5 means 5%) thus there is a division by 100 to express it at a fraction of 1, which is easier to work with (you do not need to normalize at each computation). The division by 2 is to get its half (half is used to cut the darks, half to cut the brights).
When you combine the two divisions, you get a division by 200.

@bulatnv
Copy link

bulatnv commented Dec 1, 2020

@DavidYKay, @JackDesBwa, @jbonyun.
Thank you, guys. Excelent work. Thumb up!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment