Skip to content

Instantly share code, notes, and snippets.

@tamuhey
Created December 26, 2018 02:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tamuhey/3b490998a4ea46e341870cd78f60d1fa to your computer and use it in GitHub Desktop.
Save tamuhey/3b490998a4ea46e341870cd78f60d1fa to your computer and use it in GitHub Desktop.
Pytorch Early Stop Class
import torch as to
import torch.nn as nn
class EarlyStop:
"""Check early stop, and save best params
Examples:
>>> e=EarlyStop(10)
>>> model=nn.Linear(3,5)
>>> x=to.rand(3,3)
>>> output=model(x).sum()
>>> e(output, model)
False
>>> e(output, model, another_attribute=10)
False
>>> e.another_attribute # save the value when model has best params
10
"""
def __init__(self, num_patience: int = 50):
self.num_patience = num_patience
self.count_early_stop = 0
self.best_value = float("inf")
self.best_model_state_dict = None
self.cpu = to.device("cpu")
def __call__(self, value: float, model: nn.Module, **kwargs) -> bool:
if self.best_value < value:
if self.count_early_stop > self.num_patience:
return True
self.count_early_stop += 1
else:
self.best_value = value
self.best_model_state_dict = self._state_dict_to_cpu(model.state_dict())
self.count_early_stop = 0
for k, v in kwargs.items():
self.__setattr__(k, v)
return False
def state_dict(self):
return self.best_model_state_dict
def _state_dict_to_cpu(self, state_dict: dict):
for k, v in state_dict.items():
state_dict[k] = v.to(self.cpu)
return state_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment