Skip to content

Instantly share code, notes, and snippets.

@abhishekkrthakur
Created June 11, 2019 19:57
Show Gist options
  • Save abhishekkrthakur/b2eaebf18d71d63b6b5d3fa8b0ebf0dd to your computer and use it in GitHub Desktop.
Save abhishekkrthakur/b2eaebf18d71d63b6b5d3fa8b0ebf0dd to your computer and use it in GitHub Desktop.
def train_model(model,
data_loader,
dataset_size,
optimizer,
scheduler,
num_epochs):
criterion = nn.BCEWithLogitsLoss()
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
scheduler.step()
model.train()
running_loss = 0.0
# Iterate over data.
for bi, d in enumerate(data_loader):
inputs = d["image"]
labels = d["labels"]
inputs = inputs.to(device, dtype=torch.float)
labels = labels.to(device, dtype=torch.float)
optimizer.zero_grad()
with torch.set_grad_enabled(True):
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
epoch_loss = running_loss / dataset_size
print('Loss: {:.4f}'.format(epoch_loss))
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment