Skip to content

Instantly share code, notes, and snippets.

@souravs17031999
Created August 24, 2019 05:27
Show Gist options
  • Save souravs17031999/4b1cc2cae0c2d4e0bd4229356a6d88a0 to your computer and use it in GitHub Desktop.
Save souravs17031999/4b1cc2cae0c2d4e0bd4229356a6d88a0 to your computer and use it in GitHub Desktop.
def train_and_test(e):
epochs = e
train_losses , test_losses = [] , []
valid_loss_min = np.Inf
model.train()
print("training started...")
for epoch in range(epochs):
running_loss = 0
batch = 0
#scheduler.step()
for images , labels in federated_train_loader:
model.send(images.location)
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
outputs = F.log_softmax(outputs, dim=1)
loss = F.nll_loss(outputs, labels)
loss.backward()
optimizer.step()
model.get()
running_loss += loss.get().item()
batch += 1
if batch % 200 == 0:
print(f" epoch {epoch + 1} batch {batch} completed")
test_loss = 0
accuracy = 0
with torch.no_grad():
model.eval()
for images , labels in validloader:
images, labels = images.to(device), labels.to(device)
logps = model(images)
logps = F.log_softmax(logps, dim=1)
test_loss += F.nll_loss(logps, labels)
ps = torch.exp(logps)
top_p , top_class = ps.topk(1,dim=1)
equals = top_class == labels.view(*top_class.shape)
accuracy += torch.mean(equals.type(torch.FloatTensor))
train_losses.append(running_loss/len(federated_train_loader))
test_losses.append(test_loss/len(validloader))
print("Epoch: {}/{}.. ".format(epoch+1, epochs),"Training Loss: {:.3f}.. ".format(running_loss/len(federated_train_loader)),"Valid Loss: {:.3f}.. ".format(test_loss/len(validloader)),
"Valid Accuracy: {:.3f}".format(accuracy/len(validloader)))
model.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment