Skip to content

Instantly share code, notes, and snippets.

@ambuc
Created October 20, 2021 03:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ambuc/3374d271444a303612d86d08801291be to your computer and use it in GitHub Desktop.
Save ambuc/3374d271444a303612d86d08801291be to your computer and use it in GitHub Desktop.
Simulated Annealing demo
#!/usr/bin/python3
from PIL import Image
from scipy import optimize
from skimage import metrics
import cairo
import cv2
import numpy as np
import random
import scipy
import skimage
import tempfile
# Data class to hold lists of points, transform them, and print them to different kinds of images.
class DrawingData():
def __init__(self, list_of_points):
self._list_of_points = list_of_points
self._degrees_of_freedom = sum(len(p) for p in self._list_of_points)
# The width and height (in pixels) of the image.
self._dimension = 90
def transform(self, l):
# Apply a mutator lambda |l| to each point.
self._list_of_points = [l(p) for p in self._list_of_points]
def degreesOfFreedom(self):
# Return the number of degrees of freedom in the system.
return self._degrees_of_freedom
def applyDeltas(self, list_of_deltas) -> "DrawingData":
# Apply a list of floating-point deltas to each x and y coordinate in
# order. This application order must be consistent so that annealing
# converges.
return DrawingData([
(x + list_of_deltas.pop(), y + list_of_deltas.pop())
for (x, y) in self._list_of_points
])
def mutate(self, n=5) -> "DrawingData":
# Take each point and randomly displace it by up to |n| pixels.
deltas = [random.uniform(-n, n)
for _ in range(2*self.degreesOfFreedom())]
return self.applyDeltas(deltas)
def writeToPng(self, list_of_filepaths):
with cairo.ImageSurface(cairo.Format.RGB24, self._dimension, self._dimension) as surface:
ct = cairo.Context(surface)
ct.rectangle(0, 0, self._dimension, self._dimension)
ct.set_source_rgb(1, 1, 1)
ct.fill()
ct.set_line_width(10)
ct.set_source_rgb(0, 0, 0)
ct.move_to(*self._list_of_points[0])
for x, y in self._list_of_points[1:]:
ct.line_to(x, y)
ct.stroke()
for filepath in list_of_filepaths:
surface.write_to_png(filepath)
def writeToSvg(self, list_of_filepaths):
for filepath in list_of_filepaths:
with cairo.SVGSurface(filepath, 90, 90) as surface:
surface.set_document_unit(cairo.SVGUnit.PX)
ct = cairo.Context(surface)
ct.set_source_rgb(1, 1, 1)
ct.paint()
ct.set_line_width(10)
ct.set_source_rgb(0, 0, 0)
ct.move_to(*self._list_of_points[0])
for x, y in self._list_of_points[1:]:
ct.line_to(x, y)
ct.stroke()
def blur_image(image, s=(33, 33), f=10, n=0.5):
# Take an image, blur it, and return a blend of the original and blurred images.
# I think that this is slightly superior to the completely blurred image,
# since it preserves a sharp cliff at the boundary of the unblurred shapes.
return cv2.addWeighted(image, n, cv2.GaussianBlur(image, s, f), 1-n, 0.0)
def main():
# Pointwise drawing of an ampersand. I sketched this on paper on an 8x8 grid, so...
target_drawing = DrawingData([(6.3, 6.3), (5, 6), (2, 3), (1.5, 2), (2, 1), (3, .5), (4, .75), (4.5, 1.5),
(4, 2.5), (1.75, 5), (2, 6), (2.5, 6.5), (3.5, 6.25), (5.5, 5), (7, 3), ])
# ...we have to scale it up to fit on an 80x80 image.
target_drawing.transform(lambda p: (3 + (p[0]*10), 7 + (p[1]*10)))
target_drawing.writeToPng(['target.png'])
target_drawing.writeToSvg(['target.svg'])
# Keep track of the frame number, so that we can render a gif of only every
# tenth frame. (Otherwise we store too much data in memory, and the GIF
# takes too long to watch.)
global frame_num
frame_num = 0
frame_files = []
# This is the function scipy.optimize.minimize is minimizing, so it has to
# return fitness as a floating-point number.
def eval(ts, initial_guess_drawing, target_file):
global frame_num
frame_num += 1
guess_file = tempfile.SpooledTemporaryFile(suffix="png")
guess_drawing = initial_guess_drawing.applyDeltas(list(ts.tolist()))
guess_drawing.writeToPng([guess_file, "guess.png"])
guess_drawing.writeToSvg(["guess.svg"])
# Open both images, blur them, and return the mean squared error.
mse = skimage.metrics.mean_squared_error(
blur_image(cv2.imdecode(np.frombuffer(
guess_file._file.getbuffer(), np.uint8), -1)),
blur_image(cv2.imdecode(np.frombuffer(
target_file._file.getbuffer(), np.uint8), -1)))
# The image contents of tenth frame gets moved into a big list instead
# of garbage collected.
if frame_num % 10 == 0:
frame_files.append(guess_file)
else:
del guess_file
return mse
# This is a bit nontraditional -- we store target and guess images both on
# disk (for human analysis later) and in spooled temporary files, for very
# fast inmemory access.
with tempfile.SpooledTemporaryFile(suffix="png") as target_file:
target_drawing.writeToPng([target_file])
with tempfile.SpooledTemporaryFile(suffix="png") as init_guess_file:
init_guess_drawing = target_drawing.mutate(n=5)
init_guess_drawing.writeToPng([init_guess_file, 'init_guess.png'])
init_guess_drawing.writeToSvg(['init_guess.svg'])
# It typically takes 3-8 rounds before one converges. If a round
# doesn't converge within 1000 iterations (see `maxiter` below), we
# make another random guess and try some more.
round = 0
while True:
round += 1
# Reset our list of frames.
frames = []
for frame_file in frame_files:
del frame_file
frame_num = 0
# Make another random guess. These guesses are drawn from the
# initial random guess, so they are one step further removed
# from our source image.
random_start_drawing = init_guess_drawing.mutate(n=5)
r = scipy.optimize.minimize(
fun=eval,
# Since the set of parameters we're tuning are x/y
# translations for the points in the image, our intial state
# should be a vector of zeros.
x0=[0]*2*random_start_drawing.degreesOfFreedom(),
# COBYLA was experimentally selected. There might be better methods.
method='COBYLA',
args=(random_start_drawing, target_file),
# Also experimentally selected.
tol=0.000001,
options={
# Also experimentally selected.
'maxiter': 1000,
}
)
# If our image is within 20 pixels MSE (also experimentally
# selected), it is close enough.
if r.fun < 20:
print("Num Trials: ", round)
print("Success: ", r.success)
print("Minimum error: ", r.fun)
print("# Trials", r.nfev)
print("Tuned parameters", r.x)
break
else:
print("Minimum error: ", r.fun)
# Open each of the spooled temp files as a PIL image.
frames = [Image.open(f._file) for f in frame_files]
# Trim the back half of the list. We spent a lot of time fine-tuning this,
# which is important but uninteresting to watch in a GIF.
frames = frames[:int(len(frames)/4)]
# Render the list of images as a GIF.
frames[0].save('out.gif', save_all=True, append_images=frames[1:], loop=0)
# And finally clean up the images manually.
for f in frame_files:
del f
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment