Skip to content

Instantly share code, notes, and snippets.

Last active July 12, 2023 18:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jiahao/837462f0d5f5a79152d60c58fd7ed268 to your computer and use it in GitHub Desktop.
Save jiahao/837462f0d5f5a79152d60c58fd7ed268 to your computer and use it in GitHub Desktop.
A small example of double descent in Julia
using Plots
using StatsPlots
using LinearAlgebra
using ClassicalOrthogonalPolynomials
using ProgressMeter
using Statistics
k = 15 # Size of training data
l = 15 # Size of test data
# Model for generating data
# f(x) = 2*x+cos(25*x) # This was in the MIT tutorial paper
# Some arbitrary polynomial
f(x) = 2*x+x^2 - 0.1*x^3+10*x^14
#1D Ackley function
# f(x) = -20*exp(-0.2*sqrt(0.5*x^2)) - exp(0.5*cos(2*pi*(x))) + exp(1) + 20
xlims = (-1, 1)
# Doing polynomial regression by hand
function regress(y, x,
p, # Order of polynomial
λ=√eps(one(eltype(x))) # Ridge regularization parameter
n = length(x)
V = zeros(n, p+1)
for i ∈ 1:n
for j∈1:p
V[i,j+1] = legendrep(j, x[i]) # Good old Legendre polynomials
return pinv(V, atol=λ)*y # Least squares solution with explicit regularization threshold
# Evaluate the function with coefficients c
function y(x, c)
p = length(c)
v = c[1]
for i in 2:p
v += c[i]*legendrep(i-1, x)
# Run the simulation
N = 100 # Samples
maxp = 100 # Largest order to sample
models = []
err_train = zeros(N, maxp+1)
err_test = zeros(N, maxp+1)
@showprogress for i in 1:N
# Sample train data
xk = (xlims[2]-xlims[1])*rand(k).+xlims[1]
yk = f.(xk)
# Sample test data
xl = (xlims[2]-xlims[1])*rand(l).+xlims[1]
yl = f.(xl)
for p in 0:maxp
c = regress(yk, xk, p, 0)
push!(models, c) # Save coefficients
err_train[i, p+1] = norm(yk.-(x->y(x,c)).(xk))^2/k
err_test[i, p+1] = norm(yl.-(x->y(x,c)).(xl))^2/l
# Generate plot
errorline(log10.(1:(maxp+1)), log10.(err_train), errorstyle=:ribbon, label="train", xlabel="log10(p)", ylabel="Log10 loss", ylim=(-7, 4))
errorline!(log10.(1:(maxp+1)), log10.(err_test), errorstyle=:ribbon, label="test")
vline!([log10(15)], label = "n=k")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment