Last active
March 21, 2022 22:25
-
-
Save mblondel/b65435b371d49f259fdee5ff7facd445 to your computer and use it in GitHub Desktop.
A small script to get numerical evidence that a function is convex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Authors: Mathieu Blondel, Vlad Niculae | |
# License: BSD 3 clause | |
import numpy as np | |
def _gen_pairs(gen, max_iter, max_inner, random_state, verbose): | |
rng = np.random.RandomState(random_state) | |
# if tuple, interpret as randn | |
if isinstance(gen, tuple): | |
shape = gen | |
gen = lambda rng: rng.randn(*shape) | |
for it in range(max_iter): | |
if verbose: | |
print("iter", it + 1) | |
M1 = gen(rng) | |
M2 = gen(rng) | |
for t in np.linspace(0.01, 0.99, max_inner): | |
M = t * M1 + (1 - t) * M2 | |
yield M, M1, M2, t | |
def check_convex(func, gen, max_iter=1000, max_inner=10, | |
quasi=False, random_state=None, eps=1e-9, verbose=0): | |
""" | |
Numerically check whether the definition of a convex function holds for the | |
input function. | |
If answers "not convex", a counter-example has been found and | |
the function is guaranteed to be non-convex. Don't lose time proving its | |
convexity! | |
If answers "could be convex", you can't completely rule out the possibility | |
that the function is non-convex. To be completely sure, this needs to be | |
proved analytically. | |
This approach was explained by S. Boyd in his convex analysis lectures at | |
Stanford. | |
Parameters | |
---------- | |
func: | |
Function func(M) to be tested. | |
gen: tuple or function | |
If tuple, shape of the function argument M. Small arrays are recommended. | |
If function, function for generating M. | |
max_iter: int | |
Max number of trials. | |
max_inner: int | |
Max number of values between [0, 1] to be tested for the definition of | |
convexity. | |
quasi: bool (default=False) | |
If True, use quasi-convex definition instead of convex. | |
random_state: None or int | |
Random seed to be used. | |
eps: float | |
Tolerance. | |
verbose: int | |
Verbosity level. | |
""" | |
for M, M1, M2, t in _gen_pairs(gen, max_iter, max_inner, | |
random_state, verbose): | |
if quasi: | |
# quasi-convex if f(M) <= max(f(M1), f(M2)) | |
# not quasi convex if f(M) > max(f(M1), f(M2)) | |
diff = func(M) - max(func(M1), func(M2)) | |
else: | |
# convex if f(M) <= t * f(M1) + (1 - t) * f(M2) | |
# non-convex if f(M) > t * f(M1) + (1 - t) * f(M2) | |
diff = func(M) - (t * func(M1) + (1 - t) * func(M2)) | |
if diff > eps: | |
# We found a counter-example. | |
print("not convex (diff=%f)" % diff) | |
return | |
# To be completely sure, this needs to be proved analytically. | |
print("could be convex") | |
if __name__ == "__main__": | |
def sqnorm(x): | |
return np.dot(x, x) | |
check_convex(sqnorm, gen=(5,), max_iter=10000, max_inner=10, | |
random_state=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment