Last active
December 1, 2021 03:13
-
-
Save cshjin/0b73f20011843cad0ab2e1f8f95c4dae to your computer and use it in GitHub Desktop.
Solve a minimax problem in tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Solve a minimax problem | |
The problme is defined as | |
$$ \min_{x} \max_{y} f(x, y) = x^2(y+1) + y^2(x+1)$$ | |
The first order gradient is | |
$$ \frac{\partial f}{\partial x} = 2x(y+1) + y^2 $$ | |
$$ \frac{\partial f}{\partial y} = x^2 + 2y(x+1) $$ | |
From the first order optimality condition, the alternatively solver | |
should solve the problem and converge to a stationary point. | |
Otherwise, add constraints to the problem if the solver diverges. | |
Copyright (c) H. J. 2019 | |
""" | |
import tensorflow as tf | |
g = tf.Graph() | |
with g.as_default(): | |
# If required, add constraints to the variables. | |
x = tf.Variable(2, dtype=tf.float32) | |
y = tf.Variable(3, dtype=tf.float32) | |
loss = x**2*(y+1) + y**2*(x+1) | |
opti_min = tf.train.GradientDescentOptimizer(0.1).minimize(loss, var_list=[x]) | |
opti_max = tf.train.GradientDescentOptimizer(0.1).minimize(-loss, var_list=[y]) | |
with tf.Session(graph=g) as sess: | |
sess.run(tf.global_variables_initializer()) | |
for _ in range(50): | |
# alternatively solve the min max problem. | |
sess.run(opti_max) | |
print("after max: {:.4f} {:.4f}".format(sess.run(x), sess.run(y)), end=" ") | |
sess.run(opti_min) | |
print("after min: {:.4f} {:.4f}".format(sess.run(x), sess.run(y)), end=" ") | |
# check the gradient of variables | |
dfx, dfy = sess.run(tf.gradients(loss, [x, y])) | |
print("df/dx: {:.4f}".format(dfx), "df/dy: {:.4f}".format(dfy)) | |
# output the loss | |
_loss = sess.run(loss) | |
print("loss {:.4f}".format(_loss)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment