Skip to content

Instantly share code, notes, and snippets.

@solalatus
Created January 6, 2023 12:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save solalatus/e6329879c5c9479ce60af9f6b0e22bb4 to your computer and use it in GitHub Desktop.
Save solalatus/e6329879c5c9479ce60af9f6b0e22bb4 to your computer and use it in GitHub Desktop.
NHiTS model loading and finetuning
#loading model
loaded_model = NHiTSModel.load_from_checkpoint("old_checkpoint","/content/", best=True)
#defining new parameters
MODEL_NAME = "finetune_only"
from torch.optim import RAdam
OPTIMIZER_CLS = RAdam
BASE_LR = 0.00001
EPOCHS = 25
#initializing new logger instance
from pytorch_lightning import loggers as pl_loggers
tb_logger = pl_loggers.TensorBoardLogger(save_dir="darts_logs/", name=MODEL_NAME, version="logs")
#overwriting model params
loaded_model.logger = tb_logger
loaded_model.n_epochs = EPOCHS
loaded_model.model_name = MODEL_NAME
loaded_model.load_ckpt_path = None
loaded_model.model_params["model_name"] = MODEL_NAME
loaded_model.model_params["n_epochs"] = EPOCHS
loaded_model.model_params["optimizer_scheduler_cls"] = OPTIMIZER_CLS
loaded_model.model_params["optimizer_kwargs"] = {"lr": BASE_LR}
loaded_model.model_params["lr_scheduler_cls"] = None
loaded_model.model_params["lr_scheduler_kwargs"] = {}
loaded_model.model.optimizer_kwargs = {"lr": BASE_LR}
loaded_model.model.optimizer_cls = OPTIMIZER_CLS
loaded_model.model.lr_scheduler_cls = None
loaded_model.model.lr_scheduler_kwargs = {}
loaded_model.model.n_epochs = EPOCHS
loaded_model.pl_module_params["optimizer_kwargs"] = {"lr": BASE_LR}
loaded_model.pl_module_params["optimizer_cls"] = OPTIMIZER_CLS
loaded_model.pl_module_params["lr_scheduler_cls"] = None
loaded_model.pl_module_params["lr_scheduler_kwargs"] = {}
loaded_model.trainer_params["logger"]=tb_logger
loaded_model.trainer_params["max_epochs"] = EPOCHS
loaded_model.trainer_params["val_check_interval"] = None
loaded_model.trainer_params["check_val_every_n_epoch"] = 1
loaded_model.trainer_params["default_root_dir"] = "/content/darts_logs/"
loaded_model.trainer_params["callbacks"][0].dirpath = "/content/darts_logs/"+MODEL_NAME+"/checkpoints"
#initialize new trainer
trainer = loaded_model._init_trainer(loaded_model.trainer_params)
#point to new trainer
loaded_model.trainer=trainer
loaded_model.model.trainer=trainer
#call trainer.strategy.setup
loaded_model.trainer.strategy.setup_optimizers(trainer)
#call model setup
loaded_model.model.setup("fit")
#FINALLY continue finetuning / new training
loaded_model.fit(hubs_train_data, val_series=hubs_valid_data, verbose=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment