Skip to content

Instantly share code, notes, and snippets.

@rossgoodwin
Forked from 1328/histo match
Created June 11, 2014 22:45
Show Gist options
  • Save rossgoodwin/25d867a961b5d475219e to your computer and use it in GitHub Desktop.
Save rossgoodwin/25d867a961b5d475219e to your computer and use it in GitHub Desktop.
import random
import time
import collections
import itertools
import numpy as np
from pprint import pprint
from PIL import Image
class Test(object):
'''switched to a class, to avoid globals'''
def __init__(self, source_image, target_image):
self.srcImg = Image.open(source_image)
self.tgtImg = Image.open(target_image)
self.srcPix = self.srcImg.load()
self.tgtPix = self.tgtImg.load()
#Get histograms of the images
#only take the first 256 values for now since they're B&W
self.src_hist = self.srcImg.histogram()[:256]
self.tgt_hist = self.tgtImg.histogram()[:256]
self._build_map()
self._build_bins()
def _build_map(self):
'''builds the same reverse map for pixel, but uses defaultdict'''
self.pxls_map = collections.defaultdict(set)
width, height = self.srcImg.size[0],self.srcImg.size[1]
for i,j in itertools.product(range(width),range(height)):
self.pxls_map[self.srcPix[i,j][0]].add((i,j))
def _build_bins(self):
'''let's make the bins deques, which should be a bit faster'''
'''particularly as the balancing code no longer uses slices'''
'''slices are slow!'''
self.excesses = collections.deque()
self.deficits = collections.deque()
for i in range(256):
delta = self.src_hist[i]-self.tgt_hist[i]
if delta>0:
self.excesses.append((i,delta))
if delta<0:
self.deficits.append((i,abs(delta)))
def shift_pixels(self, src, tgt, n = 1):
'''shift n number of pixels from src to tgt'''
candidates = self.pxls_map[src]
n = min(n,len(candidates))
## print('shifting {} pixels from {} to {}'.format(n,src, tgt))
for pxl in random.sample(candidates, n):
self.srcPix[pxl] = (tgt, tgt, tgt)
# and alter the map
self.pxls_map[src].remove(pxl)
self.pxls_map[tgt].add(pxl)
# update the histos
self.src_hist[src] -= n
self.tgt_hist[tgt] += n
def smooth(self, src, tgt, n = 1):
'''shifts group of pixels, but by increments of 1/-1'''
'''pixels in each shift are pulled out randomly for smoothing'''
delta = tgt - src
incriment = self.sign(delta)
for i in range(abs(delta)):
nval = src + incriment
self.shift_pixels(src, nval, n)
src = nval
def balance(self):
'''balance the histograms from src to tgt'''
start = time.time()
count =0
# first we prime the loop
src, excess = self.excesses.popleft()
tgt, deficit = self.deficits.popleft()
while True:
frame = min(excess,deficit)
if count % 10 == 0:
print('{} -> src {} target {} frame {}'.format(
self.pretty_time(time.time()-start),
src, tgt, frame))
self.smooth(src, tgt, frame)
# now increment the loop
excess -= frame
deficit -= frame
count += 1
if excess < 1:
if not self.excesses:
print('excess emoply:', self.deficits)
return
src, excess = self.excesses.popleft()
if deficit < 1:
if not self.deficits:
print('defis emplty',self.excesses)
return
tgt, deficit = self.deficits.popleft()
@staticmethod
def pretty_time(t):
h, t = divmod(t,3600)
m, t = divmod(t,60)
return '{0:02}:{1:02}:{2:04.2f}'.format(int(h),int(m),t)
@staticmethod
def sign(x):
if x>0:
return 1
return -1
t = Test('bw_src.jpg','bw_tgt.jpg')
t.balance()
t.srcImg.show()
t.srcImg.save('sets_test_their_frame.png', format='PNG')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment