Skip to content

Instantly share code, notes, and snippets.

View Rishav09's full-sized avatar

Rishav Sapahia Rishav09

View GitHub Profile
model.train()
for batch_idx, (data, target) in enumerate(loader['train']):
# move to GPU
if torch.cuda.is_available():
data, target = data.to('cuda', non_blocking=True), target.to('cuda', non_blocking = True) # noqa
optimizer.zero_grad()
output = model(data)
pred = torch.argmax(output, dim=1)
# pred = pred.to(torch.double)
# target = target.to(torch.double)