Last active
October 15, 2017 16:13
-
-
Save john-bradshaw/14b520ebf078ee498929e8fe2d1a9bb1 to your computer and use it in GitHub Desktop.
Model Based Morphing for a Two dimensional Gaussian (reproduction of Saul, L.K. and Jordan, M.I., 1997 Fig 1a)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Code for Figure 1a of | |
Saul, L.K. and Jordan, M.I., 1997. A variational principle for model-based morphing. | |
In Advances in Neural Information Processing Systems (pp. 267-273). | |
""" | |
import numpy as np | |
from scipy.integrate import odeint | |
from scipy.optimize import fsolve | |
from scipy import stats | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.mplot3d import Axes3D | |
from scipy import linalg as sla | |
from matplotlib.widgets import Slider, Button, RadioButtons | |
plt.style.use("ggplot") | |
END_POINTS = '#f47530' | |
END_POINTS_OUTER = '#003859' | |
LINE = '#f4a742' | |
DIAG_VAR = True | |
NUM_DIM = 2 | |
MARKER_SIZE = 10 | |
def func(x_concat, t, Var, ell): | |
# ODE defined in equation 9 | |
x = x_concat[:NUM_DIM][:, None] | |
v = x_concat[NUM_DIM:][:, None] | |
if DIAG_VAR: | |
# We will invert the precision matrix if it is diagonal as numerical stable (assuming bounds on precision) | |
M = np.diag(np.reciprocal(np.diagonal(Var))) | |
a_coeff = 0.5 * (x.T @ M @ x) + ell | |
new_a = (x - v * (x.T @ M @ v)) / a_coeff | |
else: | |
L = sla.cholesky(Var, lower=True) | |
intermed1 = sla.solve_triangular(L, x, lower=True) | |
intermed2 = sla.cho_solve((L, True), v) | |
a_coeff = 0.5 * ( intermed1.T @ intermed1) + ell | |
new_a = (x - v * (x.T @ intermed2)) / a_coeff | |
derivs = np.vstack((v, new_a)) | |
return np.squeeze(derivs) | |
def create_colutions_func(x_0, x_1, tspan): | |
def solutions_func(ell, var): | |
# We will solve the Boundary value problem using a shooting method. | |
Var = var * np.array([[1., 0.], [0., 1.]], dtype=np.float32) | |
def objective(u2_0): | |
U = odeint(func, np.concatenate((x_0, u2_0)), tspan, args=(Var, ell)) | |
u1 = U[:, :NUM_DIM] | |
return u1[-1] - x_1 | |
u2_0 = fsolve(objective, x_1) | |
x0 = np.array(np.concatenate((x_0, u2_0)), dtype=np.float32) | |
sol = odeint(func, x0, tspan, args=(Var, ell)) | |
return sol | |
return solutions_func | |
def main(): | |
x_0 = np.array([-1., -1.5], dtype=np.float32) | |
x_1 = np.array([-0.5, 1], dtype=np.float32) | |
tspan = np.linspace(0, 1, 101) | |
solution_func = create_colutions_func(x_0, x_1, tspan) | |
fig = plt.figure(figsize=(10, 10)) | |
ax = fig.add_subplot(1, 1, 1) | |
initial_var = 0.2 | |
initial_ell = 1. | |
MAX_VAR = 5 | |
MIN_VAR = 0.01 | |
MIN_PDF = 0. | |
MAX_PDF = 1. / (np.sqrt(2 * np.pi * MIN_VAR)) | |
mvn = stats.multivariate_normal(np.zeros(x_0.size), initial_var * np.eye(x_0.size, dtype=float)) | |
x_mesh = np.linspace(-2., 0.3) | |
y_mesh = np.linspace(-2, 1.5) | |
X, Y = np.meshgrid(x_mesh, y_mesh) | |
orig_size = X.shape | |
z = mvn.pdf(np.concatenate((X.flatten()[:, None], Y.flatten()[:, None]), axis=1)) | |
Z = z.reshape(*orig_size) | |
#V = np.logspace(-1000, 2., 50, base=np.e) | |
V = np.linspace(np.min(Z), np.max(Z), 20) | |
cs = ax.contourf(X,Y,Z, V, cmap=plt.get_cmap("Greens")) | |
initial_soln = solution_func(initial_ell, initial_var) | |
solution_line, = ax.plot(initial_soln[:, 0], initial_soln[:, 1], '--', lw=3, color=LINE) | |
ax.plot(x_0[0], x_0[1], 'o', markersize=MARKER_SIZE, color=END_POINTS) | |
ax.plot(x_1[0], x_1[1], 'o', markersize=MARKER_SIZE, color=END_POINTS) | |
axell = plt.axes([0.1, 0.95, 0.65, 0.03]) | |
axvar = plt.axes([0.1, 0.9, 0.65, 0.03]) | |
axrenormalise = plt.axes([0.8, 0.88, 0.15, 0.1]) | |
brenormaliseval = RadioButtons(axrenormalise, ('renormalise', 'nop')) | |
sell = Slider(axell, 'Ell', 0.01, 10.0, valinit=initial_ell) | |
svar = Slider(axvar, 'Variance', 0.01, 10., valinit=initial_var) | |
solution_line = [solution_line] | |
vprev = [V] | |
def update_ell(val): | |
ell_val = sell.val | |
var_val = svar.val | |
soln = solution_func(ell_val, var_val) | |
solution_line[0].set_data(soln[:, 0], soln[:, 1]) | |
fig.canvas.draw_idle() | |
def update_var(val): | |
ax.clear() | |
ell_val = sell.val | |
var_val = svar.val | |
renormalise_val = brenormaliseval.value_selected | |
mvn_nwq = stats.multivariate_normal(np.zeros(x_0.size), var_val * np.eye(x_0.size, dtype=float)) | |
z = mvn_nwq.pdf(np.concatenate((X.flatten()[:, None], Y.flatten()[:, None]), axis=1)) | |
Z = z.reshape(*orig_size) | |
soln = solution_func(ell_val, var_val) | |
if renormalise_val == 'renormalise': | |
V = np.linspace(np.min(Z), np.max(Z), 20) | |
vprev[0] = V | |
else: | |
V = vprev[0] | |
cs = ax.contourf(X, Y, Z, V, cmap=plt.get_cmap("Greens")) | |
solution_line_new, = ax.plot(soln[:, 0], soln[:, 1], '--', lw=3, color=LINE) | |
ax.plot(x_0[0], x_0[1], 'o', markersize=MARKER_SIZE, color=END_POINTS) | |
ax.plot(x_1[0], x_1[1], 'o', markersize=MARKER_SIZE, color=END_POINTS) | |
solution_line[0] = solution_line_new | |
plt.draw() | |
sell.on_changed(update_ell) | |
svar.on_changed(update_var) | |
brenormaliseval.on_clicked(update_var) | |
plt.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment