Skip to content

Instantly share code, notes, and snippets.

@mblondel
Last active April 14, 2023 15:11
Show Gist options
  • Save mblondel/9097958637e762586c589dcd3cfc8234 to your computer and use it in GitHub Desktop.
Save mblondel/9097958637e762586c589dcd3cfc8234 to your computer and use it in GitHub Desktop.
# Mathieu Blondel, 2022
# BSD license
import numpy as np
from scipy.ndimage import convolve1d
from sklearn.metrics.pairwise import euclidean_distances
def smoothed_conjugate_conv(f, x, eps=1.0):
"""
Compute f* via convolution.
f: array containing the values of f.
x: grid on which f has been evaluated.
eps: regularization strength.
The grid on which f* is evaluated is assumed to be the same.
"""
x = x.ravel()
h = np.exp((0.5 * x ** 2 - f) / eps)
g = np.exp(-0.5 * x ** 2 / eps)
Kh = convolve1d(g, h, mode='constant')
return eps * np.log(Kh) + 0.5 * x ** 2
def smoothed_conjugate_dot(f, x, y=None, eps=1.0):
"""
Compute f* via matrix product.
f: array containing the values of f.
x: grid on which f has been evaluated.
y: grid on which to evaluate f*. If None, use x.
eps: regularization strength.
"""
if y is None:
y = x
h = np.exp((0.5 * x ** 2 - f) / eps)
D = euclidean_distances(y.reshape(-1, 1), x.reshape(-1, 1), squared=True)
K = np.exp(-D / (2 * eps))
Kh = np.dot(K, h)
return eps * np.log(Kh) + 0.5 * y ** 2
if __name__ == '__main__':
import matplotlib.pyplot as plt
smoothed_conjugate = smoothed_conjugate_conv
#smoothed_conjugate = smoothed_conjugate_dot
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
# Smoothed relu.
x = np.linspace(-5, 5, 500)
f = np.where(np.logical_and(0 <= x, x <= 1), 0.0, np.inf) # Relu's conjugate
ax1.plot(x, np.maximum(x, 0), c="k", lw=3)
ax1.plot(x, smoothed_conjugate(f, x, eps=0.1), ls="--", lw=3)
ax1.set_title("Smoothed relu", size=16)
# Convex envelope / biconjugate.
x = np.linspace(-3, 3, 500)
f = x ** 2 + 0.3 * np.sin(6 * np.pi * x)
ax2.plot(x, f, c="k", lw=3)
conj = smoothed_conjugate(f, x, eps=0.01)
biconj = smoothed_conjugate(conj, x, eps=0.01)
ax2.plot(x, biconj, lw=3)
ax2.set_title("Smoothed convex envelope (biconjugate)", size=16)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment