Last active
February 5, 2021 23:11
-
-
Save currymj/59e00c474d847957e8576671fefcabb5 to your computer and use it in GitHub Desktop.
solving unconstrained Stackelberg problems using autograd to differentiate through a Newton solver
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax.numpy as jnp | |
from jax import grad, hessian, jit | |
from jax.scipy.linalg import solve | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# The goal here is to solve a Stackelberg game in a particularly grotesque way. | |
# To compute the follower best response, we take a few Newton steps using autograd to compute the Hessian and gradients. | |
# This is implemented in function `follower_bestresponse`. | |
# We then use autograd to differentiate through the Newton updates that compute the best response, compute gradients | |
# and Hessians, and take Newton steps on the leader's payoff. | |
# MacalHurter1997 from BOLIB bilevel optimization examples. | |
# These are quadratic so we can expect Newton to work great -- actually only need 1 step! | |
# this is sort of cheating -- would be good to test on a problem that's harder for Newton. | |
def follower_objective(inner_y, outer_x): | |
return (0.5*inner_y**2 + 500*inner_y - 50*outer_x*inner_y).sum() | |
def leader_objective(inner_y, outer_x): | |
return ((outer_x - 1)**2 + (inner_y - 1)**2).sum() | |
def newton_update(gradf, hessf, x): | |
return solve(hessf(x), gradf(x)) | |
def newton(gradf, hessf, x0, steps=1, eta=1.0): | |
x = x0 | |
for i in range(steps): | |
x = x - eta * newton_update(gradf, hessf, x) | |
return x | |
@jit | |
def follower_bestresponse(outer_x): | |
inner_grad = lambda y: grad_inner(y, outer_x) | |
inner_hess = lambda y: hess_inner(y, outer_x) | |
opt_inner = newton(inner_grad, inner_hess, jnp.array([1.0]), steps=1) | |
return opt_inner | |
def leader_payout(outer_x): | |
return leader_objective(follower_bestresponse(outer_x), outer_x) | |
grad_outer = jit(grad(leader_payout)) | |
hess_outer = jit(hessian(leader_payout)) | |
def stackelberg_soln(): | |
leader_move = newton(grad_outer, hess_outer, jnp.array([1.0]), steps=10) | |
return leader_move, follower_bestresponse(leader_move) | |
# correct output: (DeviceArray([10.016394], dtype=float32), DeviceArray([0.8196831], dtype=float32)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment