Skip to content

Instantly share code, notes, and snippets.

@currymj
Last active February 5, 2021 23:11
Show Gist options
  • Save currymj/59e00c474d847957e8576671fefcabb5 to your computer and use it in GitHub Desktop.
Save currymj/59e00c474d847957e8576671fefcabb5 to your computer and use it in GitHub Desktop.
solving unconstrained Stackelberg problems using autograd to differentiate through a Newton solver
import jax.numpy as jnp
from jax import grad, hessian, jit
from jax.scipy.linalg import solve
import numpy as np
import matplotlib.pyplot as plt
# The goal here is to solve a Stackelberg game in a particularly grotesque way.
# To compute the follower best response, we take a few Newton steps using autograd to compute the Hessian and gradients.
# This is implemented in function `follower_bestresponse`.
# We then use autograd to differentiate through the Newton updates that compute the best response, compute gradients
# and Hessians, and take Newton steps on the leader's payoff.
# MacalHurter1997 from BOLIB bilevel optimization examples.
# These are quadratic so we can expect Newton to work great -- actually only need 1 step!
# this is sort of cheating -- would be good to test on a problem that's harder for Newton.
def follower_objective(inner_y, outer_x):
return (0.5*inner_y**2 + 500*inner_y - 50*outer_x*inner_y).sum()
def leader_objective(inner_y, outer_x):
return ((outer_x - 1)**2 + (inner_y - 1)**2).sum()
def newton_update(gradf, hessf, x):
return solve(hessf(x), gradf(x))
def newton(gradf, hessf, x0, steps=1, eta=1.0):
x = x0
for i in range(steps):
x = x - eta * newton_update(gradf, hessf, x)
return x
@jit
def follower_bestresponse(outer_x):
inner_grad = lambda y: grad_inner(y, outer_x)
inner_hess = lambda y: hess_inner(y, outer_x)
opt_inner = newton(inner_grad, inner_hess, jnp.array([1.0]), steps=1)
return opt_inner
def leader_payout(outer_x):
return leader_objective(follower_bestresponse(outer_x), outer_x)
grad_outer = jit(grad(leader_payout))
hess_outer = jit(hessian(leader_payout))
def stackelberg_soln():
leader_move = newton(grad_outer, hess_outer, jnp.array([1.0]), steps=10)
return leader_move, follower_bestresponse(leader_move)
# correct output: (DeviceArray([10.016394], dtype=float32), DeviceArray([0.8196831], dtype=float32))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment