Skip to content

Instantly share code, notes, and snippets.

@burrussmp
Created May 5, 2020 20:16
Show Gist options
  • Save burrussmp/45d7dbe5f0c9831e20593f548f73aaf4 to your computer and use it in GitHub Desktop.
Save burrussmp/45d7dbe5f0c9831e20593f548f73aaf4 to your computer and use it in GitHub Desktop.
Train a model in Pytorch.
# hyper-parameters
batch_size = 32
learning_rate = 0.001
scheduler_step = 30
epochs = 190
gamma = 0.5
lr_scheduler_step_size = 12
adam_betas = (0.9,0.999)
use_cuda = torch.cuda.is_available()
torch.manual_seed(123456)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
# path to model
restart = False
pathToModel = os.path.join(BASEDIR,'weights2.pt')
torch.manual_seed(65675)
print('Initializing model')
model = UNet2D()
if (use_cuda):
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=lr_scheduler_step_size, gamma=gamma)
if os.path.isfile(pathToModel) and not restart:
print('Loading model.....')
model.load_state_dict(torch.load(pathToModel))
best_loss = torch.tensor(np.load(os.path.join(BASEDIR,'lowest.npy')).tolist()).to(device)
train_loss_save = np.load(os.path.join(BASEDIR,'train_loss.npy')).tolist()
val_loss_save = np.load(os.path.join(BASEDIR,'val_loss.npy')).tolist()
else:
best_loss = math.inf
train_loss_save = []
val_loss_save = []
for epoch in range(1, epochs + 1):
train_loss = train(model, device, train_loader, optimizer, epoch)
val_loss = validate(model, device, validation_loader)
train_loss_save.append(train_loss.cpu().data.numpy())
val_loss_save.append(val_loss.cpu().data.numpy())
if (val_loss < best_loss):
print('Loss improved from ', best_loss, 'to',val_loss,': Saving new model to',pathToModel)
best_loss = val_loss
torch.save(model.state_dict(), pathToModel)
scheduler.step()
np.save(os.path.join(BASEDIR,'val_loss.npy'),np.array(val_loss_save))
np.save(os.path.join(BASEDIR,'train_loss.npy'),np.array(train_loss_save))
np.save(os.path.join(BASEDIR,'lowest.npy'),best_loss.cpu().data.numpy())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment