A small script to get numerical evidence that a function is convex
# Authors: Mathieu Blondel, Vlad Niculae | |
# License: BSD 3 clause | |
import numpy as np | |
def _gen_pairs(gen, max_iter, max_inner, random_state, verbose): | |
rng = np.random.RandomState(random_state) | |
# if tuple, interpret as randn | |
if isinstance(gen, tuple): | |
shape = gen | |
gen = lambda rng: rng.randn(*shape) | |
for it in range(max_iter): | |
if verbose: | |
print("iter", it + 1) | |
M1 = gen(rng) | |
M2 = gen(rng) | |
for t in np.linspace(0.01, 0.99, max_inner): | |
M = t * M1 + (1 - t) * M2 | |
yield M, M1, M2, t | |
def check_convex(func, gen, max_iter=1000, max_inner=10, | |
quasi=False, random_state=None, eps=1e-9, verbose=0): | |
""" | |
Numerically check whether the definition of a convex function holds for the | |
input function. | |
If answers "not convex", a counter-example has been found and | |
the function is guaranteed to be non-convex. Don't lose time proving its | |
convexity! | |
If answers "could be convex", you can't completely rule out the possibility | |
that the function is non-convex. To be completely sure, this needs to be | |
proved analytically. | |
This approach was explained by S. Boyd in his convex analysis lectures at | |
Stanford. | |
Parameters | |
---------- | |
func: | |
Function func(M) to be tested. | |
gen: tuple or function | |
If tuple, shape of the function argument M. Small arrays are recommended. | |
If function, function for generating M. | |
max_iter: int | |
Max number of trials. | |
max_inner: int | |
Max number of values between [0, 1] to be tested for the definition of | |
convexity. | |
quasi: bool (default=False) | |
If True, use quasi-convex definition instead of convex. | |
random_state: None or int | |
Random seed to be used. | |
eps: float | |
Tolerance. | |
verbose: int | |
Verbosity level. | |
""" | |
for M, M1, M2, t in _gen_pairs(gen, max_iter, max_inner, | |
random_state, verbose): | |
if quasi: | |
# quasi-convex if f(M) <= max(f(M1), f(M2)) | |
# not quasi convex if f(M) > max(f(M1), f(M2)) | |
diff = func(M) - max(func(M1), func(M2)) | |
else: | |
# convex if f(M) <= t * f(M1) + (1 - t) * f(M2) | |
# non-convex if f(M) > t * f(M1) + (1 - t) * f(M2) | |
diff = func(M) - (t * func(M1) + (1 - t) * func(M2)) | |
if diff > eps: | |
# We found a counter-example. | |
print("not convex (diff=%f)" % diff) | |
return | |
# To be completely sure, this needs to be proved analytically. | |
print("could be convex") | |
if __name__ == "__main__": | |
def sqnorm(x): | |
return np.dot(x, x) | |
check_convex(sqnorm, gen=(5,), max_iter=10000, max_inner=10, | |
random_state=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment