Skip to content

Instantly share code, notes, and snippets.

@boegel
Created April 9, 2020 12:45
Show Gist options
  • Save boegel/e0c303497ba6275423b39ea1c10c7d73 to your computer and use it in GitHub Desktop.
Save boegel/e0c303497ba6275423b39ea1c10c7d73 to your computer and use it in GitHub Desktop.
add --multi-gpu option to use multiple GPUs for training, see https://github.com/nanoporetech/bonito/issues/13
--- ont-bonito-0.1.0/bonito/train.py.orig 2020-02-19 00:48:24.000000000 +0100
+++ ont-bonito-0.1.0/bonito/train.py 2020-04-09 13:49:42.964414000 +0200
@@ -69,6 +69,12 @@
print("[error]: Cannot use AMP: Apex package needs to be installed manually, See https://github.com/NVIDIA/apex")
exit(1)
+ if args.multi_gpu:
+ from torch.nn import DataParallel
+ model = DataParallel(model)
+ model.stride = model.module.stride
+ model.alphabet = model.module.alphabet
+
schedular = CosineAnnealingLR(optimizer, args.epochs * len(train_loader))
for epoch in range(1, args.epochs + 1):
@@ -85,7 +91,13 @@
epoch, workdir, val_loss, val_mean, val_median
))
- torch.save(model.state_dict(), os.path.join(workdir, "weights_%s.tar" % epoch))
+ if args.multi_gpu:
+ state = model.module.state_dict()
+ else:
+ state = model.state_dict()
+
+ # save optim state
+ torch.save(state, os.path.join(workdir, "weights_%s.tar" % epoch))
with open(os.path.join(workdir, 'training.csv'), 'a', newline='') as csvfile:
csvw = csv.writer(csvfile, delimiter=',')
if epoch == 1:
@@ -116,6 +128,7 @@
parser.add_argument("--batch", default=32, type=int)
parser.add_argument("--chunks", default=1000000, type=int)
parser.add_argument("--validation_split", default=0.99, type=float)
+ parser.add_argument("--multi-gpu", action="store_true", default=False)
parser.add_argument("--amp", action="store_true", default=False)
parser.add_argument("-f", "--force", action="store_true", default=False)
return parser
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment