Skip to content

Instantly share code, notes, and snippets.

@brikeats
Last active February 7, 2017 09:36
Show Gist options
  • Save brikeats/b454ff20f3e4d060198f to your computer and use it in GitHub Desktop.
Save brikeats/b454ff20f3e4d060198f to your computer and use it in GitHub Desktop.
A library of python functions that I've found helpful for image processing.
import time
import itertools
import platform
import subprocess
from functools import partial
from scipy import optimize, ndimage
from scipy.integrate import simps
from scipy.interpolate import splev, splprep
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from numpy import ma
from matplotlib.colors import ListedColormap
import cv2
from pykalman import KalmanFilter
import pims
from skimage import feature, filters, measure
"""
Some functions that I've found helpful. I'm sure this is reinventing the wheel,
but whatever.
"""
def isplit(iterable, splitters):
"""
Splits a list about a particular element into a list-of-lists
thanks to this guy: http://stackoverflow.com/questions/4322705/split-a-list-into-nested-lists-on-a-value
"""
return [list(g) for k,g in itertools.groupby(iterable,lambda x:x in splitters) if not k]
def partition(list_, num):
# Partition list as evenly as possible.
part_sizes = [len(list_) / int(num) for _ in range(num)]
remainder = len(list_) % num
for part_num in range(remainder):
part_sizes[part_num] += 1
end_inds = np.cumsum(part_sizes)
sta_inds = [end - size for end, size in zip(end_inds, part_sizes)]
return [list_[sta_ind:end_ind] for sta_ind, end_ind in zip(sta_inds, end_inds)]
def partition_indices(list_, num):
# Partition list as evenly as possible, return zipped indices.
part_sizes = [len(list_) / int(num) for _ in range(num)]
remainder = len(list_) % num
for part_num in range(remainder):
part_sizes[part_num] += 1
end_inds = np.cumsum(part_sizes)
sta_inds = [end - size for end, size in zip(end_inds, part_sizes)]
return zip(sta_inds, end_inds)
def hist_match(im, ref_im):
if np.squeeze(im).ndim == 2:
im = np.expand_dims(im, axis=2)
ref_im = np.expand_dims(ref_im, axis=2)
out_im = np.empty_like(im)
for chan_num in range(im.shape[2]):
chan, ref_chan = im[...,chan_num], ref_im[...,chan_num]
cdf, bin_centers = cumulative_distribution(chan)
ref_cdf, ref_bin_centers = cumulative_distribution(ref_chan)
chan_percentiles = np.interp(chan.flat, bin_centers, cdf)
flat_out = np.interp(chan_percentiles, ref_cdf, ref_bin_centers)
out = flat_out.reshape(chan.shape)
out_im[..., chan_num] = out
return np.squeeze(out_im)
def read_config_file(fn):
"""
Read a simple config file. More complex configs should be in xml or yaml.
Values should be in format "key=value" or "key value". Values are converted
to int's or float's if possible; if not, it's a string.
"""
config = dict()
with open(fn) as f:
for line in f.readlines():
line = line.rstrip() # remove newline character
# skip blank lines and comment lines
if not line or line[0]=='#':
continue
# remove inline comments
ind = line.find('#')
if ind != -1:
line = line[:ind].strip()
# parse
if '=' in line:
line_parts = line.split('=')
else:
line_parts = line.split(' ')
if len(line_parts) != 2:
print 'Could not parse line', line, ', skipping...'
continue
# cast to appropriate type
key, val_str = line_parts[0].strip(), line_parts[1].strip()
try:
config[key] = int(val_str)
except ValueError:
try:
config[key] = float(val_str)
except ValueError:
config[key] = val_str.replace('"','')
return config
def print_image_properties(im):
if not isinstance(im, (np.ndarray)):
raise TypeError('print_image_properties only handles 2D or 3D numpy arrays')
try:
nchan = im.shape[2]
except IndexError:
nchan = 1
print
print 'image size: %i x %i' % (im.shape[0], im.shape[1])
print 'num. channels: %i' % nchan
print 'dtype: %s' % im.dtype
print 'min, max: %.1f, %.1f' % (np.min(im), np.max(im))
print 'mean, stdev: %.1f, %.1f' % (np.mean(im), np.std(im))
def vidshow(frames, start_frame=0, end_frame=-1, fps=10, **kwargs):
# similar to imshow, but for arrays with a time dimension
if not isinstance(frames, np.ndarray):
raise TypeError('vidshow requires a 3D or 4D numpy array')
if len(frames.shape) == 3:
is_color = False
elif len(frames.shape) == 4:
is_color = True
if frames.shape[3] != 3:
raise IndexError('vidshow only knows how to display 3-channel frames')
else:
raise IndexError('vidshow requires a 3D or 4D numpy array')
frames = frames[start_frame:end_frame]
plt.gray()
im = plt.imshow(frames[0], **kwargs)
for frame_num, frame in enumerate(frames):
im.set_data(frame)
plt.pause(1./fps)
plt.show()
def volshow(vol, sl_dim=0, mask=None, **kwargs):
# FIXME: doesn't work when sl_dim=-1
"""
This function displays slices into a 3-dimensional numpy array. Optionally, the user can
pass a binary mask of the same size to be displayed on top of the volume. This is intended
to be a 3D equivalent of pyplot.imshow.
Usage:
the up and down arrows flip through the slices
'd' toggles the slice dimension (eg, axial->sagittal->coronal)
'q' or escape closes the figure and exits the function
In an ipython notebook session (Ubuntu, firefox), I need to use the line `%matplotlib tk` to
get the interactivity to work correctly.
"""
class Slicer:
def __init__(self, vol, sl_dim, mask=None, alpha=0.5, color='red', **kwargs):
self.vol = vol
self.slice_dim = sl_dim
self.color = color
self.alpha = alpha
self.slice_nums = [sz/2 for sz in vol.shape]
self.slice_num = self.slice_nums[self.slice_dim]
indx = [Ellipsis, Ellipsis, Ellipsis]
indx[self.slice_dim] = self.slice_num
sl = np.squeeze(vol[indx])
self.fig, self.ax = plt.subplots()
self.cid = self.fig.canvas.mpl_connect('key_press_event', self.on_key)
self.plot = plt.imshow(sl, **kwargs)
# TODO: check that mask is okay (right shape, binary)
self.has_mask = mask is not None
if self.has_mask:
mask = mask > 0
self.mask_vol = np.ma.masked_where(~mask, mask)
mask_sl = np.squeeze(self.mask_vol[indx])
self.mask_vol_plot = plt.imshow(mask_sl, alpha=alpha, cmap=ListedColormap([color]))
plt.title('Slice %i of %i' % (self.slice_num, self.vol.shape[self.slice_dim]))
plt.axis('off')
plt.draw()
plt.show()
try:
self.fig.canvas.start_event_loop(timeout=-1)
except: # this may throw a TclError, depending on the backend
pass
def replot(self):
# create a new imshow plot. Used when toggling slice dimension
self.plot.remove()
indx = [Ellipsis, Ellipsis, Ellipsis]
indx[self.slice_dim] = self.slice_num
sl = np.squeeze(vol[indx])
self.plot = plt.imshow(sl, cmap='gray')
if self.has_mask:
self.mask_vol_plot.remove()
mask_sl = np.squeeze(self.mask_vol[indx])
self.mask_vol_plot = plt.imshow(mask_sl, alpha=self.alpha, cmap=ListedColormap([self.color]))
plt.title('Slice %i of %i' % (self.slice_num+1, self.vol.shape[self.slice_dim]))
plt.draw()
def redraw(self):
# refresh the data to display a new slice (along the same dimension)
indx = [Ellipsis, Ellipsis, Ellipsis]
indx[self.slice_dim] = self.slice_num
sl = np.squeeze(vol[indx])
self.plot.set_data(sl)
if self.has_mask:
mask_sl = np.squeeze(self.mask_vol[indx])
self.mask_vol_plot.set_data(mask_sl)
plt.title('Slice %i of %i' % (self.slice_num+1, self.vol.shape[self.slice_dim]))
plt.draw()
def prev_slice(self, _):
self.slice_num = max([0, self.slice_num - 1])
self.redraw()
def next_slice(self, _):
self.slice_num = min([self.slice_num + 1, self.vol.shape[self.slice_dim]-1])
self.redraw()
def toggle_slice_dim(self, _):
self.slice_nums[self.slice_dim] = self.slice_num # save the current slice index
self.slice_dim = (self.slice_dim+1) % len(self.vol.shape)
self.slice_num = self.slice_nums[self.slice_dim]
self.replot()
def on_key(self, event):
if event.key in ['q', 'Q', 'escape']:
plt.close()
elif event.key == 'up':
self.next_slice(event)
elif event.key == 'down':
self.prev_slice(event)
elif event.key == 'd':
self.toggle_slice_dim(event)
slicer = Slicer(vol, sl_dim, mask)
def volcompare(vol1, vol2, sl_dim=0, **kwargs):
# FIXME: doesn't work when sl_dim=-1
"""
This function displays slices into a 3-dimensional numpy array. Optionally, the user can
pass a binary mask of the same size to be displayed on top of the volume. This is intended
to be a 3D equivalent of pyplot.imshow.
Usage:
the up and down arrows flip through the slices
'd' toggles the slice dimension (eg, axial->sagittal->coronal)
'q' or escape closes the figure and exits the function
In an ipython notebook session (Ubuntu, firefox), I need to use the line `%matplotlib tk` to
get the interactivity to work correctly.
"""
class Slicer:
def __init__(self, vol1, vol2, sl_dim, mask=None, alpha=0.5, color='red', **kwargs):
if len(vol1) != len(vol2):
raise ValueError('The volumes must have the same number of slices.')
self.vol1 = vol1
self.vol2 = vol2
self.slice_dim = sl_dim
self.color = color
self.alpha = alpha
self.slice_num = vol.shape[sl_dim]/2
indx = [Ellipsis, Ellipsis, Ellipsis]
indx[self.slice_dim] = self.slice_num
sl1 = np.squeeze(vol1[indx])
sl2 = np.squeeze(vol2[indx])
self.fig, self.ax = plt.subplots()
self.cid = self.fig.canvas.mpl_connect('key_press_event', self.on_key)
plt.subplot(121)
self.plot1 = plt.imshow(sl1, **kwargs)
plt.axis('off')
plt.subplot(122)
self.plot2 = plt.imshow(sl2, **kwargs)
plt.axis('off')
plt.title('Slice %i of %i' % (self.slice_num, self.vol1.shape[self.slice_dim]))
plt.draw()
plt.show()
try:
self.fig.canvas.start_event_loop(timeout=-1)
except: # this may throw a TclError, depending on the backend
pass
def redraw(self):
# refresh the data to display a new slice (along the same dimension)
indx = [Ellipsis, Ellipsis, Ellipsis]
indx[self.slice_dim] = self.slice_num
sl1 = np.squeeze(vol1[indx])
sl2 = np.squeeze(vol2[indx])
self.plot1.set_data(sl1)
self.plot2.set_data(sl2)
plt.title('Slice %i of %i' % (self.slice_num+1, self.vol1.shape[self.slice_dim]))
plt.draw()
def prev_slice(self, _):
self.slice_num = max([0, self.slice_num - 1])
self.redraw()
def next_slice(self, _):
self.slice_num = min([self.slice_num + 1, self.vol1.shape[self.slice_dim]-1])
self.redraw()
def on_key(self, event):
if event.key in ['q', 'Q', 'escape']:
plt.close()
elif event.key == 'up':
self.next_slice(event)
elif event.key == 'down':
self.prev_slice(event)
Slicer(vol1, vol2, sl_dim, mask)
def flip_dim(a, axis=0):
# like numpy.fliplr or numpy.flipud but works on arbitrary dimension
idx = [slice(None)]*len(a.shape)
idx[axis] = slice(None, None, -1)
return a[idx]
def tiff_to_ndarray(fn):
"""
Load a tiff stack as 3D numpy array.
You must have enough RAM to hold the whole movie in memory.
"""
frames = pims.TiffStack(fn)
num_frames = len(frames)
sz = frames.frame_shape
arr = np.empty((num_frames, sz[0], sz[1]), dtype=frames.pixel_type)
for frame_num, frame in enumerate(frames):
arr[frame_num, :, :] = np.fliplr(np.swapaxes(frame, 0, 1))
return arr
def imshow_overlay(im, mask, alpha=0.5, color='red', **kwargs):
"""Show semi-transparent red mask over an image"""
my_cmap = ListedColormap([color])
my_cmap._init()
my_cmap._lut[:-1,-1] = alpha
print my_cmap._lut
mask = mask > 0
mask = ma.masked_where(~mask, mask)
plt.imshow(im, **kwargs)
plt.imshow(mask, alpha=alpha, cmap=my_cmap)
ax = plt.gca()
ax.patch.set_alpha(alpha)
class AviReader:
"""Read a file as an immutable, iterable, sliceable sequence of frames."""
def __init__(self, fn):
self.cap = cv2.VideoCapture(fn)
self.first_frame = 0
self.last_frame = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.fn = fn
self.num_frames = len(self)
self.frame_rate = int(self.cap.get(cv2.CAP_PROP_FPS))
self.frame_size = (self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT),
self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
def __len__(self):
return self.last_frame - self.first_frame
def __iter__(self):
self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.first_frame)
return self
def next(self):
current_frame = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))
if current_frame >= self.last_frame:
raise StopIteration
else:
_, frame = self.cap.read()
return frame
def __str__(self):
repr_str = 'AviReader instance from '+self.fn+': '
repr_str += str(len(self))+' frames of shape '+str(self.frame_size())
repr_str += ', ' +str(self.frame_rate())+' fps'
return repr_str
def __getitem__(self, index):
# FIXME: doesn't handle step (stride), nor negative slice indices
if isinstance(index, int): # single frame
if index < 0:
index = len(self) + index
if index + self.first_frame > self.last_frame:
raise IndexError
self.cap.set(cv2.CAP_PROP_POS_FRAMES, index + self.first_frame)
ok, frame = self.cap.read()
if ok:
return frame
else:
raise IndexError
elif isinstance(index, slice): # slice
self.first_frame = index.start
# if index.stop is not None: # FIXME: why doesn't this work with [sta:] indexing?
if index.stop <= self.last_frame:
self.last_frame = index.stop
return self
else:
raise TypeError('Avi indices should be integer or slices')
class TifReader:
"""An immutable, iterable sequence of frames."""
def __init__(self, fn):
self.fn = fn
self.im = Image.open(fn)
self.first_frame = 0
idx = 0
while True:
try:
self.im.seek(idx)
except EOFError:
self.last_frame = idx
break
idx += 1
self.num_frames = self.last_frame
self._total_frames = self.num_frames
self.frame_size = self.im.size
def __len__(self):
return self.last_frame - self.first_frame
def __iter__(self):
self.iter_frame = 0
return self
def next(self):
if self.iter_frame + self.first_frame >= self.last_frame:
raise StopIteration
else:
self.im.seek(self.first_frame + self.iter_frame)
self.iter_frame += 1
return np.array(self.im)
def __str__(self):
repr_str = self.__class__.__name__+' instance from '+self.fn+': '
repr_str += str(len(self))+' frames of shape '+str(self.frame_size)
return repr_str
def __getitem__(self, index):
# FIXME: doesn't handle step (stride), nor slicing w negative indices
if isinstance(index, int): # single frame
if index < 0:
index = len(self) + index
if index + self.first_frame > self.last_frame:
raise IndexError
self.im.seek(index + self.first_frame)
return np.array(self.im)
elif isinstance(index, slice): # slice
self.first_frame = index.start
# if index.stop is not None: # FIXME: why doesn't this work with [sta:] indexing?
if index.stop <= self.last_frame:
self.last_frame = index.stop
return self
else:
raise TypeError('Avi indices should be integer or slices')
@property
def shape(self):
sz = self[0].shape
return (sz[0], sz[1], len(self))
def write_video(frames, filename, fps=20):
"""
Uses avconv to write a 3D numpy array to a video file.
Currently only supports grayscale arrays.
"""
# On Mac systems, copy ffmeg binaries to your PATH (http://ffmpegmac.net/)
if platform.system() == 'Windows':
err_str = 'Don\'t know how to write a movie for %s platform' % platform.system()
raise NotImplementedError(err_str)
if len(frames.shape) == 4:
pix_fmt = 'rgb24'
else:
pix_fmt = 'gray'
# normalize
max_pix_val = np.percentile(frames, 99.9)
if frames.dtype in (np.bool, bool):
frames = frames.astype(np.uint8)
frames -= frames.min()
frames[frames>max_pix_val] = max_pix_val
if max_pix_val > 0:
frames *= 255. / max_pix_val
frames = frames.astype(np.uint8)
# figure out which av program is installed
program_name = ''
try:
subprocess.check_call(['avconv', '-h'])
program_name = 'avconv'
except OSError:
try:
subprocess.check_call(['ffmpeg', '-h'])
program_name = 'ffmpeg'
except OSError:
pass
if not program_name:
raise OSError('Can\'t find avconv or ffmpeg')
# prepare pipe to av converter program
size_str = '%ix%i' % (frames.shape[1], frames.shape[2])
cmd = [program_name,
'-y', # (optional) overwrite output file if it exists
'-f', 'rawvideo',
'-vcodec','rawvideo',
'-s', size_str, # size of one frame
'-pix_fmt', pix_fmt,
'-r', str(fps), # frames per second
'-i', '-', # input comes from a pipe
'-an', # no audio
'-qscale', '1',
'-vcodec','mjpeg',
filename]
pipe = subprocess.Popen(cmd, stdin=subprocess.PIPE)
# write frames
for frame in frames:
frame = np.fliplr(frame)
pipe.stdin.write(frame.tostring())
pipe.stdin.close()
pipe.wait()
def label_im_to_color(im, cmap='jet'):
im = im.astype(float)
im -= np.min(im)
im /= np.max(im)
cmap = plt.cm.get_cmap(cmap)
return cmap(im)
class KalmanSmoother2D:
def __init__(self, x_noise, y_noise, smoothness_x=1, smoothness_y=1):
dt = 1
# model
F = np.eye(4)
F[0, 2] = dt
F[1, 3] = dt
H = np.zeros((2, 4))
H[0, 0] = 1
H[1, 1] = 1
R = np.zeros((2, 2))
R[0, 0] = x_noise * x_noise
R[1, 1] = y_noise * y_noise
sigma_ax, sigma_ay = 1, 1
G = np.zeros((4, 1))
G[2] = sigma_ax*dt
G[3] = sigma_ay*dt
Q = np.transpose(G)*G
Q[0, 1] = 0; Q[1, 0] = 0
Q[0, 3] = 0; Q[3, 0] = 0
Q[1, 2] = 0; Q[2, 1] = 0
Q[2, 3] = 0; Q[3, 2] = 0
# initialize filter
self.kf = KalmanFilter()
self.kf.transition_matrices = F
self.kf.observation_matrices = H
self.kf.transition_covariance = Q
self.kf.observation_covariance = R
# default initial state
# TODO: maybe use first measurement as default?
self.kf.initial_state_mean = np.zeros((4,))
self.kf.initial_state_covariance = np.zeros((4, 4))
# TODO: get innovations?
def set_initial_state(self, initial_mean, initial_covariance=np.zeros((4,4))):
if initial_mean.shape[0] == 2:
print 'initial velocity unspecified, assuming v0 = 0'
initial_mean = np.array([initial_mean[0], initial_mean[1], 0, 0])
self.kf.initial_state_mean = initial_mean
self.kf.initial_state_covariance = initial_covariance
def set_measurements(self, measurements):
self.smooth_means, self.smooth_covs = self.kf.smooth(measurements)
def get_smoothed_measurements(self):
return self.smooth_means[:,0:2]
def get_velocities(self):
return self.smooth_means[:,2:]
def get_covariances(self):
return self.smooth_covs
def gray2rgb(im):
im = im.astype(np.float)
im /= im.max()
im = np.round(255*im)
im = im.astype(np.uint8)
w, h = im.shape
ret = np.empty((w, h, 3), dtype=np.uint8)
ret[:, :, 0] = im
ret[:, :, 1] = im
ret[:, :, 2] = im
return ret
def enhance_ridges(frame, mask=None):
"""Detect ridges (larger hessian eigenvalue)"""
blurred = filters.gaussian_filter(frame, 2)
Hxx, Hxy, Hyy = feature.hessian_matrix(blurred, sigma=4.5, mode='nearest')
ridges = feature.hessian_matrix_eigvals(Hxx, Hxy, Hyy)[0]
return np.abs(ridges)
def mask_to_boundary_pts(mask, pt_spacing=10):
"""
Convert a binary image containing a single object to a set
of 2D points that are equally spaced along the object's contour.
"""
# interpolate boundary
boundary_pts = measure.find_contours(mask, 0)[0]
tck, u = splprep(boundary_pts.T, u=None, s=0.0, per=1)
u_new = np.linspace(u.min(), u.max(), 1000)
x_new, y_new = splev(u_new, tck, der=0)
# get equi-spaced points along spline-interpolated boundary
x_diff, y_diff = np.diff(x_new), np.diff(y_new)
S = simps(np.sqrt(x_diff**2 + y_diff**2))
N = int(round(S/pt_spacing))
u_equidist = np.linspace(0, 1, N+1)
x_equidist, y_equidist = splev(u_equidist, tck, der=0)
return np.array(zip(x_equidist, y_equidist))
def snake_energy(flattened_pts, edge_dist, alpha, beta):
"""
Compute the energy associated with a proposed contour. The contour is defined
by N 2-dimensional points. The energy is comprised of external energy, which is
derived from the supplied distance images; and internal energy, which is computed
based only on the characteristics of the contour. Note that the image
interpolation uses only 1st-order splines, which increases speed at the
expense of accuracy.
The current implementation was created for a closed contour. An open contour
formulation should replace the periodic 'np.roll' calls by non-periodic end-off
shifts.
Args:
flattened_pts ((2*N,)-shaped numpy array): A flattened list of the contour
points, ordered so that adjacent points are consecutive in the list.
Can be created by calling arr_2d.ravel() on an ordered (N,2)-shaped array
of points.
edge_dist (2D numpy array): Distance transform of binary edge detector.
alpha (float): The relative weight given to unevenly spaced points. A higher
value encourages evenly-spaced points. Should be > 0.
beta (float): The weight given to local curvature. A higher value encourages
flat contours.
Returns:
float: Image energy. (lower is better)
"""
pts = np.reshape(flattened_pts, (len(flattened_pts)/2, 2))
# external energy (favors low values of distance image)
dist_vals = ndimage.interpolation.map_coordinates(edge_dist, [pts[:,0], pts[:,1]], order=1)
edge_energy = np.sum(dist_vals)
external_energy = edge_energy
# spacing energy (favors equi-distant points)
prev_pts = np.roll(pts, 1, axis=0)
next_pts = np.roll(pts, -1, axis=0)
displacements = pts - prev_pts
point_distances = np.sqrt(displacements[:,0]**2 + displacements[:,1]**2)
mean_dist = np.mean(point_distances)
spacing_energy = np.sum((point_distances - mean_dist)**2)
# curvature energy (favors smooth curves)
curvature_1d = prev_pts - 2*pts + next_pts
curvature = (curvature_1d[:,0]**2 + curvature_1d[:,1]**2)
curvature_energy = np.sum(curvature)
return external_energy + alpha*spacing_energy + beta*curvature_energy
def fit_snake(pts, edge_dist, alpha=0.5, beta=0.25, nits=100, point_plot=None):
"""
Fit an active contour model (aka snakes) based on some initial points and a
feature image. Given a list of points as a starting point, it evolves the points
until they sit at a minimum of the energy function 'snake_energy'. This function
is not especially good at avoiding local minima, and it does not adapt the number
of points in the contour. Therefore, it is most useful for "polishing up" and
already good initial guess.
Args:
pts ((N,2)-shaped numpy array): A list of the contour points, ordered so that
adjacent points are consecutive in the list (ie, in clockwise or counter-
clockwise order).
edge_dist (2D numpy array): Distance transform of binary edge detector.
alpha (float): The weight given to unevenly spaced points. A higher value encourages
evenly-spaced points. Should be > 0.
beta (float): The weight given to local curvature. A higher value encourages
flat contours.
point_plot (matplotlib.lines.Line2D, optional): A matplotlib line object for
the given points. The Line2D data will be updated on each iteration to
provide an animation of the optimization.
Returns:
(N,2)-shaped numpy array: The points after minimization.
"""
if point_plot:
def callback_function(new_pts):
callback_function.nits += 1
y = new_pts[0::2]
x = new_pts[1::2]
point_plot.set_data(x,y)
plt.title('%i iterations' % callback_function.nits)
point_plot.figure.canvas.draw()
time.sleep(0.1)
callback_function.nits = 0
else:
callback_function = None
# optimize
cost_function = partial(snake_energy, alpha=alpha, beta=beta, edge_dist=edge_dist)
options = {'disp':False}
options['maxiter'] = nits # FIXME: check convergence
method = 'BFGS' # 'BFGS', 'CG', or 'Powell'. 'Nelder-Mead' has very slow convergence
res = optimize.minimize(cost_function, pts.ravel(), method=method, options=options, callback=callback_function)
optimal_pts = np.reshape(res.x, (len(res.x)/2, 2))
return optimal_pts
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment