Created
January 9, 2023 01:53
-
-
Save rhee-elten/03ea159342323f412b8fe87975e48e6b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
from __future__ import print_function, division, absolute_import | |
import sys | |
import os | |
import re | |
import fnmatch | |
from collections import OrderedDict as odict | |
from functools import lru_cache | |
import numpy as np | |
import cv2 | |
#import numba | |
from PIL import Image | |
from os.path import dirname, basename, join | |
import base64 | |
# imagehash==4.0 required | |
# to install: pip install ImageHash==4.0 | |
from imagehash import phash, dhash | |
try: | |
get_ipython() | |
from tqdm import tqdm_notebook as tqdm | |
except: | |
from tqdm import tqdm | |
def walk_tree(root, regexp=None, pattern=None, match_fn=None, count_fn=None): | |
for dirpath, dirnames, filenames in os.walk(root): | |
for filename in filenames: | |
filepath = join(dirpath, filename) | |
if regexp is None and pattern is None and match_fn is None or \ | |
regexp and re.match(regexp, filepath) or \ | |
pattern and fnmatch.fnmatch(filepath, pattern) or \ | |
callable(match_fn) and match_fn(filepath): | |
if callable(count_fn): | |
count_fn(filepath) | |
yield filepath | |
def read_image(im_path, | |
flags=cv2.IMREAD_UNCHANGED, | |
compress_gray_image=False): | |
""" | |
flags: | |
cv2.IMREAD_COLOR | |
cv2.IMREAD_GRAYSCALE | |
cv2.IMREAD_UNCHANGED | |
compress_gray_image: *** IGNORED *** | |
""" | |
if im_path.endswith('.npy'): | |
im = np.load(im_path) | |
else: | |
with open(im_path, 'rb') as f: | |
im_encoded = f.read() | |
# compress_gray_image 옴션 지정시, | |
# IMREAD_COLOR 플래그가 있으면 cv2 의 자체 변환을 | |
# 사용하지 않고 gray2compressedrgb() 를 사용 | |
if compress_gray_image and flags == cv2.IMREAD_COLOR: | |
im = cv2.imdecode(np.frombuffer( | |
im_encoded, dtype=np.uint8), flags=cv2.IMREAD_UNCHANGED) | |
# 원래 컬러 이미지면 상관 없고, 아니면 컬러로 변환 | |
if len(im.shape) == 2: | |
im = gray2compressedrgb( | |
im, keep_size=True, reverse_channel_order=True) | |
else: | |
im = cv2.imdecode(np.frombuffer( | |
im_encoded, dtype=np.uint8), flags=flags) | |
# GBR==>RGB | |
if len(im.shape) > 2: | |
im = im[:, :, ::-1] | |
return im | |
def deinterlace2(im, k=(0.25, 0.5, 0.25)): | |
h, w = im.shape[:2] | |
im_temp = np.zeros((h, w + 2), dtype=np.float32) | |
im_temp[:, 0] = im[:, 0] # padding | |
im_temp[:, 1:-1] = im[:, :] | |
im_temp[:, -1] = im[:, -1] # padding | |
k_a, k_b, k_c = k | |
im_new = (k_a * im_temp[:, :-2] + k_b * im_temp[:, 1:-1] + k_c * im_temp[:, 2:]).astype(np.uint8) | |
return im_new | |
def center_crop(im,crop_size): | |
h, w = im.shape[:2] | |
x1 = (w - crop_size) // 2 | |
x2 = x1 + crop_size | |
y1 = (h - crop_size) // 2 | |
y2 = y1 + crop_size | |
im_new = im[y1:y2,x1:x2] | |
return im_new | |
def image_to_data_uri(im,crop_size=None): | |
if crop_size is not None: | |
im = center_crop(im,crop_size=crop_size) | |
if len(im.shape) > 2: | |
im = im[:,:,::-1] | |
_, arr = cv2.imencode('.png',im) | |
b64 = base64.b64encode(arr.tobytes()).decode('ascii') | |
data = 'data:image/png;base64,' + b64 | |
return data | |
def make_image_data_uri(im_name,crop_size=150): | |
im = read_image(im_name,flags=cv2.IMREAD_UNCHANGED) | |
return image_to_data_uri(im,crop_size=crop_size) | |
def hash_thumbnail_data_uri(hash_val,hash_val_2,as_uri=False): | |
im = np.asarray([int(x) for x in list('{:064b}{:064b}'.format(hash_val,hash_val_2))],dtype=np.uint8) | |
im = im.reshape(16,1,8,1) | |
im = np.tile(im,[1,4,1,4]).reshape(64,32)*64+(128-32) | |
if as_uri: | |
return image_to_data_uri(im) | |
else: | |
return im | |
# def hash_to_title(hash_val): | |
# ss = '{:064b}'.format(hash_val) | |
# res = '' | |
# for rr in range(8): | |
# for cc in range(8): | |
# ch = ss[rr*8+cc] | |
# res += ch | |
# res += '\n' | |
# return res | |
#@numba.jit('u8(u1[:,:])') | |
def get_image_hash(im): | |
h_value = phash(Image.fromarray(im)) | |
c = 0 | |
for x in np.flip(h_value.hash.flatten()): | |
c = c << 1 | |
if x: | |
c = c | 1 | |
return c | |
#@numba.jit('u8(u8)') | |
def get_bits_count(c): | |
dist = 0 | |
for _ in range(64): | |
if c & 1: | |
dist += 1 | |
c = c >> 1 | |
return dist | |
#@numba.jit('u8(u8,u8)') | |
def get_hash_distance(a, b): | |
c = a ^ b | |
return get_bits_count(c) | |
# bit count map for 16-bit integers | |
_bits_count_map = np.zeros(65536, dtype=int) | |
for i in range(len(_bits_count_map)): | |
_bits_count_map[i] = get_bits_count(i) | |
def array_bits_count(arr): | |
""" | |
split max 64-bit integers into 4 16-bit integers | |
and count bits for each parts | |
""" | |
mask = 0xffff | |
cnt1 = _bits_count_map[np.right_shift(arr, 48)] | |
cnt2 = _bits_count_map[np.bitwise_and(np.right_shift(arr, 32), mask)] | |
cnt3 = _bits_count_map[np.bitwise_and(np.right_shift(arr, 16), mask)] | |
cnt4 = _bits_count_map[np.bitwise_and(arr, mask)] | |
res = cnt1+cnt2+cnt3+cnt4 | |
assert np.all(res >= 0) and np.all( | |
res <= 64), ('internal error: invalid count:', res) | |
return cnt1+cnt2+cnt3+cnt4 | |
def check_similar_images(root_dir, | |
crop_size=138, | |
crop_size_2=64, | |
distance_threshold=4, | |
report_dir='.', | |
file_regexp=None, | |
file_pattern=None, | |
list_all=False, | |
max_files=None): | |
# regexp='.*\.[0-9]+.bmp$': _hf _vf _rot-??? 제외 | |
files = sorted(list(walk_tree(root_dir, regexp=file_regexp, pattern=file_pattern))) | |
if max_files is not None: | |
files = files[:max_files] | |
phash_list = np.zeros(len(files), dtype=np.uint64) | |
phash_list_2 = np.zeros(len(files), dtype=np.uint64) | |
for i, fn in enumerate(tqdm(files)): | |
im = read_image(fn, cv2.IMREAD_UNCHANGED) | |
im = deinterlace2(im, k=(0.25,0.5,0.25)) | |
im = center_crop(im, crop_size=crop_size) | |
im2 = center_crop(im, crop_size=crop_size_2) | |
ph_value = get_image_hash(im) | |
ph_value_2 = get_image_hash(im2) | |
phash_list[i] = ph_value | |
phash_list_2[i] = ph_value_2 | |
with open(join(report_dir, 'result.html'), 'w', encoding='utf-8') as fout: | |
fout.write(''' | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset="utf-8"> | |
<style> | |
.report_images div { | |
padding: 0; | |
margin: 0; | |
cursor: pointer; | |
} | |
.report_image img { | |
padding: 0; | |
margin: 0; | |
vertical-align: bottom; | |
} | |
.report_image { | |
display: inline-block; | |
position: relative; | |
height: 150px; | |
} | |
.report_image_thumb { | |
position: absolute; | |
left: 5px; | |
top: 5px; | |
z-index: 100; | |
opacity: 0.4; | |
} | |
.report_image_thumb img { | |
margin: 2px; | |
border: 1px solid red; | |
border-radius: 3px; | |
} | |
.report_image > .report_image_thumb { | |
visibility: hidden; | |
} | |
.report_image:hover > .report_image_thumb { | |
visibility: visible; | |
} | |
</style> | |
</head> | |
<body> | |
<div> | |
<ol> | |
''') | |
class_collisions = split_collisions = 0 | |
itr = tqdm(phash_list) | |
for i, _ in enumerate(itr): | |
xors = np.bitwise_xor(phash_list[i+1:],phash_list[i]) | |
xors_2 = np.bitwise_xor(phash_list_2[i+1:],phash_list_2[i]) | |
counts = array_bits_count(xors) + array_bits_count(xors_2) | |
indices = (counts <= distance_threshold).nonzero()[0] | |
for i_nz in indices: | |
dist = counts[i_nz] | |
f1 = files[i] | |
ii = i+1+i_nz | |
f2 = files[ii] | |
h1 = phash_list[i] | |
hx1 = phash_list_2[i] | |
h2 = phash_list[ii] | |
hx2 = phash_list_2[ii] | |
hh1 = '{:08x}{:08x}'.format(h1,hx1) | |
hh2 = '{:08x}{:08x}'.format(h2,hx2) | |
hhdiff = '{:08x}{:08x}'.format(h1^h2,hx1^hx2) | |
thum1 = hash_thumbnail_data_uri(h1,hx1,as_uri=True) | |
thum2 = hash_thumbnail_data_uri(h2,hx2,as_uri=True) | |
thumdiff = hash_thumbnail_data_uri(h1^h2,hx1^hx2,as_uri=True) | |
b1 = basename(f1) | |
b2 = basename(f2) | |
p1 = basename(dirname(f1)) | |
p2 = basename(dirname(f2)) | |
pp1 = basename(dirname(dirname(f1))) | |
pp2 = basename(dirname(dirname(f2))) | |
if list_all or p1 != p2 or pp1 != pp2: | |
if p1 != p2: | |
class_collisions += 1 | |
str1 = '*** different class: distance={:d}\nL[{:d}]: {:s}/{:s}\nR[{:d}]: {:s}/{:s}'.format( | |
dist, i, p1, b1, ii, p2, b2) | |
itr.write(str1) | |
msg1 = ''' | |
<ul> | |
<li>diffrent class: distance={:d}</li> | |
<li>L: {:s}/{:s}</li> | |
<li>R: {:s}/{:s}</li> | |
<ul> | |
'''.format(dist, p1, b1, p2, b2) | |
else: | |
msg1 = '' | |
if pp1 != pp2: | |
split_collisions += 1 | |
str2 = '*** different split: distance={:d}\nL[{:d}]: {:s}/{:s}/{:s}\nR[{:d}]: {:s}/{:s}/{:s}'.format( | |
dist, i, pp1, p1, b1, ii, pp2, p2, b2) | |
itr.write(str2) | |
msg2 = ''' | |
<ul> | |
<li>diffrent split: distance={:d}</li> | |
<li>L: {:s}/{:s}/{:s}</li> | |
<li>R: {:s}/{:s}/{:s}</li> | |
<ul> | |
'''.format(dist, pp1, p1, b1, pp2, p2, b2) | |
else: | |
msg2 = '' | |
data1 = make_image_data_uri(f1) | |
data2 = make_image_data_uri(f2) | |
fout.write(''' | |
<li> | |
<div class="report_item"> | |
<div class="report_descs"> | |
<div class="report_desc_1"> | |
{msg1:s} | |
</div> | |
<div class="report_desc_2"> | |
{msg2:s} | |
</div> | |
</div> | |
<div class="report_images"> | |
<div class="report_image_left report_image" title="{hh1:s}"> | |
<img src="{data1:s}" style="width:150px"/> | |
<div class="report_image_thumb"><img src="{thum1:s}" style="width:32px"/></div> | |
</div> | |
<div class="report_image" style="width:48px" title="{hhdiff:s}"> | |
<div class="report_image_thumb"><img src="{thumdiff:s}" style="width:32px"/></div> | |
</div> | |
<div class="report_image_right report_image" title="{hh2:s}"> | |
<img src="{data2:s}" style="width:150px"/> | |
<div class="report_image_thumb"><img src="{thum2:s}" style="width:32px"/></div> | |
</div> | |
</div> | |
</div> | |
</li> | |
'''.format(f1=f1, | |
f2=f2, | |
msg1=msg1, | |
msg2=msg2, | |
data1=data1, | |
data2=data2, | |
hh1=hh1, | |
hh2=hh2, | |
hhdiff=hhdiff, | |
thum1=thum1, | |
thum2=thum2, | |
thumdiff=thumdiff)) | |
fout.write(''' | |
</ol> | |
</div> | |
</body> | |
</html> | |
''') | |
print('total class collisions:', class_collisions) | |
print('total split collisions:', split_collisions) | |
if __name__ == '__main__': | |
""" | |
example: | |
check_similar_images.py --root_dir images/orig-0524 --file_regexp '.*\.[0-9]+\.bmp$' | |
""" | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--root_dir',type=str,required=True) | |
parser.add_argument('--crop_size',type=int,default=138) | |
parser.add_argument('--crop_size_2',type=int,default=64) | |
parser.add_argument('--distance_threshold',type=int,default=14) | |
parser.add_argument('--report_dir',type=str,default='.') | |
parser.add_argument('--file_regexp',type=str) | |
parser.add_argument('--file_pattern',type=str) | |
parser.add_argument('--list_all',action='store_true') | |
parser.add_argument('--max_files',type=int) | |
args = parser.parse_args(sys.argv[1:]) | |
check_similar_images(args.root_dir, | |
crop_size=args.crop_size, | |
crop_size_2=args.crop_size_2, | |
distance_threshold=args.distance_threshold, | |
report_dir=args.report_dir, | |
file_regexp=args.file_regexp, | |
file_pattern=args.file_pattern, | |
list_all=args.list_all, | |
max_files=args.max_files) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment