From Appendix B in the paper Neural ODE Solver Implementation using autograd
From Appendix B in the paper
Neural ODE Solver
Implementation in autograd
import scipy.integrate
import autograd.numpy as np
from autograd.extend import primitive, defvjp_argnums
from autograd import make_vjp
from autograd.misc import flatten
from autograd.builtins import tuple
odeint = primitive(scipy.integrate.odeint)
def grad_odeint_all(yt, func, y0, t, func_args, **kwargs):
Extended from "Scalable Inference of Ordinary Differential"
Equation Models of Biochemical Processes". Sec. 2.4.2
Fabian Froehlich, Carolin Loos, Jan Hasenauer, 2017
T, D = np.shape(yt)
flat_args, unflatten = flatten(func_args)
def flat_func(y, t, flat_args):
return func(y, t, *unflatten(flat_args))
def unpack(x):
# y, vjp_y, vjp_t, vjp_args
return x[0:D], x[D:2 * D], x[2 * D], x[2 * D + 1:]
def augmented_dynamics(augmented_state, t, flat_args):
# Original system augemented with vjp_y, vjp_t and vjp_args
y, vjp_y, _, _ = unpack(augmented_state)
vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args)
vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y)
return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))
def vjp_all(g, **kwargs):
vjp_y = g[-1, :]
vjp_t0 = 0
time_vjp_list = []
vjp_args = np.zeros(np.size(flat_args))
for i in range(T - 1, 0, -1):
# Compute effect of moving current time.
vjp_cur_t =[i, :], t[i], *func_args), g[i, :])
vjp_t0 = vjp_t0 - vjp_cur_t
# Run augmented system backwards to the previous observation
aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
aug_ans = odeint(augmented_dynamics, aug_y0,
np.array(t[i], t[i - 1]), tuple((flat_args,)), **kwargs)
_, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])
# Add gradient from current output
vjp_y = vjp_y + g[i - 1, :]
vjp_times = np.hstack(time_vjp_list)[::-1]
return None, vjp_y, vjp_times, unflatten(vjp_args)
return vjp_all
def grad_argnums_wrapper(all_vjp_builder):
A generic autograd helper function. Takes a function that
builds vjps for all arguments, and wraps it to return only required vjps.
def build_selected_vjps(argnums, ans, combined_args, kwargs):
vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)
def chosen_vjps(g):
# Return whichever vjps were asked for
all_vjps = vjp_func(g)
return [all_vjps[argnum] for argnum in argnums]
return chosen_vjps
return build_selected_vjps
if __name__ == '__main__':
print(defvjp_argnums(odeint, grad_argnums_wrapper(grad_odeint_all)))
from __future__ import absolute_import
from __future__ import print_function
from builtins import range
import matplotlib.pyplot as plt
import numpy as npo
import autograd.numpy as np
from autograd import grad
#from autograd.scipy.integrate import odeint
from autograd.builtins import tuple
from autograd.misc.optimizers import adam
import autograd.numpy.random as npr
import scipy.integrate
import autograd.numpy as np
from autograd.extend import primitive, defvjp_argnums
from autograd import make_vjp
from autograd.misc import flatten
from autograd.builtins import tuple
odeint = primitive(scipy.integrate.odeint)
def grad_odeint(yt, func, y0, t, func_args, **kwargs):
# Extended from "Scalable Inference of Ordinary Differential
# Equation Models of Biochemical Processes", Sec. 2.4.2
# Fabian Froehlich, Carolin Loos, Jan Hasenauer, 2017
T, D = np.shape(yt)
flat_args, unflatten = flatten(func_args)
def flat_func(y, t, flat_args):
return func(y, t, *unflatten(flat_args))
def unpack(x):
# y, vjp_y, vjp_t, vjp_args
return x[0:D], x[D:2 * D], x[2 * D], x[2 * D + 1:]
def augmented_dynamics(augmented_state, t, flat_args):
# Orginal system augmented with vjp_y, vjp_t and vjp_args.
y, vjp_y, _, _ = unpack(augmented_state)
vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args)
vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y)
return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))
def vjp_all(g):
vjp_y = g[-1, :]
vjp_t0 = 0
time_vjp_list = []
vjp_args = np.zeros(np.size(flat_args))
for i in range(T - 1, 0, -1):
# Compute effect of moving measurement time.
vjp_cur_t =[i, :], t[i], *func_args), g[i, :])
vjp_t0 = vjp_t0 - vjp_cur_t
# Run augmented system backwards to the previous observation.
aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
aug_ans = odeint(augmented_dynamics, aug_y0,
np.array([t[i], t[i - 1]]), tuple((flat_args,)), **kwargs)
_, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])
# Add gradient from current output.
vjp_y = vjp_y + g[i - 1, :]
vjp_times = np.hstack(time_vjp_list)[::-1]
return None, vjp_y, vjp_times, unflatten(vjp_args)
return vjp_all
def argnums_unpack(all_vjp_builder):
# A generic autograd helper function. Takes a function that
# builds vjps for all arguments, and wraps it to return only required vjps.
def build_selected_vjps(argnums, ans, combined_args, kwargs):
vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)
def chosen_vjps(g): # Returns whichever vjps were asked for.
all_vjps = vjp_func(g)
return [all_vjps[argnum] for argnum in argnums]
return chosen_vjps
return build_selected_vjps
defvjp_argnums(odeint, argnums_unpack(grad_odeint))
N = 30 # Dataset size
D = 2 # Data dimension
max_T = 1.5
# Two-dimensional damped oscillator
def func(y, t0, A):
return**3, A)
def nn_predict(inputs, t, params):
for W, b in params:
outputs =, W) + b
inputs = np.maximum(0, outputs)
return outputs
def init_nn_params(scale, layer_sizes, rs=npr.RandomState(0)):
"""Build a list of (weights, biases) tuples, one for each layer."""
return [(rs.randn(insize, outsize) * scale, # weight matrix
rs.randn(outsize) * scale) # bias vector
for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:])]
# Define neural ODE model.
def ode_pred(params, y0, t):
return odeint(nn_predict, y0, t, tuple((params,)), rtol=0.01)
def L1_loss(pred, targets):
return np.mean(np.abs(pred - targets))
if __name__ == '__main__':
# Generate data from true dynamics.
true_y0 = np.array([2., 0.]).T
t = np.linspace(0., max_T, N)
true_A = np.array([[-0.1, 2.0], [-2.0, -0.1]])
true_y = odeint(func, true_y0, t, args=(true_A,))
def train_loss(params, iter):
pred = ode_pred(params, true_y0, t)
return L1_loss(pred, true_y)
# Set up figure
fig = plt.figure(figsize=(12, 4), facecolor='white')
ax_traj = fig.add_subplot(131, frameon=False)
ax_phase = fig.add_subplot(132, frameon=False)
ax_vecfield = fig.add_subplot(133, frameon=False)
# Plots data and learned dynamics.
def callback(params, iter, g):
pred = ode_pred(params, true_y0, t)
print("Iteration {:d} train loss {:.6f}".format(
iter, L1_loss(pred, true_y)))
ax_traj.plot(t, true_y[:, 0], '-', t, true_y[:, 1], 'g-')
ax_traj.plot(t, pred[:, 0], '--', t, pred[:, 1], 'b--')
ax_traj.set_xlim(t.min(), t.max())
ax_traj.set_ylim(-2, 2)
ax_phase.set_title('Phase Portrait')
ax_phase.plot(true_y[:, 0], true_y[:, 1], 'g-')
ax_phase.plot(pred[:, 0], pred[:, 1], 'b--')
ax_phase.set_xlim(-2, 2)
ax_phase.set_ylim(-2, 2)
ax_vecfield.set_title('Learned Vector Field')
# vector field plot
y, x = npo.mgrid[-2:2:21j, -2:2:21j]
dydt = nn_predict(np.stack([x, y], -1).reshape(21 * 21, 2), 0,
params).reshape(-1, 2)
mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
dydt = (dydt / mag)
dydt = dydt.reshape(21, 21, 2)
ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
ax_vecfield.set_xlim(-2, 2)
ax_vecfield.set_ylim(-2, 2)
# Train neural net dynamics to match data.
init_params = init_nn_params(0.1, layer_sizes=[D, 150, D])
optimized_params = adam(grad(train_loss), init_params,
num_iters=1000, callback=callback)
import tensorflow as tf
import autograd.numpy as np
from autograd import grad
from tensorflow.python.framework import function
rng = np.random.RandomState(42)
x_np = rng.randn(4,4).astype(np.float32)
with tf.device('/cpu:0'):
x = tf.Variable(x_np)
def tf_loss(a):
return tf.reduce_sum(tf.square(a))
def np_loss(a):
return np.array(2.).astype(np.float32)*np.square(a).sum()
grad_np_loss = grad(np_loss)
l = tf_loss(x)
g = tf.gradients(l, x)
with tf.device('/cpu:0'):
np_in_tf = tf.py_func(np_loss, [x], tf.float32)
npgrad_in_tf = tf.py_func(grad_np_loss, [x], tf.float32)
def op_grad(x, grad):
return [tf.py_func(grad_np_loss, [x], tf.float32)]
def tf_replaced_grad_loss(a):
return tf_loss(a)
with tf.device('/cpu:0'):
tf_np_grad = tf.gradients(tf_replaced_grad_loss(x),x)
with tf.Session() as sess:
print("Tensorflow gradient:\n")
print("\nNumpy gradient (should be 2 times tf version):\n")
print("\nNumpy gradient evaluated in Tensorflow:\n")
print("\nNumpy gradient put in Tensorflow graph:\n")
