Skip to content

Instantly share code, notes, and snippets.

@henry16lin
Created February 13, 2020 16:39
Show Gist options
  • Save henry16lin/6fae622a38f0815c2e699739520ad7b2 to your computer and use it in GitHub Desktop.
Save henry16lin/6fae622a38f0815c2e699739520ad7b2 to your computer and use it in GitHub Desktop.
RNN_training
net.to(device)
loss_function = nn.CrossEntropyLoss() # ~ nn.LogSoftmax()+nn.NLLLoss()
optimizer = optim.Adam(net.parameters())
def train(net,num_epochs,loss_function,optimizer,train_iter,val_iter):
for epoch in range(num_epochs):
start = time.time()
train_loss, val_losses = 0, 0
train_acc, val_acc = 0, 0
n, m = 0, 0
net.train()
for feature, label in train_iter:
n += 1
optimizer.zero_grad()
feature = Variable(feature.to(device))
label = Variable(label.to(device))
score = net(feature)
loss = loss_function(score, label)
loss.backward()
optimizer.step()
train_acc += accuracy_score(torch.argmax(score.cpu().data,dim=1), label.cpu())
train_loss += loss
with torch.no_grad():
net.eval()
for val_feature, val_label in val_iter:
m += 1
val_feature = val_feature.to(device)
val_label = val_label.to(device)
val_score = net(val_feature)
val_loss = loss_function(val_score, val_label)
val_acc += accuracy_score(torch.argmax(val_score.cpu().data,dim=1), val_label.cpu())
val_losses += val_loss
runtime = time.time() - start
print('epoch: %d, train loss: %.4f, train acc: %.2f, val loss: %.4f, val acc: %.2f, time: %.2f' %
(epoch, train_loss.data/n, train_acc/n, val_losses.data/m, val_acc/m, runtime))
#save final model
state = {
'epoch': epoch,
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(state, os.path.join(model_save_path,'last_model.pt'))
def predict(net,test_iter):
#state = torch.load(os.path.join(cwd,'checkpoint','epoch10_maxlen300_embed200.pt'),map_location=torch.device('cpu'))
#net.load_state_dict(state['state_dict'])
pred_list = []
true_list = []
softmax = nn.Softmax(dim=1)
with torch.no_grad():
net.eval()
for batch,label in test_iter:
output = net(batch.to(device))
pred_list.extend(torch.argmax(softmax(output),dim=1).cpu().numpy())
true_list.extend(label.cpu().numpy())
acc = accuracy_score(pred_list, true_list)
print('test acc: %f'%acc)
return acc,pred_list,true_list
print('start to train...')
train(net,num_epochs,loss_function,optimizer,train_iter,val_iter)
print('start to predict test set...')
acc,pred_list,true_list = predict(net,test_iter)
print('Done')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment