Skip to content

Instantly share code, notes, and snippets.

@bigsnarfdude
Last active December 17, 2018 18:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bigsnarfdude/0580b293382d048496a7895c9ffc3964 to your computer and use it in GitHub Desktop.
Save bigsnarfdude/0580b293382d048496a7895c9ffc3964 to your computer and use it in GitHub Desktop.
From Appendix B in the paper Neural ODE Solver Implementation using autograd https://arxiv.org/abs/1806.07366
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
"""
From Appendix B in the paper
https://arxiv.org/abs/1806.07366
Neural ODE Solver
Implementation in autograd
"""
import scipy.integrate
import autograd.numpy as np
from autograd.extend import primitive, defvjp_argnums
from autograd import make_vjp
from autograd.misc import flatten
from autograd.builtins import tuple
odeint = primitive(scipy.integrate.odeint)
def grad_odeint_all(yt, func, y0, t, func_args, **kwargs):
"""
Extended from "Scalable Inference of Ordinary Differential"
Equation Models of Biochemical Processes". Sec. 2.4.2
Fabian Froehlich, Carolin Loos, Jan Hasenauer, 2017
https://arxiv.org/pdf/1711.08079.pdf
"""
T, D = np.shape(yt)
flat_args, unflatten = flatten(func_args)
def flat_func(y, t, flat_args):
return func(y, t, *unflatten(flat_args))
def unpack(x):
# y, vjp_y, vjp_t, vjp_args
return x[0:D], x[D:2 * D], x[2 * D], x[2 * D + 1:]
def augmented_dynamics(augmented_state, t, flat_args):
# Original system augemented with vjp_y, vjp_t and vjp_args
y, vjp_y, _, _ = unpack(augmented_state)
vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args)
vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y)
return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))
def vjp_all(g, **kwargs):
vjp_y = g[-1, :]
vjp_t0 = 0
time_vjp_list = []
vjp_args = np.zeros(np.size(flat_args))
for i in range(T - 1, 0, -1):
# Compute effect of moving current time.
vjp_cur_t = np.dot(func(yt[i, :], t[i], *func_args), g[i, :])
time_vjp_list.append(vjp_cur_t)
vjp_t0 = vjp_t0 - vjp_cur_t
# Run augmented system backwards to the previous observation
aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
aug_ans = odeint(augmented_dynamics, aug_y0,
np.array(t[i], t[i - 1]), tuple((flat_args,)), **kwargs)
_, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])
# Add gradient from current output
vjp_y = vjp_y + g[i - 1, :]
time_vjp_list.append(vjp_t0)
vjp_times = np.hstack(time_vjp_list)[::-1]
return None, vjp_y, vjp_times, unflatten(vjp_args)
return vjp_all
def grad_argnums_wrapper(all_vjp_builder):
"""
A generic autograd helper function. Takes a function that
builds vjps for all arguments, and wraps it to return only required vjps.
"""
def build_selected_vjps(argnums, ans, combined_args, kwargs):
vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)
def chosen_vjps(g):
# Return whichever vjps were asked for
all_vjps = vjp_func(g)
return [all_vjps[argnum] for argnum in argnums]
return chosen_vjps
return build_selected_vjps
if __name__ == '__main__':
print(defvjp_argnums(odeint, grad_argnums_wrapper(grad_odeint_all)))
from __future__ import absolute_import
from __future__ import print_function
from builtins import range
import matplotlib.pyplot as plt
import numpy as npo
import autograd.numpy as np
from autograd import grad
#from autograd.scipy.integrate import odeint
from autograd.builtins import tuple
from autograd.misc.optimizers import adam
import autograd.numpy.random as npr
import scipy.integrate
import autograd.numpy as np
from autograd.extend import primitive, defvjp_argnums
from autograd import make_vjp
from autograd.misc import flatten
from autograd.builtins import tuple
odeint = primitive(scipy.integrate.odeint)
def grad_odeint(yt, func, y0, t, func_args, **kwargs):
# Extended from "Scalable Inference of Ordinary Differential
# Equation Models of Biochemical Processes", Sec. 2.4.2
# Fabian Froehlich, Carolin Loos, Jan Hasenauer, 2017
# https://arxiv.org/abs/1711.08079
T, D = np.shape(yt)
flat_args, unflatten = flatten(func_args)
def flat_func(y, t, flat_args):
return func(y, t, *unflatten(flat_args))
def unpack(x):
# y, vjp_y, vjp_t, vjp_args
return x[0:D], x[D:2 * D], x[2 * D], x[2 * D + 1:]
def augmented_dynamics(augmented_state, t, flat_args):
# Orginal system augmented with vjp_y, vjp_t and vjp_args.
y, vjp_y, _, _ = unpack(augmented_state)
vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args)
vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y)
return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))
def vjp_all(g):
vjp_y = g[-1, :]
vjp_t0 = 0
time_vjp_list = []
vjp_args = np.zeros(np.size(flat_args))
for i in range(T - 1, 0, -1):
# Compute effect of moving measurement time.
vjp_cur_t = np.dot(func(yt[i, :], t[i], *func_args), g[i, :])
time_vjp_list.append(vjp_cur_t)
vjp_t0 = vjp_t0 - vjp_cur_t
# Run augmented system backwards to the previous observation.
aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
aug_ans = odeint(augmented_dynamics, aug_y0,
np.array([t[i], t[i - 1]]), tuple((flat_args,)), **kwargs)
_, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])
# Add gradient from current output.
vjp_y = vjp_y + g[i - 1, :]
time_vjp_list.append(vjp_t0)
vjp_times = np.hstack(time_vjp_list)[::-1]
return None, vjp_y, vjp_times, unflatten(vjp_args)
return vjp_all
def argnums_unpack(all_vjp_builder):
# A generic autograd helper function. Takes a function that
# builds vjps for all arguments, and wraps it to return only required vjps.
def build_selected_vjps(argnums, ans, combined_args, kwargs):
vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)
def chosen_vjps(g): # Returns whichever vjps were asked for.
all_vjps = vjp_func(g)
return [all_vjps[argnum] for argnum in argnums]
return chosen_vjps
return build_selected_vjps
defvjp_argnums(odeint, argnums_unpack(grad_odeint))
N = 30 # Dataset size
D = 2 # Data dimension
max_T = 1.5
# Two-dimensional damped oscillator
def func(y, t0, A):
return np.dot(y**3, A)
def nn_predict(inputs, t, params):
for W, b in params:
outputs = np.dot(inputs, W) + b
inputs = np.maximum(0, outputs)
return outputs
def init_nn_params(scale, layer_sizes, rs=npr.RandomState(0)):
"""Build a list of (weights, biases) tuples, one for each layer."""
return [(rs.randn(insize, outsize) * scale, # weight matrix
rs.randn(outsize) * scale) # bias vector
for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:])]
# Define neural ODE model.
def ode_pred(params, y0, t):
return odeint(nn_predict, y0, t, tuple((params,)), rtol=0.01)
def L1_loss(pred, targets):
return np.mean(np.abs(pred - targets))
if __name__ == '__main__':
# Generate data from true dynamics.
true_y0 = np.array([2., 0.]).T
t = np.linspace(0., max_T, N)
true_A = np.array([[-0.1, 2.0], [-2.0, -0.1]])
true_y = odeint(func, true_y0, t, args=(true_A,))
def train_loss(params, iter):
pred = ode_pred(params, true_y0, t)
return L1_loss(pred, true_y)
# Set up figure
fig = plt.figure(figsize=(12, 4), facecolor='white')
ax_traj = fig.add_subplot(131, frameon=False)
ax_phase = fig.add_subplot(132, frameon=False)
ax_vecfield = fig.add_subplot(133, frameon=False)
plt.show(block=False)
# Plots data and learned dynamics.
def callback(params, iter, g):
pred = ode_pred(params, true_y0, t)
print("Iteration {:d} train loss {:.6f}".format(
iter, L1_loss(pred, true_y)))
ax_traj.cla()
ax_traj.set_title('Trajectories')
ax_traj.set_xlabel('t')
ax_traj.set_ylabel('x,y')
ax_traj.plot(t, true_y[:, 0], '-', t, true_y[:, 1], 'g-')
ax_traj.plot(t, pred[:, 0], '--', t, pred[:, 1], 'b--')
ax_traj.set_xlim(t.min(), t.max())
ax_traj.set_ylim(-2, 2)
ax_traj.xaxis.set_ticklabels([])
ax_traj.yaxis.set_ticklabels([])
ax_traj.legend()
ax_phase.cla()
ax_phase.set_title('Phase Portrait')
ax_phase.set_xlabel('x')
ax_phase.set_ylabel('y')
ax_phase.plot(true_y[:, 0], true_y[:, 1], 'g-')
ax_phase.plot(pred[:, 0], pred[:, 1], 'b--')
ax_phase.set_xlim(-2, 2)
ax_phase.set_ylim(-2, 2)
ax_phase.xaxis.set_ticklabels([])
ax_phase.yaxis.set_ticklabels([])
ax_vecfield.cla()
ax_vecfield.set_title('Learned Vector Field')
ax_vecfield.set_xlabel('x')
ax_vecfield.set_ylabel('y')
ax_vecfield.xaxis.set_ticklabels([])
ax_vecfield.yaxis.set_ticklabels([])
# vector field plot
y, x = npo.mgrid[-2:2:21j, -2:2:21j]
dydt = nn_predict(np.stack([x, y], -1).reshape(21 * 21, 2), 0,
params).reshape(-1, 2)
mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
dydt = (dydt / mag)
dydt = dydt.reshape(21, 21, 2)
ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
ax_vecfield.set_xlim(-2, 2)
ax_vecfield.set_ylim(-2, 2)
fig.tight_layout()
plt.draw()
plt.pause(0.001)
# Train neural net dynamics to match data.
init_params = init_nn_params(0.1, layer_sizes=[D, 150, D])
optimized_params = adam(grad(train_loss), init_params,
num_iters=1000, callback=callback)
import tensorflow as tf
import autograd.numpy as np
from autograd import grad
from tensorflow.python.framework import function
rng = np.random.RandomState(42)
x_np = rng.randn(4,4).astype(np.float32)
with tf.device('/cpu:0'):
x = tf.Variable(x_np)
def tf_loss(a):
return tf.reduce_sum(tf.square(a))
def np_loss(a):
return np.array(2.).astype(np.float32)*np.square(a).sum()
grad_np_loss = grad(np_loss)
l = tf_loss(x)
g = tf.gradients(l, x)
with tf.device('/cpu:0'):
np_in_tf = tf.py_func(np_loss, [x], tf.float32)
npgrad_in_tf = tf.py_func(grad_np_loss, [x], tf.float32)
@function.Defun()
def op_grad(x, grad):
return [tf.py_func(grad_np_loss, [x], tf.float32)]
@function.Defun(grad_func=op_grad)
def tf_replaced_grad_loss(a):
return tf_loss(a)
with tf.device('/cpu:0'):
tf_np_grad = tf.gradients(tf_replaced_grad_loss(x),x)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("Tensorflow gradient:\n")
print(sess.run(g)[0])
print("\nNumpy gradient (should be 2 times tf version):\n")
print(grad_np_loss(x_np))
print("\nNumpy gradient evaluated in Tensorflow:\n")
print(sess.run(npgrad_in_tf))
print("\nNumpy gradient put in Tensorflow graph:\n")
print(sess.run(tf_np_grad)[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment