Skip to content

Instantly share code, notes, and snippets.

@rhee-elten
Created January 9, 2023 01:53
Show Gist options
  • Save rhee-elten/03ea159342323f412b8fe87975e48e6b to your computer and use it in GitHub Desktop.
Save rhee-elten/03ea159342323f412b8fe87975e48e6b to your computer and use it in GitHub Desktop.
# coding: utf-8
from __future__ import print_function, division, absolute_import
import sys
import os
import re
import fnmatch
from collections import OrderedDict as odict
from functools import lru_cache
import numpy as np
import cv2
#import numba
from PIL import Image
from os.path import dirname, basename, join
import base64
# imagehash==4.0 required
# to install: pip install ImageHash==4.0
from imagehash import phash, dhash
try:
get_ipython()
from tqdm import tqdm_notebook as tqdm
except:
from tqdm import tqdm
def walk_tree(root, regexp=None, pattern=None, match_fn=None, count_fn=None):
for dirpath, dirnames, filenames in os.walk(root):
for filename in filenames:
filepath = join(dirpath, filename)
if regexp is None and pattern is None and match_fn is None or \
regexp and re.match(regexp, filepath) or \
pattern and fnmatch.fnmatch(filepath, pattern) or \
callable(match_fn) and match_fn(filepath):
if callable(count_fn):
count_fn(filepath)
yield filepath
def read_image(im_path,
flags=cv2.IMREAD_UNCHANGED,
compress_gray_image=False):
"""
flags:
cv2.IMREAD_COLOR
cv2.IMREAD_GRAYSCALE
cv2.IMREAD_UNCHANGED
compress_gray_image: *** IGNORED ***
"""
if im_path.endswith('.npy'):
im = np.load(im_path)
else:
with open(im_path, 'rb') as f:
im_encoded = f.read()
# compress_gray_image 옴션 지정시,
# IMREAD_COLOR 플래그가 있으면 cv2 의 자체 변환을
# 사용하지 않고 gray2compressedrgb() 를 사용
if compress_gray_image and flags == cv2.IMREAD_COLOR:
im = cv2.imdecode(np.frombuffer(
im_encoded, dtype=np.uint8), flags=cv2.IMREAD_UNCHANGED)
# 원래 컬러 이미지면 상관 없고, 아니면 컬러로 변환
if len(im.shape) == 2:
im = gray2compressedrgb(
im, keep_size=True, reverse_channel_order=True)
else:
im = cv2.imdecode(np.frombuffer(
im_encoded, dtype=np.uint8), flags=flags)
# GBR==>RGB
if len(im.shape) > 2:
im = im[:, :, ::-1]
return im
def deinterlace2(im, k=(0.25, 0.5, 0.25)):
h, w = im.shape[:2]
im_temp = np.zeros((h, w + 2), dtype=np.float32)
im_temp[:, 0] = im[:, 0] # padding
im_temp[:, 1:-1] = im[:, :]
im_temp[:, -1] = im[:, -1] # padding
k_a, k_b, k_c = k
im_new = (k_a * im_temp[:, :-2] + k_b * im_temp[:, 1:-1] + k_c * im_temp[:, 2:]).astype(np.uint8)
return im_new
def center_crop(im,crop_size):
h, w = im.shape[:2]
x1 = (w - crop_size) // 2
x2 = x1 + crop_size
y1 = (h - crop_size) // 2
y2 = y1 + crop_size
im_new = im[y1:y2,x1:x2]
return im_new
def image_to_data_uri(im,crop_size=None):
if crop_size is not None:
im = center_crop(im,crop_size=crop_size)
if len(im.shape) > 2:
im = im[:,:,::-1]
_, arr = cv2.imencode('.png',im)
b64 = base64.b64encode(arr.tobytes()).decode('ascii')
data = 'data:image/png;base64,' + b64
return data
def make_image_data_uri(im_name,crop_size=150):
im = read_image(im_name,flags=cv2.IMREAD_UNCHANGED)
return image_to_data_uri(im,crop_size=crop_size)
def hash_thumbnail_data_uri(hash_val,hash_val_2,as_uri=False):
im = np.asarray([int(x) for x in list('{:064b}{:064b}'.format(hash_val,hash_val_2))],dtype=np.uint8)
im = im.reshape(16,1,8,1)
im = np.tile(im,[1,4,1,4]).reshape(64,32)*64+(128-32)
if as_uri:
return image_to_data_uri(im)
else:
return im
# def hash_to_title(hash_val):
# ss = '{:064b}'.format(hash_val)
# res = ''
# for rr in range(8):
# for cc in range(8):
# ch = ss[rr*8+cc]
# res += ch
# res += '\n'
# return res
#@numba.jit('u8(u1[:,:])')
def get_image_hash(im):
h_value = phash(Image.fromarray(im))
c = 0
for x in np.flip(h_value.hash.flatten()):
c = c << 1
if x:
c = c | 1
return c
#@numba.jit('u8(u8)')
def get_bits_count(c):
dist = 0
for _ in range(64):
if c & 1:
dist += 1
c = c >> 1
return dist
#@numba.jit('u8(u8,u8)')
def get_hash_distance(a, b):
c = a ^ b
return get_bits_count(c)
# bit count map for 16-bit integers
_bits_count_map = np.zeros(65536, dtype=int)
for i in range(len(_bits_count_map)):
_bits_count_map[i] = get_bits_count(i)
def array_bits_count(arr):
"""
split max 64-bit integers into 4 16-bit integers
and count bits for each parts
"""
mask = 0xffff
cnt1 = _bits_count_map[np.right_shift(arr, 48)]
cnt2 = _bits_count_map[np.bitwise_and(np.right_shift(arr, 32), mask)]
cnt3 = _bits_count_map[np.bitwise_and(np.right_shift(arr, 16), mask)]
cnt4 = _bits_count_map[np.bitwise_and(arr, mask)]
res = cnt1+cnt2+cnt3+cnt4
assert np.all(res >= 0) and np.all(
res <= 64), ('internal error: invalid count:', res)
return cnt1+cnt2+cnt3+cnt4
def check_similar_images(root_dir,
crop_size=138,
crop_size_2=64,
distance_threshold=4,
report_dir='.',
file_regexp=None,
file_pattern=None,
list_all=False,
max_files=None):
# regexp='.*\.[0-9]+.bmp$': _hf _vf _rot-??? 제외
files = sorted(list(walk_tree(root_dir, regexp=file_regexp, pattern=file_pattern)))
if max_files is not None:
files = files[:max_files]
phash_list = np.zeros(len(files), dtype=np.uint64)
phash_list_2 = np.zeros(len(files), dtype=np.uint64)
for i, fn in enumerate(tqdm(files)):
im = read_image(fn, cv2.IMREAD_UNCHANGED)
im = deinterlace2(im, k=(0.25,0.5,0.25))
im = center_crop(im, crop_size=crop_size)
im2 = center_crop(im, crop_size=crop_size_2)
ph_value = get_image_hash(im)
ph_value_2 = get_image_hash(im2)
phash_list[i] = ph_value
phash_list_2[i] = ph_value_2
with open(join(report_dir, 'result.html'), 'w', encoding='utf-8') as fout:
fout.write('''
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
.report_images div {
padding: 0;
margin: 0;
cursor: pointer;
}
.report_image img {
padding: 0;
margin: 0;
vertical-align: bottom;
}
.report_image {
display: inline-block;
position: relative;
height: 150px;
}
.report_image_thumb {
position: absolute;
left: 5px;
top: 5px;
z-index: 100;
opacity: 0.4;
}
.report_image_thumb img {
margin: 2px;
border: 1px solid red;
border-radius: 3px;
}
.report_image > .report_image_thumb {
visibility: hidden;
}
.report_image:hover > .report_image_thumb {
visibility: visible;
}
</style>
</head>
<body>
<div>
<ol>
''')
class_collisions = split_collisions = 0
itr = tqdm(phash_list)
for i, _ in enumerate(itr):
xors = np.bitwise_xor(phash_list[i+1:],phash_list[i])
xors_2 = np.bitwise_xor(phash_list_2[i+1:],phash_list_2[i])
counts = array_bits_count(xors) + array_bits_count(xors_2)
indices = (counts <= distance_threshold).nonzero()[0]
for i_nz in indices:
dist = counts[i_nz]
f1 = files[i]
ii = i+1+i_nz
f2 = files[ii]
h1 = phash_list[i]
hx1 = phash_list_2[i]
h2 = phash_list[ii]
hx2 = phash_list_2[ii]
hh1 = '{:08x}{:08x}'.format(h1,hx1)
hh2 = '{:08x}{:08x}'.format(h2,hx2)
hhdiff = '{:08x}{:08x}'.format(h1^h2,hx1^hx2)
thum1 = hash_thumbnail_data_uri(h1,hx1,as_uri=True)
thum2 = hash_thumbnail_data_uri(h2,hx2,as_uri=True)
thumdiff = hash_thumbnail_data_uri(h1^h2,hx1^hx2,as_uri=True)
b1 = basename(f1)
b2 = basename(f2)
p1 = basename(dirname(f1))
p2 = basename(dirname(f2))
pp1 = basename(dirname(dirname(f1)))
pp2 = basename(dirname(dirname(f2)))
if list_all or p1 != p2 or pp1 != pp2:
if p1 != p2:
class_collisions += 1
str1 = '*** different class: distance={:d}\nL[{:d}]: {:s}/{:s}\nR[{:d}]: {:s}/{:s}'.format(
dist, i, p1, b1, ii, p2, b2)
itr.write(str1)
msg1 = '''
<ul>
<li>diffrent class: distance={:d}</li>
<li>L: {:s}/{:s}</li>
<li>R: {:s}/{:s}</li>
<ul>
'''.format(dist, p1, b1, p2, b2)
else:
msg1 = ''
if pp1 != pp2:
split_collisions += 1
str2 = '*** different split: distance={:d}\nL[{:d}]: {:s}/{:s}/{:s}\nR[{:d}]: {:s}/{:s}/{:s}'.format(
dist, i, pp1, p1, b1, ii, pp2, p2, b2)
itr.write(str2)
msg2 = '''
<ul>
<li>diffrent split: distance={:d}</li>
<li>L: {:s}/{:s}/{:s}</li>
<li>R: {:s}/{:s}/{:s}</li>
<ul>
'''.format(dist, pp1, p1, b1, pp2, p2, b2)
else:
msg2 = ''
data1 = make_image_data_uri(f1)
data2 = make_image_data_uri(f2)
fout.write('''
<li>
<div class="report_item">
<div class="report_descs">
<div class="report_desc_1">
{msg1:s}
</div>
<div class="report_desc_2">
{msg2:s}
</div>
</div>
<div class="report_images">
<div class="report_image_left report_image" title="{hh1:s}">
<img src="{data1:s}" style="width:150px"/>
<div class="report_image_thumb"><img src="{thum1:s}" style="width:32px"/></div>
</div>
<div class="report_image" style="width:48px" title="{hhdiff:s}">
<div class="report_image_thumb"><img src="{thumdiff:s}" style="width:32px"/></div>
</div>
<div class="report_image_right report_image" title="{hh2:s}">
<img src="{data2:s}" style="width:150px"/>
<div class="report_image_thumb"><img src="{thum2:s}" style="width:32px"/></div>
</div>
</div>
</div>
</li>
'''.format(f1=f1,
f2=f2,
msg1=msg1,
msg2=msg2,
data1=data1,
data2=data2,
hh1=hh1,
hh2=hh2,
hhdiff=hhdiff,
thum1=thum1,
thum2=thum2,
thumdiff=thumdiff))
fout.write('''
</ol>
</div>
</body>
</html>
''')
print('total class collisions:', class_collisions)
print('total split collisions:', split_collisions)
if __name__ == '__main__':
"""
example:
check_similar_images.py --root_dir images/orig-0524 --file_regexp '.*\.[0-9]+\.bmp$'
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--root_dir',type=str,required=True)
parser.add_argument('--crop_size',type=int,default=138)
parser.add_argument('--crop_size_2',type=int,default=64)
parser.add_argument('--distance_threshold',type=int,default=14)
parser.add_argument('--report_dir',type=str,default='.')
parser.add_argument('--file_regexp',type=str)
parser.add_argument('--file_pattern',type=str)
parser.add_argument('--list_all',action='store_true')
parser.add_argument('--max_files',type=int)
args = parser.parse_args(sys.argv[1:])
check_similar_images(args.root_dir,
crop_size=args.crop_size,
crop_size_2=args.crop_size_2,
distance_threshold=args.distance_threshold,
report_dir=args.report_dir,
file_regexp=args.file_regexp,
file_pattern=args.file_pattern,
list_all=args.list_all,
max_files=args.max_files)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment