Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
A small script to get numerical evidence that a function is convex
# Authors: Mathieu Blondel, Vlad Niculae
# License: BSD 3 clause
import numpy as np
def _gen_pairs(gen, max_iter, max_inner, random_state, verbose):
rng = np.random.RandomState(random_state)
# if tuple, interpret as randn
if isinstance(gen, tuple):
shape = gen
gen = lambda rng: rng.randn(*shape)
for it in range(max_iter):
if verbose:
print("iter", it + 1)
M1 = gen(rng)
M2 = gen(rng)
for t in np.linspace(0.01, 0.99, max_inner):
M = t * M1 + (1 - t) * M2
yield M, M1, M2, t
def check_convex(func, gen, max_iter=1000, max_inner=10,
quasi=False, random_state=None, eps=1e-9, verbose=0):
"""
Numerically check whether the definition of a convex function holds for the
input function.
If answers "not convex", a counter-example has been found and
the function is guaranteed to be non-convex. Don't lose time proving its
convexity!
If answers "could be convex", you can't completely rule out the possibility
that the function is non-convex. To be completely sure, this needs to be
proved analytically.
This approach was explained by S. Boyd in his convex analysis lectures at
Stanford.
Parameters
----------
func:
Function func(M) to be tested.
gen: tuple or function
If tuple, shape of the function argument M. Small arrays are recommended.
If function, function for generating M.
max_iter: int
Max number of trials.
max_inner: int
Max number of values between [0, 1] to be tested for the definition of
convexity.
quasi: bool (default=False)
If True, use quasi-convex definition instead of convex.
random_state: None or int
Random seed to be used.
eps: float
Tolerance.
verbose: int
Verbosity level.
"""
for M, M1, M2, t in _gen_pairs(gen, max_iter, max_inner,
random_state, verbose):
if quasi:
# quasi-convex if f(M) <= max(f(M1), f(M2))
# not quasi convex if f(M) > max(f(M1), f(M2))
diff = func(M) - max(func(M1), func(M2))
else:
# convex if f(M) <= t * f(M1) + (1 - t) * f(M2)
# non-convex if f(M) > t * f(M1) + (1 - t) * f(M2)
diff = func(M) - (t * func(M1) + (1 - t) * func(M2))
if diff > eps:
# We found a counter-example.
print("not convex (diff=%f)" % diff)
return
# To be completely sure, this needs to be proved analytically.
print("could be convex")
if __name__ == "__main__":
def sqnorm(x):
return np.dot(x, x)
check_convex(sqnorm, gen=(5,), max_iter=10000, max_inner=10,
random_state=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.