Skip to content

Instantly share code, notes, and snippets.

@john-bradshaw
Last active October 15, 2017 16:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save john-bradshaw/14b520ebf078ee498929e8fe2d1a9bb1 to your computer and use it in GitHub Desktop.
Save john-bradshaw/14b520ebf078ee498929e8fe2d1a9bb1 to your computer and use it in GitHub Desktop.
Model Based Morphing for a Two dimensional Gaussian (reproduction of Saul, L.K. and Jordan, M.I., 1997 Fig 1a)
"""
Code for Figure 1a of
Saul, L.K. and Jordan, M.I., 1997. A variational principle for model-based morphing.
In Advances in Neural Information Processing Systems (pp. 267-273).
"""
import numpy as np
from scipy.integrate import odeint
from scipy.optimize import fsolve
from scipy import stats
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy import linalg as sla
from matplotlib.widgets import Slider, Button, RadioButtons
plt.style.use("ggplot")
END_POINTS = '#f47530'
END_POINTS_OUTER = '#003859'
LINE = '#f4a742'
DIAG_VAR = True
NUM_DIM = 2
MARKER_SIZE = 10
def func(x_concat, t, Var, ell):
# ODE defined in equation 9
x = x_concat[:NUM_DIM][:, None]
v = x_concat[NUM_DIM:][:, None]
if DIAG_VAR:
# We will invert the precision matrix if it is diagonal as numerical stable (assuming bounds on precision)
M = np.diag(np.reciprocal(np.diagonal(Var)))
a_coeff = 0.5 * (x.T @ M @ x) + ell
new_a = (x - v * (x.T @ M @ v)) / a_coeff
else:
L = sla.cholesky(Var, lower=True)
intermed1 = sla.solve_triangular(L, x, lower=True)
intermed2 = sla.cho_solve((L, True), v)
a_coeff = 0.5 * ( intermed1.T @ intermed1) + ell
new_a = (x - v * (x.T @ intermed2)) / a_coeff
derivs = np.vstack((v, new_a))
return np.squeeze(derivs)
def create_colutions_func(x_0, x_1, tspan):
def solutions_func(ell, var):
# We will solve the Boundary value problem using a shooting method.
Var = var * np.array([[1., 0.], [0., 1.]], dtype=np.float32)
def objective(u2_0):
U = odeint(func, np.concatenate((x_0, u2_0)), tspan, args=(Var, ell))
u1 = U[:, :NUM_DIM]
return u1[-1] - x_1
u2_0 = fsolve(objective, x_1)
x0 = np.array(np.concatenate((x_0, u2_0)), dtype=np.float32)
sol = odeint(func, x0, tspan, args=(Var, ell))
return sol
return solutions_func
def main():
x_0 = np.array([-1., -1.5], dtype=np.float32)
x_1 = np.array([-0.5, 1], dtype=np.float32)
tspan = np.linspace(0, 1, 101)
solution_func = create_colutions_func(x_0, x_1, tspan)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
initial_var = 0.2
initial_ell = 1.
MAX_VAR = 5
MIN_VAR = 0.01
MIN_PDF = 0.
MAX_PDF = 1. / (np.sqrt(2 * np.pi * MIN_VAR))
mvn = stats.multivariate_normal(np.zeros(x_0.size), initial_var * np.eye(x_0.size, dtype=float))
x_mesh = np.linspace(-2., 0.3)
y_mesh = np.linspace(-2, 1.5)
X, Y = np.meshgrid(x_mesh, y_mesh)
orig_size = X.shape
z = mvn.pdf(np.concatenate((X.flatten()[:, None], Y.flatten()[:, None]), axis=1))
Z = z.reshape(*orig_size)
#V = np.logspace(-1000, 2., 50, base=np.e)
V = np.linspace(np.min(Z), np.max(Z), 20)
cs = ax.contourf(X,Y,Z, V, cmap=plt.get_cmap("Greens"))
initial_soln = solution_func(initial_ell, initial_var)
solution_line, = ax.plot(initial_soln[:, 0], initial_soln[:, 1], '--', lw=3, color=LINE)
ax.plot(x_0[0], x_0[1], 'o', markersize=MARKER_SIZE, color=END_POINTS)
ax.plot(x_1[0], x_1[1], 'o', markersize=MARKER_SIZE, color=END_POINTS)
axell = plt.axes([0.1, 0.95, 0.65, 0.03])
axvar = plt.axes([0.1, 0.9, 0.65, 0.03])
axrenormalise = plt.axes([0.8, 0.88, 0.15, 0.1])
brenormaliseval = RadioButtons(axrenormalise, ('renormalise', 'nop'))
sell = Slider(axell, 'Ell', 0.01, 10.0, valinit=initial_ell)
svar = Slider(axvar, 'Variance', 0.01, 10., valinit=initial_var)
solution_line = [solution_line]
vprev = [V]
def update_ell(val):
ell_val = sell.val
var_val = svar.val
soln = solution_func(ell_val, var_val)
solution_line[0].set_data(soln[:, 0], soln[:, 1])
fig.canvas.draw_idle()
def update_var(val):
ax.clear()
ell_val = sell.val
var_val = svar.val
renormalise_val = brenormaliseval.value_selected
mvn_nwq = stats.multivariate_normal(np.zeros(x_0.size), var_val * np.eye(x_0.size, dtype=float))
z = mvn_nwq.pdf(np.concatenate((X.flatten()[:, None], Y.flatten()[:, None]), axis=1))
Z = z.reshape(*orig_size)
soln = solution_func(ell_val, var_val)
if renormalise_val == 'renormalise':
V = np.linspace(np.min(Z), np.max(Z), 20)
vprev[0] = V
else:
V = vprev[0]
cs = ax.contourf(X, Y, Z, V, cmap=plt.get_cmap("Greens"))
solution_line_new, = ax.plot(soln[:, 0], soln[:, 1], '--', lw=3, color=LINE)
ax.plot(x_0[0], x_0[1], 'o', markersize=MARKER_SIZE, color=END_POINTS)
ax.plot(x_1[0], x_1[1], 'o', markersize=MARKER_SIZE, color=END_POINTS)
solution_line[0] = solution_line_new
plt.draw()
sell.on_changed(update_ell)
svar.on_changed(update_var)
brenormaliseval.on_clicked(update_var)
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment