Last active
December 10, 2018 11:38
-
-
Save cosacog/9b07a0f748e00e4031c2cd854255495a to your computer and use it in GitHub Desktop.
Simple image processing: adjust mean level, gamma 説明は下の方にあります。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
from PIL import Image, ImageDraw, ImageFilter, ImageOps, ImageStat | |
import numpy as np | |
import matplotlib.pyplot as plt | |
class ImageProcForVisualExperiment(): | |
def __init__(self, path_img=None): | |
self.path_img = path_img | |
if self.path_img is not None: | |
img = Image.open(path_img) | |
self.load_image(img) | |
def load_image(self, img): | |
#self.img = Image.open(path_img) | |
self.img = img.copy() # not sure if .copy() is necessary | |
self.img_array = np.array(self.img) | |
self.size = self.img.size | |
self.isRGB = self.img_array.shape[-1] >= 3 | |
if self.img_array.shape[-1] == 4: | |
# include alpha channel | |
self.img_array = self.img_array[:,:,:3] | |
self.img = Image.fromarray(self.img_array) | |
self.size = self.img.size | |
if self.isRGB: | |
# color image: get gray, r, g, b | |
# ref. https://wwld.jp/2017/05/14/image-grayscale.html | |
# L = ( 0.298912 * r + 0.586611 * g + 0.114478 * b ) | |
self.L = self.img.convert('L') | |
self._set_red_channel() | |
self._set_green_channel() | |
self._set_blue_channel() | |
else: | |
# gray scale | |
self.L = self.img | |
def adjust_contrast(self, contrast=0.9, col='L', copy=True): | |
''' | |
adjust contrast keeping mean pixel level | |
usage: | |
img = self.adjust_contrast(contrast=0.9, col='L') | |
params: | |
contrast: set between 0.0 - 1.0 | |
col: 'L'-gray, 'R', 'G', 'B' is acceptable | |
copy: output Image object. If False, replace internal image data | |
''' | |
# check col option: exclude 'RGB' | |
if not col.upper() in ['L', 'R', 'G', 'B']: | |
raise ValueError("contrast must be either of 'L', 'R', 'G', B'.") | |
# get single ch (L, R, G, B) image data | |
img, _ = self._get_1ch_img(col=col, keep_color=False) | |
imgOut = self._adjust_contrast(img, contrast=contrast) | |
imgOutCol = self._get_rgb_image_by2d_image(imgOut,col) # get rgb image if not 'L' | |
if copy: | |
return imgOutCol | |
else: | |
self._set_image(imgOutCol, col) | |
def adjust_gamma(self, gamma=[1.0, 1.0, 1.0], gain=1.0, copy=True): | |
''' | |
usage: | |
img, imgR, imgG, imgB, imgL = self.adjust_gamma(gamma=(1.0, 1.0, 1.0), gain=1.0, copy=True) | |
params: | |
gamma | |
gain: I do not understand well | |
copy: if False, instance image data will be replaced | |
return: | |
None if copy=False | |
(img, imgR, imgG, imgB, imgL): Image if copy=True (default) | |
''' | |
imgOut = self.img.copy().point(self.gamma_table(gamma, gain)) | |
imgOutL = imgOut.convert('L') | |
imgOutR = self._get_red_channel(imgOut) | |
imgOutG = self._get_green_channel(imgOut) | |
imgOutB = self._get_blue_channel(imgOut) | |
if copy: | |
print('if you want to save this, use *.save(filename). ') | |
return (imgOut, imgOutR, imgOutG, imgOutB, imgOutL) | |
else: | |
self.img = imgOut | |
self.img_array = np.array(self.img) | |
self.R = imgOutR | |
self.G = imgOutG | |
self.B = imgOutB | |
self.L = imgOutL | |
return | |
def adjust_gray_gamma(self, gamma=1.0, gain=1.0, copy=True): | |
''' | |
usage: | |
imgL = self.adjust_gray_gamma(gamma=1.0, gain=1.0, copy=True) | |
params: | |
gamma | |
gain: I do not understand well | |
copy: if False, instance image data will be replaced | |
return: | |
None if copy=False | |
imgL: Image if copy=True (default) | |
''' | |
imgOut = self.L.copy().point(self.gamma_table_for_L(gamma, gain)) | |
if copy: | |
print('if you want to save this, use *.save(filename). ') | |
return imgOut | |
else: | |
self.L = imgOut | |
return | |
def adjust_mean_level(self, mean_level=0.5, col='L', copy=True): | |
''' | |
Only support each of L, R, G, B | |
''' | |
if (mean_level <= 0.0) or (mean_level >= 1.0): | |
print("mean level must be between 0.0 and 1.0. Try again.") | |
return | |
img = self._get_1ch_img(col=col, keep_color=False) | |
imgOut = self._adjust_mean_level(img, mean_level=mean_level, copy=copy) | |
imgOutCol = self._get_rgb_image_by2d_image(imgOut, col) # get rgb image if not 'L' | |
if copy: | |
# imgOut = self._set_image_color(img, col) | |
return imgOutCol | |
else: | |
# self._set_image(img, col) | |
self._set_image(imgOutCol, col) | |
def adjust_mean_and_contrast(self, mean_level=0.5, contrast=0.9, col='L', copy=True): | |
''' | |
adjust mean and contrast | |
Only support each of L, R, G, B | |
''' | |
if (mean_level <= 0.0) or (mean_level >= 1.0): | |
print("Pixel mean level must be between 0.0 and 1.0. Try again.") | |
return | |
if (contrast <= 0.0) or (contrast >= 1.0): | |
print("Contrast must be between 0.0 and 1.0. Try again.") | |
return | |
imgInput = self._get_1ch_img_2darray(col=col) | |
img_mean_adjusted = self._adjust_mean_level(imgInput, mean_level=mean_level, copy=True) | |
img_contrast_adjusted = self._adjust_contrast(img_mean_adjusted, contrast=contrast) | |
if copy: | |
return img_contrast_adjusted | |
else: | |
self._set_image(img_contrast_adjusted, col) | |
def copy(self): | |
img_copy = ImageProcForVisualExperiment() | |
img_copy.load_image(self.img) | |
return img_copy | |
def get_stats(self, col='L'): | |
''' | |
get stats | |
''' | |
img, _ = self._get_1ch_img(col=col, keep_color=False) | |
min_p, max_p, mean_p, contrast_p = self._get_stats(img) | |
print("min:%d, max:%d, mean:%0.3f, contrast:%0.3f" %(min_p, max_p, mean_p, contrast_p)) | |
def plot_sorted_pixel_levels(self, col='L'): | |
''' | |
plot pixel level curve | |
''' | |
img, col_txt = self._get_1ch_img(col=col, keep_color=False) | |
img_arry = np.array(img) | |
size_img = img_arry.size | |
# get stats | |
min_p, max_p, mean_p, contrast_p = self._get_stats(img) | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
ax.plot(np.sort(img_arry.flatten())) | |
ax.set_xlabel('pixel count') | |
ax.set_ylabel('pixel level (0-255)') | |
ax.set_ylim((0,255)) | |
ax.set_xlim((0, size_img)) | |
ttl_plt = "color: {0}, mean:{1:0.2f}, contrast:{2:0.2f}, min:{3}, max{4}".format( | |
col_txt, mean_p, contrast_p, min_p, max_p | |
) | |
ax.set_title(ttl_plt) | |
return fig | |
def save(self, path_save, col=None): | |
''' | |
save image | |
params: | |
path_save: needs extension (e.g. .png) | |
col: either of 'R', 'G', 'B' or 'L' | |
return: | |
none | |
print mean, contrast, min, max level | |
''' | |
if not col in ['R','G','B','L']: | |
raise ValueError('col option must be either of "R","G","B" or "L".') | |
img, col_txt = self._get_1ch_img(col=col, keep_color=True) | |
min_p, max_p, mean_p, contrast_p = self.get_stats(col=col) | |
if img is not None: | |
img.save(path_save) | |
print("{0} image was saved to {1}.\n".format( | |
col_txt, path_save | |
)) | |
print("Image mean:{0:0.2f}, contrast:{1:0.2f}, min:{2}, max:{3}.".format( | |
mean_p, contrast_p, min_p, max_p)) | |
else: | |
print('Image was not saved. There is some problem.') | |
def show(self, col=None): | |
''' | |
option col: 'R', 'G', 'B', 'L'(for gray) | |
''' | |
# todo: use _get_1ch_img | |
print('If you want to show red channel, use "col="R" option, etc.') | |
# check col option | |
if col is None: | |
img = self.img | |
elif col in ['R', 'G', 'B', 'L']: | |
img, _ = self._get_1ch_img(col=col, keep_color=True) | |
else: | |
print('col option must be either "R", "G", "B" or "L".') | |
return | |
img.show() | |
im_size = img.size | |
print("img width: {0}, height:{1}".format(im_size[0], im_size[1])) | |
def stretch_img(self, col='L'): | |
''' | |
set contrast 1.0 | |
cannot keep mean luminance | |
''' | |
# check col option: exclude 'RGB' | |
if not col.upper() in ['L', 'R', 'G', 'B']: | |
raise ValueError("contrast must be either of 'L', 'R', 'G', B'.") | |
# get single ch (L, R,G,B) image data | |
img, _ = self._get_1ch_img(col=col, keep_color=False) | |
imgOut = self._stretch_img(img) # 2d array image | |
imgOutCol = self._get_rgb_image_by2d_image(imgOut, col) # 2d array image -> rgb image if not 'L' | |
return imgOutCol | |
def _adjust_contrast(self, img, contrast=0.9): | |
''' | |
adjust contrast | |
params: | |
img: instance of Image. Must be 2d array image | |
contrast | |
return: | |
imgOut: instance of Image. 2d array image | |
''' | |
# get mean level | |
_, _, mean_p, _ = self._get_stats(img) | |
# if contrast < 1.0, stretch image | |
img_stretched = self._stretch_img(img) | |
# adjust mean level | |
img_mean_adjusted = self._adjust_mean_level(img_stretched, mean_level=mean_p, copy=True) | |
# set contrast | |
## get min/max level with mean for a given contrast | |
min_p, max_p = self._calc_minmax_with_mean_contrast(mean_p, contrast) | |
## get table: below mean | |
arry_below = np.linspace(min_p, int(mean_p), int(mean_p)) | |
tbl_below = [int(x) for x in arry_below] | |
## get table :above mean | |
arry_above = np.linspace(int(mean_p), max_p, 257-int(mean_p)) | |
# mean_pのところが重複しないよう257を設定して1個削る | |
tbl_above = [int(x) for x in arry_above[1:]] | |
tbl = tbl_below + tbl_above | |
# map image | |
imgOut = img_mean_adjusted.copy().point(tbl) | |
return imgOut | |
def _adjust_mean_level(self, img, mean_level=0.5, copy=True): | |
''' | |
adjust mean level | |
params: | |
img: instance of Image. Must be 2d array image | |
mean_level | |
return: | |
imgOut: instance of Image. 2d array image | |
''' | |
# img = self._get_1ch_img_2darray(col=col) | |
meanLinit = np.array(img).mean()/255.0 | |
# mean = mean | |
if meanLinit < mean_level: | |
# get more bright image | |
imgOut = self._seek_gamma_when_mean_is_lower(img, meanLinit, mean_level) | |
elif meanLinit > mean_level: | |
imgOut = self._seek_gamma_when_mean_is_higher(img, meanLinit, mean_level) | |
else: | |
imgOut = img | |
return imgOut | |
def _calc_minmax_with_mean_contrast(self, mean_p, contrast): | |
''' | |
calculate min/max | |
''' | |
min_p = 255.*mean_p*(1-contrast)/(255 - 255*contrast + 2*contrast*mean_p) | |
max_p = (1+contrast)/(1-contrast)*min_p | |
return (int(min_p), int(max_p)) | |
def _get_1ch_img(self, col='L', keep_color=True): | |
img = None | |
col_txt = '' | |
if self.isRGB: | |
if col.upper()=='L': | |
img = self.L | |
col_txt = 'Gray' | |
elif col.upper()=='R': | |
img = self.R | |
col_txt = 'Red' | |
elif col.upper()=='G': | |
img = self.G | |
col_txt = 'Green' | |
elif col.upper()=='B': | |
img = self.B | |
col_txt = 'Blue' | |
else: | |
img = self.img | |
col_txt = 'Gray' | |
if (self.isRGB) and (not col.upper()=='L') and (not keep_color): | |
# extract R or G or B channel | |
idx_col = ['R', 'G', 'B'].index(col.upper()) | |
img = Image.fromarray(np.array(img)[:,:, idx_col]) | |
return (img, col_txt) | |
def _get_1ch_img_2darray(self, col='L'): | |
''' | |
get single channel image | |
''' | |
if self.isRGB: | |
if col.upper()=='L': | |
img_arry = np.array(self.L) | |
elif col.upper()=='R': | |
img_arry = np.array(self.R)[:,:,0] | |
elif col.upper()=='G': | |
img_arry = np.array(self.G)[:,:,1] | |
elif col.upper()=='B': | |
img_arry = np.array(self.B)[:,:,2] | |
else: | |
img_arry = np.array(self.img) | |
return Image.fromarray(img_arry) | |
def _get_rgb_image_by2d_image(self, img, col): | |
''' | |
get rgb image from 2d array image | |
params: | |
img: instance of Image. must be 2d array image | |
col: color-either of 'L','R','G','B' | |
return: | |
imgOut: instance of Image. rgb image | |
''' | |
if col=='L': | |
return img | |
img_arry = np.array(img) | |
shape_img = img_arry.shape | |
img_arry_rgb = np.zeros((shape_img[0], shape_img[1],3)) | |
idx_col = ['R','G','B'].index(col.upper()) | |
img_arry_rgb[:,:,idx_col] = img_arry | |
imgOut = Image.fromarray(img_arry_rgb) | |
return imgOut | |
def _get_red_channel(self, img): | |
img_copy = np.array(img) | |
img_copy[:,:,(1,2)] = 0 | |
return Image.fromarray(img_copy) | |
def _get_green_channel(self, img): | |
img_copy = np.array(img) | |
img_copy[:,:,(0,2)] = 0 | |
return Image.fromarray(img_copy) | |
def _get_blue_channel(self, img): | |
img_copy = np.array(img) | |
img_copy[:,:,(0,1)] = 0 | |
return Image.fromarray(img_copy) | |
def _get_stats(self, img): | |
''' | |
get min, max, mean, contrast | |
do not process exception here. | |
assume 2d img (=h,w), not RGB (=h,w,3) | |
''' | |
stat = ImageStat.Stat(img) | |
min_pix, max_pix = stat.extrema[0] | |
mean_pix = stat.mean[0] | |
contrast_pix = (max_pix - min_pix)/(max_pix + min_pix) | |
return (min_pix, max_pix, mean_pix, contrast_pix) | |
def _set_red_channel(self): | |
self.R = self._get_red_channel(self.img) | |
def _set_green_channel(self): | |
self.G = self._get_green_channel(self.img) | |
def _set_blue_channel(self): | |
self.B = self._get_blue_channel(self.img) | |
def _set_image(self, img, col): | |
img_arry = np.array(img) | |
if self.isRGB: | |
if col.upper()=='L': | |
self.L = img | |
elif col.upper()=='R': | |
orig_img_arry = np.array(self.R) | |
orig_img_arry[:,:,0] = img_arry | |
self.R = Image.fromarray(orig_img_arry) | |
elif col.upper()=='G': | |
orig_img_arry = np.array(self.G) | |
orig_img_arry[:,:,1] = img_arry | |
self.G = Image.fromarray(orig_img_arry) | |
elif col.upper()=='B': | |
orig_img_arry = np.array(self.B) | |
orig_img_arry[:,:,2] = img_arry | |
self.B = Image.fromarray(orig_img_arry) | |
else: | |
self.img = img | |
def _seek_gamma_when_mean_is_lower(self, img, meanLinit, mean_level): | |
gamma_base = 1.0 | |
gammas = np.arange(11.0) | |
meanL = meanLinit | |
g_increment = 0.0 | |
for i in np.arange(7): | |
isSkipLoop=False | |
isOver = False | |
gammas = gammas if i==0 else gammas*0.1 | |
for idx, g in enumerate(gammas): | |
if (isOver) and (not isSkipLoop): | |
g_increment += gammas[idx-2] | |
isSkipLoop=True | |
gamma = gamma_base + g_increment + g | |
imgOut = img.copy().point(self.gamma_table_for_L(gamma)) | |
meanL = np.array(imgOut).mean()/255.0 | |
isOver = meanL >= mean_level | |
imgOut = img.copy().point(self.gamma_table_for_L(gamma - gammas[1])) | |
return imgOut | |
def _seek_gamma_when_mean_is_higher(self, img, meanLinit, mean_level): | |
gamma_base = 1.0 | |
gammas = np.arange(11.0) | |
meanL = meanLinit | |
g_increment = 0.0 | |
for i in np.arange(7): | |
isSkipLoop=False | |
isOver = False | |
gammas = gammas if i==0 else gammas*0.1 | |
for idx, g in enumerate(gammas): | |
if (isOver) and (not isSkipLoop): | |
g_increment += gammas[idx-2] | |
isSkipLoop=True | |
gamma = 1.0/(gamma_base + g_increment + g) | |
imgOut = img.copy().point(self.gamma_table_for_L(gamma)) | |
meanL = np.array(imgOut).mean()/255.0 | |
isOver = meanL <= mean_level | |
imgOut = img.copy().point(self.gamma_table_for_L(gamma - gammas[1])) | |
return imgOut | |
def _stretch_img(self, img): | |
''' | |
set contrast to 1.0 | |
params: | |
img: instance of Image. Must be 2d array image | |
return: | |
imgOut: instance of Image. 2d array image | |
''' | |
# get contrast | |
min_p, max_p, _, contrast_p = self._get_stats(img) | |
if contrast_p == 1.0: | |
return img | |
# stretch image | |
range_p = max_p - min_p | |
arry_tbl = np.zeros(256) | |
arry_tbl[min_p:max_p] = np.arange(0, 255, 255/range_p) | |
arry_tbl[max_p:] = 255 | |
tbl = [int(x) for x in arry_tbl] | |
# return | |
imgOut = img.copy().point(tbl) | |
return imgOut | |
@staticmethod | |
def gamma_table(gamma=[1,0, 1.0, 1.0], gain=1.0): | |
''' | |
ref. https://qiita.com/pashango2/items/145d858eff3c505c100a | |
usually only gamma values are needed | |
''' | |
gamma_r, gamma_g, gamma_b = (gamma[0], gamma[1], gamma[2]) | |
gain_r, gain_g, gain_b = (gain, gain, gain) | |
r_tbl = [min(255, int((x / 255.) ** (1. / gamma_r) * gain_r * 255.)) for x in range(256)] | |
g_tbl = [min(255, int((x / 255.) ** (1. / gamma_g) * gain_g * 255.)) for x in range(256)] | |
b_tbl = [min(255, int((x / 255.) ** (1. / gamma_b) * gain_b * 255.)) for x in range(256)] | |
return r_tbl + g_tbl + b_tbl | |
@staticmethod | |
def gamma_table_for_L(gamma=1.0, gain=1.0): | |
''' | |
ref. https://qiita.com/pashango2/items/145d858eff3c505c100a | |
usually only gamma values are needed | |
''' | |
tbl = [min(255, int((x / 255.) ** (1. / gamma) * gain * 255.)) for x in range(256)] | |
return tbl | |
if __name__ == '__main__': | |
import shutil | |
import requests | |
import tempfile | |
import os | |
# png download | |
url_sample_gray_gradient = "https://bit.ly/2L80GAq" # sample png | |
res = requests.get(url_sample_gray_gradient,stream=True) | |
filepath = os.path.join(tempfile.mkdtemp(), "gray_gradient.png") | |
with open(filepath,"wb") as fp: | |
shutil.copyfileobj(res.raw,fp) | |
# image load | |
img_sample = ImageProcForVisualExperiment(path_img = filepath) | |
img_sample.get_stats(col='L') | |
fig1= img_sample.plot_sorted_pixel_levels(col='L') | |
# adjust mean level and contrast | |
img_sample.adjust_mean_and_contrast(col='L', mean_level=0.5, contrast=0.8, copy=False) | |
fig2=img_sample.plot_sorted_pixel_levels(col='L') | |
# adjust gamma | |
img_sample.adjust_gray_gamma(gamma=2.2, copy=False) | |
fig3= img_sample.plot_sorted_pixel_levels(col='L') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
比較的シンプルな画像処理のスクリプト:pillowをベースにしてます
使い方:導入
1. 導入のための準備
このサイトのスクリプトをダウンロードして読み込むためのモジュールを導入します。
"import_gist" てのはカスタムのモジュール(スクリプト)で、端末上から
で、pipを使って導入できます。
注意:たぶんpython3でしか走りません。
参照: githubで"cosacog/import_gist"を検索してください。
2. pythonでの作業
普通にダウンロードして読み込めばよいですが、1の作業をすると、以下の方法でも使えるようになります.
で準備完了です。
使い方:本番
ちょっとくたびれて十分テストしてないので、不具合があれば声かけてください。
1. 表示
2. 輝度+コントラスト調整
輝度といってもいわゆるピクセルレベルを調整するだけです。
画面上の輝度はモニタ、PCの特性(通常はWindowsでガンマ 2.2)の影響を受けるので、その辺ご注意ください。
引数
返り値
3. ガンマの補正
主にWindowsのガンマ(2.2くらい)を直線的な変化(1.0)に調整するのを想定しています。
引数
返り値
注意点
4. 画像の平均グレーレベルとか取り出し方
引数
返り値
なし, min, max, mean, contrastをコンソールに出力します
5. 画像の保存
6. 画像データ(pillowの画像として)の取り出し方
ここから先はpillowの文法に沿って作業になります。
7. おまけ機能:ピクセルレベルをソートしてプロット
min, max, mean, contrastがイメージとしてつかみやすいのではないかと思い作りました。
引数
col: 'L', 'R', 'G', 'B'のいずれか
返り値
なし。プロットが出ます。横軸はピクセルの数、縦軸はグレーレベル(0-255)です。
タイトルにmin, max, mean, contrastを表記してます。
8. 最後に
dir(mdl.ImageProcForVisualExperiment)
とかすると使えるメソッド(関数)の一覧がわかります。ヘルプドキュメントとかほぼ書き込んでないので悪しからず。