Skip to content

Instantly share code, notes, and snippets.

@YashashGaurav
Last active September 19, 2022 22:47
Show Gist options
  • Save YashashGaurav/4db384fe3c255d6d52ec6a3c0b88e6fe to your computer and use it in GitHub Desktop.
Save YashashGaurav/4db384fe3c255d6d52ec6a3c0b88e6fe to your computer and use it in GitHub Desktop.
best version of model logging system that I have built for GDrive - Kinda depends on wandb for naming - but easily customizable to any to other service provider.
# Model Logging setup
import os
from os.path import isfile, join
def log_checkpoint(
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: object, # Cuz the base class seemed like a protected instance
metric: float = None,
track_greater_metric: bool = False
):
"""Logs a checkpoint (a torch model) only if the metric passed is one of
top 3 (to save space of course) metrics in the directory that
the models are being saved in.
Format: if wandb instance exists the checkpoints are stored at:
checkpoint directory ->
{args["checkpoint_path"]}/trainings/{wandb.run.name}/
else,
checkpoint directory -> {args["checkpoint_path"]}/trainings/temp/
where the file is named:
{metric}_{project_name}_checkpoint.h5
if metric is not provided, we overwrite:
{project_name}_checkpoint.h5 given above checkpoint directory
Beyond the function params, the function also expects an 'args' dictionary
that has:
- args['checkpoint_epoch_step'] = number of epochs after which we want to
try to log a checkpoint. This can help save time if you want to skip a few
epochs and then log your checkpoint.
- args['project_name'] = a project name for the experiments that we are
running. We add this detail to the checkpoint file saved.
and 'hyper_params' dictionary where:
hyper_params['epochs'] = total number of epochs so that we definitely
log the last epoch's checkpoint.
:param epoch: current epoch index
:type epoch: int
:param model: model that you are trying to log using torch.save()
:type model: torch.nn.Module
:param optimizer: optimizer to be stored with its state
:type optimizer: torch.optim.Optimizer
:param lr_scheduler: LR Scheduler used for the experiment.
:type lr_scheduler: _LRScheduler - github/torch/optim/lr_scheduler.py#L25
:param metric: value that you want to log the model based on (like Val acc),
if not provided we save the model
checkpoints by, defaults to None
:type metric: float, optional
:param track_greater_metric: To be set to true if higher metric passed
means that the model is better
:type track_greater_metric: bool
"""
if (
epoch % args["checkpoint_epoch_step"] == 0
or epoch == hyper_params["epochs"]
):
state = {
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
}
# we add a project name tag to the checkpoints saved.
project_name = args["project_name"]
# create path
if wandb:
check_point_dir = (
args["checkpoint_path"] + "trainings/" + wandb.run.name
)
else:
check_point_dir = args["checkpoint_path"] + "trainings/temp/"
if not os.path.exists(check_point_dir):
os.makedirs(check_point_dir)
onlyfiles = [
float(f.split(f"_{project_name}_checkpoint.h5")[0])
for f in os.listdir(check_point_dir)
if isfile(join(check_point_dir, f))
and f"_{project_name}_checkpoint.h5" in f
]
if metric != None:
checkpoint_file_path = (
check_point_dir + f"/{metric}_{project_name}_checkpoint.h5"
)
if len(onlyfiles) >= 3:
if track_greater_metric and metric > sorted(onlyfiles, reverse=True)[2]:
torch.save(state, checkpoint_file_path)
os.remove(
check_point_dir
+ f"/{sorted(onlyfiles, reverse=True)[3]}_{project_name}_checkpoint.h5"
)
elif (not track_greater_metric) and metric < sorted(onlyfiles)[2]:
torch.save(state, checkpoint_file_path)
os.remove(
check_point_dir
+ f"/{sorted(onlyfiles)[3]}_{project_name}_checkpoint.h5"
)
elif len(onlyfiles) < 3:
torch.save(state, checkpoint_file_path)
else:
checkpoint_file_path = (
check_point_dir + f"/{project_name}_checkpoint.h5"
)
torch.save(state, checkpoint_file_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment