Skip to content

Instantly share code, notes, and snippets.

@ryanpeach
Last active November 28, 2023 08:22
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save ryanpeach/9ef833745215499e77a2a92e71f89ce2 to your computer and use it in GitHub Desktop.
Save ryanpeach/9ef833745215499e77a2a92e71f89ce2 to your computer and use it in GitHub Desktop.
A Python 3 implementation of the early stopping algorithm described in the Deep Learning book by Ian Goodfellow. Untested, needs basic syntax correction.
""" Python 3 implementation of Deep Learning book early stop algorithm.
@book{Goodfellow-et-al-2016,
title={Deep Learning},
author={Ian Goodfellow and Yoshua Bengio and Aaron Courville},
publisher={MIT Press},
note={\url{http://www.deeplearningbook.org}},
year={2016}
}
"""
import numpy as np
class Network(object):
""" Sometimes a Network object is described, use this definition. """
def train(x: iterable, y: iterable) -> None:
raise NotImplementedError()
def error(x: iterable, y: iterable) -> float:
raise NotImplementedError()
def __call__(x: iterable) -> iterable:
raise NotImplementedError()
def clone() -> Network:
raise NotImplementedError()
from exceptions import Warning
class ConvergenceWarning(Warning):
""" Used to indicate an infinite loop reached max iteration
and the underlying function failed to converge. """
pass
def early_stopping(theta0, (x_train, y_train), (x_valid, y_valid), n = 1, p = 100):
""" The early stopping meta-algorithm for determining the best amount of time to train.
REF: Algorithm 7.1 in deep learning book.
Parameters:
n: int; Number of steps between evaluations.
p: int; "patience", the number of evaluations to observe worsening validataion set.
theta0: Network; initial network.
x_train: iterable; The training input set.
y_train: iterable; The training output set.
x_valid: iterable; The validation input set.
y_valid: iterable; The validation output set.
Returns:
theta_prime: Network object; The output network.
i_prime: int; The number of iterations for the output network.
v: float; The validation error for the output network.
"""
# Initialize variables
theta = theta0.clone() # The active network
i = 0 # The number of training steps taken
j = 0 # The number of evaluations steps since last update of theta_prime
v = np.inf # The best evaluation error observed thusfar
theta_prime = theta.clone() # The best network found thusfar
i_prime = i # The index of theta_prime
while j < p:
# Update theta by running the training algorithm for n steps
for _ in range(n):
theta.train(x_train, y_train)
# Update Values
i += n
v_new = theta.error(x_valid, y_valid)
# If better validation error, then reset waiting time, save the network, and update the best error value
if v_new < v:
j = 0
theta_prime = theta.clone()
i_prime = i
v = v_new
# Otherwise, update the waiting time
else:
j += 1
return theta_prime, i_prime, v
def early_stopping_retrain((x_train, y_train), theta0, split_percent = .8, n = 1, p = 100):
""" Meta algorithm using early stopping to determine at what objective value we start to overfit,
then retraining on all the data.
REF: Algorithm 7.2 in deep learning book.
Parameters:
n: int; Number of steps between evaluations.
p: int; "patience", the number of evaluations to observe worsening validataion set.
split_percent: float; the percentage of subtrain to validation set length by which to split the given training sets.
theta0: Network; initial network.
x_train: iterable; training set input.
y_train: iterable; training set output.
Returns:
theta_prime: Network object; The output network.
i_prime: int; The number of iterations for the output network.
"""
# Split x_train and y_train into x_subtrain, x_valid and y_subtrain, y_valid
cut = int(len(x_train)*split_percent)
x_subtrain, x_valid = x_train[:cut], x_train[cut:]
y_subtrain, y_valid = y_train[:cut], y_train[cut:]
# Run early_stopping
_, i_prime, _ = early_stopping(theta0.clone(), (x_subtrain, y_subtrain), (x_valid, y_valid), n = n, p = p)
# Reset theta and train for the found number of steps
theta = theta0.clone()
for _ in range(i_prime):
theta.train(x_train, y_train)
return theta, i_prime
def early_stopping_continuous((x_train, y_train), theta0, split_percent = .8, n = 1, p = 100, max_iteration = 1e4):
""" Meta algorithm using early stopping to determine at what objective value we start to overfit,
then continue training until that value is reached.
REF: Algorithm 7.3 in deep learning book.
Parameters:
n: int; Number of steps between evaluations.
p: int; "patience", the number of evaluations to observe worsening validataion set.
split_percent: float; the percentage of subtrain to validation set length by which to split the given training sets.
theta0: Network; initial network.
x_train: iterable; training set input.
y_train: iterable; training set output.
max_iteration: int; maximum number of iterations to continue training, raises Exception
Returns:
theta_prime: Network object; The output network.
v_new: float; The validation error for the output network.
Raises:
ConvergenceWarning: Does not converge to found optimum.
"""
# Split x_train and y_train into x_subtrain, x_valid and y_subtrain, y_valid
cut = int(len(x_train)*split_percent)
x_subtrain, x_valid = x_train[:cut], x_train[cut:]
y_subtrain, y_valid = y_train[:cut], y_train[cut:]
# Run early_stopping
theta_prime, i_prime, v = early_stopping(theta0.clone(), (x_subtrain, y_subtrain), (x_valid, y_valid), n = n, p = p)
# Train on x_train and y_train until value v is reached
for _ in range(max_iteration):
# Train for n iterations
for _ in range(n):
theta_prime.train(x_train, y_train)
# Update error
v_new = theta_prime.error()
# If at validation error, then finish training
if v_new <= v:
return theta_prime, v_new
# if training never completes before max_iteration reached, raise a warning
raise ConvergenceWarning("early_stopping_continuous failed to converge.")
return theta_prime, v_new
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment