Skip to content

Instantly share code, notes, and snippets.

@syockit
Last active September 19, 2021 09:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save syockit/8a69778b7217689cb371c231f76a3150 to your computer and use it in GitHub Desktop.
Save syockit/8a69778b7217689cb371c231f76a3150 to your computer and use it in GitHub Desktop.
Benchmark leave-k-out cross-validation using naive approach vs Woodbury formula
#https://stats.stackexchange.com/a/513787/274906
# Find cross-validation test error for linear regression (with L2 regularisation)
# Requires ipython for %timeit
import numpy as np
from matplotlib import pyplot as plt
def naive(X, y, l, k):
idx = np.repeat(False, len(X))
ss = 0
for i in range(k):
idx[:] = False
idx[i::k] = True
XTX = X[~idx].T@X[~idx]
np.fill_diagonal(XTX, np.diagonal(XTX) + l)
C = np.linalg.solve(XTX, X[~idx].T@y[~idx])
ss += np.sum((X[idx]@C - y[idx])**2)
return np.sqrt(ss/len(X))
def woodbury(X, y, l, k):
idx = np.repeat(False, len(X))
ss = 0
w, U = np.linalg.eigh(X.T@X)
P = X@U
PD = P/(w+l)
res = y - PD@(P.T@y)
for i in range(k):
H = -PD[i::k]@P[i::k].T
np.fill_diagonal(H,np.diagonal(H) + 1)
rescv = np.linalg.solve(H,res[i::k])
ss += np.sum(rescv**2)
return np.sqrt(ss/len(X))
def trueY(x):
return 1.3 + 0.4*x + 0.9*x**2 + 0.2*x**3 + 5*np.cos(4*x/np.pi)
def setup_input(n,m):
x = np.linspace(-7,5,n)
X = np.vstack(x[:,None]**np.arange(m))
y0 = trueY(x)
y1 = y0 + np.random.normal(0,3,size=len(x))
return X,y1
X,y1 = setup_input(100,8)
times = {
naive: [],
woodbury: [],
}
ks = [2,5,10,20,50,100]
for k in ks:
for fun in times.keys():
time = %timeit -o fun(X,y1,1.0,k)
times[fun].append(time)
labels = ["naive", "woodbury"]
xk = np.log10(ks)
for label, fun in zip(labels, times.keys()):
ty0, ty1, ty2 = np.array([[t.average, t.average-t.stdev, t.average+t.stdev] for t in times[fun]]).T
plt.plot(xk, ty0,"o-", label=label)
plt.fill_between(xk, ty1, ty2, alpha=0.2)
fig = plt.gcf()
fig.set_size_inches((4.5,4.5))
ax = plt.gca()
ax.set_xticks(xk)
ax.set_xticklabels(ks)
plt.title(f"{X.shape[0]} x {X.shape[1]} matrix")
plt.xlabel("folds")
plt.ylabel("time [s]")
plt.legend()
plt.yscale("log")
plt.savefig(f"bench-{X.shape[0]}x{X.shape[1]}.png", transparent=False, bbox_inches='tight', facecolor="white")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment