Skip to content

Instantly share code, notes, and snippets.

@thomasweng15
Created July 27, 2022 03:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomasweng15/47a90ba4041f5442f5a99200e8bc3094 to your computer and use it in GitHub Desktop.
Save thomasweng15/47a90ba4041f5442f5a99200e8bc3094 to your computer and use it in GitHub Desktop.
wandb with pre-emption
import os
import sys
import pytorch_lightning as pl
import pytorch_lightning.utilities.seed as seed_utils
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
import hydra
import yaml
from dataset import GraspDataModule
from networks import Decoder
from pathlib import Path
import wandb
import logging
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
@hydra.main(config_path='config', config_name="train")
def main(cfg):
with open('.hydra/command.txt', 'w') as f:
command = 'python ' + ' '.join(sys.argv)
f.write(command)
seed_utils.seed_everything(cfg.seed, workers=True)
decoder = Decoder(**cfg.net)
dm = GraspDataModule(cfg)
logging.info('Data loaded.')
# Logging
logging.info(f'Load path: {cfg.load_path}')
loggers = [pl_loggers.CSVLogger(save_dir=cfg.csv_logs),
pl_loggers.TensorBoardLogger(save_dir=cfg.tb_logs, default_hp_metric=False)]
if cfg.wandb:
if not os.path.exists('run_id.yaml'):
run_id = wandb.util.generate_id()
with open('run_id.yaml', 'w') as f:
yaml.dump({'id': run_id}, f)
else:
with open('run_id.yaml', 'r') as f:
run_id = yaml.load(f, Loader=yaml.FullLoader)['id']
loggers.append(pl_loggers.WandbLogger(
id=run_id,
project="grasp-manifolds",
name=str(Path(os.getcwd()).name),
config=cfg,
resume='allow'))
checkpoint_callback = ModelCheckpoint(monitor='loss/val', save_last=True, save_top_k=-1, every_n_epochs=cfg.every_n_epochs)
trainer = pl.Trainer(gpus=cfg.gpu,
logger=loggers,
max_epochs=cfg.epochs,
log_every_n_steps=cfg.log_every_n_steps,
val_check_interval=cfg.val_check_interval,
callbacks=[checkpoint_callback],
resume_from_checkpoint=cfg.load_path,
profiler='simple',
num_sanity_val_steps=cfg.num_sanity_val_steps,
progress_bar_refresh_rate=0,
)
logging.info(f"Process ID {os.getpid()}")
trainer.fit(decoder, dm)
if cfg.wandb:
wandb.finish()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment