Skip to content

Instantly share code, notes, and snippets.

@DicksonWu654
Last active January 18, 2021 00:11
Show Gist options
  • Save DicksonWu654/37e4c00777adc874e851c2e0a38e85c8 to your computer and use it in GitHub Desktop.
Save DicksonWu654/37e4c00777adc874e851c2e0a38e85c8 to your computer and use it in GitHub Desktop.
def train(model, iterator, optimizer, loss_func):
#Initializing them
epoch_loss = 0
epoch_accuracy = 0
#Gets the model in training mode
model.train()
for batch in iterator:
#Set the gradietns to 0
optimizer.zero_grad()
#Get the text and number of words to pass into model
text, text_len = batch.sequence
#Because it throws a fit otherwise
text_len = text_len.cpu()
#Put it into model
preds = model(text, text_len).squeeze()
#Turn the target into one hot encoding
onehot = F.one_hot(batch.classification,num_classes=20)
#Find the loss
loss = loss_func(preds, onehot.float())
#Find the accuracy
acc = accuracy(torch.argmax(preds, dim = 1).float(), batch.classification.float())
#Backprop the loss and find the gradients
loss.backward()
#Then update all the weights
optimizer.step()
#Add in the loss and the accuracy
epoch_loss += loss.item()
epoch_accuracy += acc.item()
#Return the loss and the accuracy
return epoch_loss / len(iterator), epoch_accuracy / len(iterator)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment