Skip to content

Instantly share code, notes, and snippets.

@aryan-jadon
Created September 15, 2022 04:49
Show Gist options
  • Save aryan-jadon/1e5205f9c7354fa7a9c03b21c1a49b25 to your computer and use it in GitHub Desktop.
Save aryan-jadon/1e5205f9c7354fa7a9c03b21c1a49b25 to your computer and use it in GitHub Desktop.
EarlyStoppingClass.py
import io
import copy
import torch
# Make use of a GPU or MPS (Apple) if one is available.
device = "mps" if getattr(torch,'has_mps',False) \
else "cuda" if torch.cuda.is_available() else "cpu"
class EarlyStopping():
def __init__(self, patience=5, min_delta=0, restore_best_weights=True):
self.patience = patience
self.min_delta = min_delta
self.restore_best_weights = restore_best_weights
self.best_model = None
self.best_loss = None
self.counter = 0
self.status = ""
def __call__(self, model, val_loss):
if self.best_loss == None:
self.best_loss = val_loss
self.best_model = copy.deepcopy(model)
elif self.best_loss - val_loss > self.min_delta:
self.best_loss = val_loss
self.counter = 0
self.best_model.load_state_dict(model.state_dict())
elif self.best_loss - val_loss < self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.status = f"Stopped on {self.counter}"
if self.restore_best_weights:
model.load_state_dict(self.best_model.state_dict())
return True
self.status = f"{self.counter}/{self.patience}"
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment