Skip to content

Instantly share code, notes, and snippets.

@AahanSingh
Last active August 9, 2021 10:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AahanSingh/d69348256e99fd2029d556b68aa6a6df to your computer and use it in GitHub Desktop.
Save AahanSingh/d69348256e99fd2029d556b68aa6a6df to your computer and use it in GitHub Desktop.
Training step
def train_step(model, train_loader, device, optimizer, epoch, batch_size):
# training
avg_loss = 0.0
start_time = time.time()
for batch_no, (x, target) in enumerate(train_loader):
x, target = x.to(device), target.to(device)
# CLEAR GRADIENT TO PREVENT ACCUMULATION
optimizer.zero_grad()
# COMPUTE OUTPUT
out, recon, mask = model(x, target)
# COMPUTE LOSS
loss = CapsuleLoss(out, mask, x, recon)
# FIND GRADIENTS
loss.backward()
# UPDATE WEIGHTS
optimizer.step()
# OBTAIN ACCURACY ON BATCH
logits = F.softmax(out.norm(dim=-1), dim=-1)
_, pred_label = torch.max(logits.data, dim=1)
pred_label = pred_label.to(device)
train_acc = (pred_label == target.data).double().sum()
logging.info(
"Epoch = {0}\t Batch n.o.={1}\t Loss={2:.4f}\t Batch_acc={3:.4f}\r".format(
epoch, batch_no, loss.item(), train_acc / batch_size
)
)
mlflow.log_metric(
"Batch Accuracy",
train_acc.item() / batch_size,
step=math.ceil(epoch * len(train_loader) / batch_size) + batch_no,
)
mlflow.log_metric(
"Loss",
loss.item(),
step=math.ceil(epoch * len(train_loader) / batch_size) + batch_no,
)
avg_loss += loss.item()
total_time = time.time() - start_time
avg_loss /= len(train_loader)
logging.info("\nAvg Loss={0:.4f}\t time taken = {1:0.2f}".format(avg_loss, total_time))
mlflow.log_metric("Average Loss", avg_loss, step=epoch)
mlflow.log_metric("Time Taken", total_time, step=epoch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment