Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Last active March 8, 2022 22:30
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ogrisel/1b430b2bf1e83173f6061676c62b9f18 to your computer and use it in GitHub Desktop.
Save ogrisel/1b430b2bf1e83173f6061676c62b9f18 to your computer and use it in GitHub Desktop.
Spectrum of the extended feature Gram matrix of an single hidden layer ReLU MLP
"""Empirical evaluation of the extended feature Gram matrix of a ReLU MLP
Here we try to estimate the spectrum of the H^\infty matrix as defined in:
Gradient Descent Provably Optimizes Over-parameterized Neural Networks (2018)
Simon S. Du, Xiyu Zhai, Barnabas Poczos, Aarti Singh
https://arxiv.org/abs/1810.02054
Theorem 4.1 relies on the assumption that H^\infty has a strictly positive
minimum eigenvalue. The following computes an estimate of this eigenvalue
for a toy digits dataset with 1797 samples of 64 dimensions. In this case
we find that this assumption holds with \lambda_0 > 1.3e-2.
"""
from time import time
import numpy as np
import numba
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.preprocessing import normalize
# Workaround: https://github.com/numba/numba/issues/3341
numba.config.THREADING_LAYER = 'workqueue'
@numba.jit(parallel=True)
def compute_h_inf(X, n_iter=int(1e4), seed=0):
n_samples, n_features = X.shape
H_inf = np.zeros(shape=(n_samples, n_samples), dtype=X.dtype)
W = np.random.RandomState(seed).randn(n_iter, n_features)
W_X = W @ X.T > 0
Gram = X @ X.T
# Could be implemented with einsum as follows:
# np.einsum('ij,ki,kj->ij', Gram, W_X, W_X) / n_iter
# but using explicit numba loops makes it possible to use multi-threading.
scale = 1. / n_iter
for k in range(n_iter):
for i in numba.prange(n_samples):
for j in range(n_samples):
H_inf[i, j] += scale * Gram[i, j] * W_X[k, i] * W_X[k, j]
return H_inf
digits = load_digits()
X, y = digits.data, digits.target
n_samples, n_features = X.shape
print(f"Loaded digits data (n_samples={n_samples}, n_features={n_features})")
print("Normalizing X...")
X = normalize(X)
print("Computing the spectrum of the data Gram matrix", end="", flush=True)
t0 = time()
eigvals_gram = np.linalg.eigvalsh(X @ X.T)
print(f" done in {time() - t0:0.3f}s")
print(f"lambda_min(XX^T): {eigvals_gram.min():0.3e}")
# We only have 64 features, so the rank of this Gram matrix is bounded by 64.
fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True, constrained_layout=True,
figsize=(12, 8))
ax0.semilogy(eigvals_gram[::-1])
ax0.set_title('Spectrum of the data Gram matrix $XX^T$')
ax0.set_ylabel('Eigenvalue (logscale)')
for n_iter in [1_000, 10_000, 100_000]:
print(f"Computing extended feature Gram H_inf with n_iter={n_iter}...",
end="", flush=True)
t0 = time()
H_inf = compute_h_inf(X, n_iter=n_iter)
print(f" done in {time() - t0:0.3f}s")
print(f"H_inf.shape={H_inf.shape}")
print("Checking that H_inf is symmetric...", end="", flush=True)
np.testing.assert_allclose(H_inf, H_inf.T)
print(" ok")
print("Computing the spectrum of H_inf...", end="", flush=True)
t0 = time()
eigvals = np.linalg.eigvalsh(H_inf)
print(f" done in {time() - t0:0.3f}s")
print(f"lambda_min(H_inf): {eigvals.min():0.3e}")
ax1.semilogy(eigvals[::-1])
ax1.set_title('Spectrum of the extended feature Gram matrix: $H^\infty$')
ax1.set_ylabel('Eigenvalue (logscale)')
ax1.set_xlabel('Eigenvalue rank')
plt.show()
Loaded digits data (n_samples=1797, n_features=64)
Normalizing X...
Computing the spectrum of the data Gram matrix done in 0.763s
lambda_min(XX^T): -2.796e-13
Computing extended feature Gram H_inf with n_iter=1000... done in 2.793s
H_inf.shape=(1797, 1797)
Checking that H_inf is symmetric... ok
Computing the spectrum of H_inf... done in 0.592s
lambda_min(H_inf): 3.083e-03
Computing extended feature Gram H_inf with n_iter=10000... done in 18.585s
H_inf.shape=(1797, 1797)
Checking that H_inf is symmetric... ok
Computing the spectrum of H_inf... done in 0.578s
lambda_min(H_inf): 1.112e-02
Computing extended feature Gram H_inf with n_iter=100000... done in 209.125s
H_inf.shape=(1797, 1797)
Checking that H_inf is symmetric... ok
Computing the spectrum of H_inf... done in 0.607s
lambda_min(H_inf): 1.354e-02
@ogrisel
Copy link
Author

ogrisel commented Oct 7, 2018

Here are the spectrums (logscale):

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment