Skip to content

Instantly share code, notes, and snippets.

@cshjin
Last active December 1, 2021 03:13
Show Gist options
  • Save cshjin/0b73f20011843cad0ab2e1f8f95c4dae to your computer and use it in GitHub Desktop.
Save cshjin/0b73f20011843cad0ab2e1f8f95c4dae to your computer and use it in GitHub Desktop.
Solve a minimax problem in tensorflow
""" Solve a minimax problem
The problme is defined as
$$ \min_{x} \max_{y} f(x, y) = x^2(y+1) + y^2(x+1)$$
The first order gradient is
$$ \frac{\partial f}{\partial x} = 2x(y+1) + y^2 $$
$$ \frac{\partial f}{\partial y} = x^2 + 2y(x+1) $$
From the first order optimality condition, the alternatively solver
should solve the problem and converge to a stationary point.
Otherwise, add constraints to the problem if the solver diverges.
Copyright (c) H. J. 2019
"""
import tensorflow as tf
g = tf.Graph()
with g.as_default():
# If required, add constraints to the variables.
x = tf.Variable(2, dtype=tf.float32)
y = tf.Variable(3, dtype=tf.float32)
loss = x**2*(y+1) + y**2*(x+1)
opti_min = tf.train.GradientDescentOptimizer(0.1).minimize(loss, var_list=[x])
opti_max = tf.train.GradientDescentOptimizer(0.1).minimize(-loss, var_list=[y])
with tf.Session(graph=g) as sess:
sess.run(tf.global_variables_initializer())
for _ in range(50):
# alternatively solve the min max problem.
sess.run(opti_max)
print("after max: {:.4f} {:.4f}".format(sess.run(x), sess.run(y)), end=" ")
sess.run(opti_min)
print("after min: {:.4f} {:.4f}".format(sess.run(x), sess.run(y)), end=" ")
# check the gradient of variables
dfx, dfy = sess.run(tf.gradients(loss, [x, y]))
print("df/dx: {:.4f}".format(dfx), "df/dy: {:.4f}".format(dfy))
# output the loss
_loss = sess.run(loss)
print("loss {:.4f}".format(_loss))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment