Skip to content

Instantly share code, notes, and snippets.

@rossgoodwin
Forked from 1328/dsklfj
Created June 12, 2014 00:02
Show Gist options
  • Save rossgoodwin/f9024422e8718c2ae606 to your computer and use it in GitHub Desktop.
Save rossgoodwin/f9024422e8718c2ae606 to your computer and use it in GitHub Desktop.
import random
import time
import collections
import itertools
import numpy as np
from pprint import pprint
from PIL import Image
class Test(object):
'''switched to a class, to avoid globals'''
def __init__(self, source_image, target_image):
self.srcImg = Image.open(source_image)
self.tgtImg = Image.open(target_image)
self.srcPix = self.srcImg.load()
self.tgtPix = self.tgtImg.load()
#Get histograms of the images
#only take the first 256 values for now since they're B&W
self.src_hist = self.srcImg.histogram()[:256]
self.tgt_hist = self.tgtImg.histogram()[:256]
self._build_map()
self._build_bins()
def _build_map(self):
'''builds the same reverse map for pixel, but uses defaultdict'''
self.pxls_map = collections.defaultdict(set)
width, height = self.srcImg.size[0],self.srcImg.size[1]
for i,j in itertools.product(range(width),range(height)):
self.pxls_map[self.srcPix[i,j][0]].add((i,j))
def _build_bins(self):
'''let's make the bins deques, which should be a bit faster'''
'''particularly as the balancing code no longer uses slices'''
'''slices are slow!'''
self.excesses = collections.deque()
self.deficits = collections.deque()
for i in range(256):
delta = self.src_hist[i]-self.tgt_hist[i]
if delta>0:
self.excesses.append((i,delta))
if delta<0:
self.deficits.append((i,abs(delta)))
def shift_pixels(self, src, tgt, n = 1):
'''shift n number of pixels from src to tgt'''
move = set(random.sample(self.pxls_map[src], n))
self.pxls_map[src] = self.pxls_map[src] - move
self.pxls_map[tgt] = self.pxls_map[tgt].union(move)
def smooth(self, src, tgt, n = 1):
'''shifts group of pixels, but by increments of 1/-1'''
'''pixels in each shift are pulled out randomly for smoothing'''
delta = tgt - src
incriment = self.sign(delta)
for i in range(abs(delta)):
nval = src + incriment
self.shift_pixels(src, nval, n)
src = nval
def rebuild_from_map(self):
for val,pixls in self.pxls_map.items():
for pxl in pixls:
self.srcPix[pxl] = (val, val, val)
def balance(self):
'''balance the histograms from src to tgt'''
start = time.time()
count =0
# first we prime the loop
src, excess = self.excesses.popleft()
tgt, deficit = self.deficits.popleft()
while True:
frame = min(excess,deficit)
if count % 10 == 0:
print('{} -> src {} target {} frame {}'.format(
self.pretty_time(time.time()-start),
src, tgt, frame))
self.smooth(src, tgt, frame)
# now increment the loop
excess -= frame
deficit -= frame
count += 1
if excess < 1:
if not self.excesses:
print('excess emoply:', self.deficits)
break
src, excess = self.excesses.popleft()
if deficit < 1:
if not self.deficits:
print('defis emplty',self.excesses)
break
tgt, deficit = self.deficits.popleft()
print('rebuilding pic')
self.rebuild_from_map()
@staticmethod
def pretty_time(t):
h, t = divmod(t,3600)
m, t = divmod(t,60)
return '{0:02}:{1:02}:{2:04.2f}'.format(int(h),int(m),t)
@staticmethod
def sign(x):
if x>0:
return 1
return -1
t = Test('bw_src.jpg','bw_tgt.jpg')
t.balance()
t.srcImg.show()
t.srcImg.save('unionstest.png', format='PNG')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment