Skip to content

Instantly share code, notes, and snippets.

@pervognsen pervognsen/costate.md
Last active Mar 21, 2019

Embed
What would you like to do?

I've been reading this much-publicized paper on neural ordinary differential equations:

https://arxiv.org/abs/1806.07366

I found their presentation of the costate/adjoint method to be severely lacking in intuition and notational clarity, so I decided to write up my own tutorial treatment. I'm familiar with this material from its original setting in optimal control theory.

You have a dynamical system described by an autonomous first-order ODE, x' = f(x), where the state x belongs to an n-dimensional vector space. There is a value function V(x) defined over the state space. Given a particular path t -> x(t) satisfying the ODE, we may evaluate it at the terminal time T to get the terminal state x(T) and the terminal value V(x(T)).

Now we want to analyze the sensitivity of the terminal value to perturbations along the path. If I perturb the path at time t to a state off the original path, I will end up on a different integral curve of the ODE and as such the terminal state (and presumably the terminal value) will be different. In fact, even if I push or pull the state along the same integral curve, it's equivalent to advancing or retarding in time, so the terminal state will change.

Because of this "perturb the path and then flow forward to the terminal time" idea, it should be clear that this problem must be solved backwards, starting from the terminal time and working backwards to the initial time.

The sensitivity along the state path will be described by a corresponding "costate" path. The costate p(t) is the gradient of the terminal value with respect to perturbations dx at x(t): the differential change in the terminal value will be p(t) dx. As usual with gradients, we think of p(t) as a row vector, so that the product p(t) dx is a scalar. (If you want more mathematical sophistication, think 'dual vector' instead of 'row vector'.)

At time T you're done flowing, so any perturbation just changes the value function directly. Hence p(T) = V'(x(T)). This serves as the base case for our backward induction. It is the final condition for the costate path in the same way that x(0) is the initial condition for the state path.

Otherwise, if we perturb x(t) by dx and then flow forward by dt time we get x(t) + dx + f(x(t) + dx) dt = x(t) + dx + (f(x(t)) + f'(x(t)) dx) dt = x(t + dt) + (1 + f'(x(t)) dt) dx. Hence the perturbation dx at time t propagates to a perturbation (1 + f'(x(t)) dt) dx at time t + dt. This gives the costate equation p(t) dx = p(t + dt) (1 + f'(x(t)) dt) dx. Since p(t + dt) = p(t) + p'(t) dt, we can cancel p(t) dx and get 0 = p(t) f'(x(t)) dt dx + p'(t) dt dx if we neglect the higher-order differential dt^2. This gives the costate equation in its more conventional form:

p'(t) = -p(t) f'(x(t))

Note that f'(x(t)) is a Jacobian matrix and appears to the right of p(t) since p(t) is a row vector. If we wanted to write p(t) as a column vector, we would multiply by the Jacobian transpose on the left:

p'(t)^T = -f'(x(t))^T p(t)^T

Once we know x(t) we can solve backward in time for p(t), starting at the terminal time t = T. If you solve for x(t) with forward Euler, you could solve for p(t) with backward Euler. Backward Euler is usually described as an implicit ODE method, but that's only true for the initial value problem; it's an explicit method when used for backward induction.

You don't have to remember the whole state path for the costate backward solve; once x(T) is known, you can backward solve for x(t) and p(t) in lockstep, since the state dynamics are time reversible under mild technical conditions on f (e.g. Lipschitz continuity). The downside is that you will not end up exactly at your initial state because of approximation errors in the ODE method and round-off error from finite-precision arithmetic. However, if you use an ODE method with high-order accuracy and small enough time steps, you can expect the backward solve to yield a similar state path. The other major downside is that the backward method is no longer explicit when you "forget" the state path; the backward method is only explicit if you know f' at the previous state in time (which is the next state for backward induction) at every step.

However, as the paper notes, the advantage of the forgetful approach is that not only does the memory consumption not grow with the depth (number of time steps) of your neural ODE, but reconstructing the approximate state path backwards lets you use an existing black box ODE solver from an external library for the backpropagation; we just invoke the black box solver for x' = -f(x) to solve it backward in time. More generally, you can pull this trick any time the forward state equation is efficiently invertible; ODEs are generally time reversible and when solved numerically can be approximately inverted by running the ODE solver on the time-reversed equation.

There's a discrete-time version of everything we just did. If x[n+1] = f(x[n]) is the forward state equation then the backward costate equation is p[n] = p[n+1] f'(x[n]). This is just old fashioned discrete backpropagation. The negative sign in the continuous-time costate equation is a consequence of expressing it as a differential equation. That's why it doesn't appear here. You can write the discrete-time system as a difference equation if you like and the negative sign in the costate equation reappears when written as a forward difference equation. If you solve the continuous-time costate equation with backward Euler, the increment from p(t) to p(t - dt) has a positive sign. Similarly, the backward difference form of the discrete-time costate equation has a positive sign. If you want to connect your discrete-time system to a corresponding continuous-time system governed by a differential equation or variational principle (e.g. discrete Lagrangian or Hamiltonian mechanics), you might want to write the dynamics as a difference equation. Otherwise, directly expressing x[n+1] as a function of x[n] is faster and simpler.

As an aside, I've seen people "explain" discrete backpropagation and reverse-mode automatic differentiation via the costate/adjoint picture we just went over. While our discussion of discrete-time systems shows a connection, in my opinion this is an extremely round-about and confusing road to understanding what's going on. Backprop/RAD is just the row vector/dual vector version of the proper transformation law for gradients as encapsulated in the simple slogan: "Vectors are pushed forward by differentiable functions, but dual vectors are pulled backward." There's only confusion because people usually insist on writing the gradient as a column vector, and then you have to multiply on the left by the Jacobian transpose rather than on the right by the Jacobian, as explained earlier. The costate/adjoint method is an application of this basic fact, not the other way around.

There's a standard way to generalize this beyond its immediate setting of terminal value problems. Let's say you want to optimize a value function that depends on the whole path by integrating a scalar function L (called the Lagrangian) along the path: int(0, T) L(x(t), x'(t)) dt. This can be transformed into a terminal value problem by augmenting the state space with an extra coordinate v that serves as a value accumulator along the path. Then x becomes (x, v) and x' = f(x) is augmented with a dynamical equation for the value accumulator: v' = L(x, f(x)). This is a first-order ODE over the augmented state space and the augmented value function is V(x, v) = v, so we can apply our techniques from before. You can also combine a terminal value with a Lagrangian value by just summing them: V(x, v) = V(x) + v.

You can handle both time-dependent and control-dependent effects within our existing framework by augmenting the state space appropriately. The augmented state is (x, u, s) with dynamics governed by x' = f(x, u, s), u' = k'(s), s' = 1 where k(t) is the externally specified set of control parameters as a function of time and s acts like a time accumulator that ticks at the same rate as t. This first-order ODE is autonomous since t does not appear explicitly. The component of the augmented costate p(t) corresponding to u is the gradient of the terminal value with respect to perturbations of u(t) = k(t), which is what you would use for gradient descent in an optimal control problem ("continuous backprop"). There's no reason to forward integrate u and s; you can just evaluate s(t) = t and u(t) = k(t) on demand based on the current time. Since there's no feedback in the augmented dynamics for u (we're just coupling u to the external control k), the backward solve for the augmented costate corresponding to u is just integration/quadrature and doesn't need a proper ODE solver, but of course you need the ODE solver for the part of the augmented costate corresponding to x since x' = f(x, u, s) has feedback. If you like, you can do the backward solve for the costates corresponding to (x, s) using the backward ODE solver, and then do a final quadrature to compute the costates corresponding to u once the (x, s) costates are known, or you can compute both in lockstep as suggested by the augmented state space formulation, which is ideal if you're taking the "forgetful" approach to the state path for the backward solve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.