Created
July 29, 2019 01:51
-
-
Save williamFalcon/603d7183347591c80dc6173227d6705b to your computer and use it in GitHub Desktop.
PTL Trainer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning import Trainer | |
from test_tube import Experiment | |
model = CoolModel() | |
exp = Experiment(save_dir=os.getcwd()) | |
# train on cpu using only 10% of the data (for demo purposes) | |
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1) | |
# train on 4 gpus | |
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3]) | |
# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job) | |
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3, 4, 5, 6, 7], nb_gpu_nodes=4) | |
# train (1 epoch only here for demo) | |
trainer.fit(model) | |
# view tensorflow logs | |
print(f'View tensorboard logs by running\ntensorboard --logdir {os.getcwd()}') | |
print('and going to http://localhost:6006 on your browser') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment