Created
July 27, 2022 03:07
-
-
Save thomasweng15/47a90ba4041f5442f5a99200e8bc3094 to your computer and use it in GitHub Desktop.
wandb with pre-emption
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import pytorch_lightning as pl | |
import pytorch_lightning.utilities.seed as seed_utils | |
from pytorch_lightning import loggers as pl_loggers | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
import hydra | |
import yaml | |
from dataset import GraspDataModule | |
from networks import Decoder | |
from pathlib import Path | |
import wandb | |
import logging | |
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p') | |
@hydra.main(config_path='config', config_name="train") | |
def main(cfg): | |
with open('.hydra/command.txt', 'w') as f: | |
command = 'python ' + ' '.join(sys.argv) | |
f.write(command) | |
seed_utils.seed_everything(cfg.seed, workers=True) | |
decoder = Decoder(**cfg.net) | |
dm = GraspDataModule(cfg) | |
logging.info('Data loaded.') | |
# Logging | |
logging.info(f'Load path: {cfg.load_path}') | |
loggers = [pl_loggers.CSVLogger(save_dir=cfg.csv_logs), | |
pl_loggers.TensorBoardLogger(save_dir=cfg.tb_logs, default_hp_metric=False)] | |
if cfg.wandb: | |
if not os.path.exists('run_id.yaml'): | |
run_id = wandb.util.generate_id() | |
with open('run_id.yaml', 'w') as f: | |
yaml.dump({'id': run_id}, f) | |
else: | |
with open('run_id.yaml', 'r') as f: | |
run_id = yaml.load(f, Loader=yaml.FullLoader)['id'] | |
loggers.append(pl_loggers.WandbLogger( | |
id=run_id, | |
project="grasp-manifolds", | |
name=str(Path(os.getcwd()).name), | |
config=cfg, | |
resume='allow')) | |
checkpoint_callback = ModelCheckpoint(monitor='loss/val', save_last=True, save_top_k=-1, every_n_epochs=cfg.every_n_epochs) | |
trainer = pl.Trainer(gpus=cfg.gpu, | |
logger=loggers, | |
max_epochs=cfg.epochs, | |
log_every_n_steps=cfg.log_every_n_steps, | |
val_check_interval=cfg.val_check_interval, | |
callbacks=[checkpoint_callback], | |
resume_from_checkpoint=cfg.load_path, | |
profiler='simple', | |
num_sanity_val_steps=cfg.num_sanity_val_steps, | |
progress_bar_refresh_rate=0, | |
) | |
logging.info(f"Process ID {os.getpid()}") | |
trainer.fit(decoder, dm) | |
if cfg.wandb: | |
wandb.finish() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment