Created
February 13, 2020 16:39
-
-
Save henry16lin/6fae622a38f0815c2e699739520ad7b2 to your computer and use it in GitHub Desktop.
RNN_training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
net.to(device) | |
loss_function = nn.CrossEntropyLoss() # ~ nn.LogSoftmax()+nn.NLLLoss() | |
optimizer = optim.Adam(net.parameters()) | |
def train(net,num_epochs,loss_function,optimizer,train_iter,val_iter): | |
for epoch in range(num_epochs): | |
start = time.time() | |
train_loss, val_losses = 0, 0 | |
train_acc, val_acc = 0, 0 | |
n, m = 0, 0 | |
net.train() | |
for feature, label in train_iter: | |
n += 1 | |
optimizer.zero_grad() | |
feature = Variable(feature.to(device)) | |
label = Variable(label.to(device)) | |
score = net(feature) | |
loss = loss_function(score, label) | |
loss.backward() | |
optimizer.step() | |
train_acc += accuracy_score(torch.argmax(score.cpu().data,dim=1), label.cpu()) | |
train_loss += loss | |
with torch.no_grad(): | |
net.eval() | |
for val_feature, val_label in val_iter: | |
m += 1 | |
val_feature = val_feature.to(device) | |
val_label = val_label.to(device) | |
val_score = net(val_feature) | |
val_loss = loss_function(val_score, val_label) | |
val_acc += accuracy_score(torch.argmax(val_score.cpu().data,dim=1), val_label.cpu()) | |
val_losses += val_loss | |
runtime = time.time() - start | |
print('epoch: %d, train loss: %.4f, train acc: %.2f, val loss: %.4f, val acc: %.2f, time: %.2f' % | |
(epoch, train_loss.data/n, train_acc/n, val_losses.data/m, val_acc/m, runtime)) | |
#save final model | |
state = { | |
'epoch': epoch, | |
'state_dict': net.state_dict(), | |
'optimizer': optimizer.state_dict() | |
} | |
torch.save(state, os.path.join(model_save_path,'last_model.pt')) | |
def predict(net,test_iter): | |
#state = torch.load(os.path.join(cwd,'checkpoint','epoch10_maxlen300_embed200.pt'),map_location=torch.device('cpu')) | |
#net.load_state_dict(state['state_dict']) | |
pred_list = [] | |
true_list = [] | |
softmax = nn.Softmax(dim=1) | |
with torch.no_grad(): | |
net.eval() | |
for batch,label in test_iter: | |
output = net(batch.to(device)) | |
pred_list.extend(torch.argmax(softmax(output),dim=1).cpu().numpy()) | |
true_list.extend(label.cpu().numpy()) | |
acc = accuracy_score(pred_list, true_list) | |
print('test acc: %f'%acc) | |
return acc,pred_list,true_list | |
print('start to train...') | |
train(net,num_epochs,loss_function,optimizer,train_iter,val_iter) | |
print('start to predict test set...') | |
acc,pred_list,true_list = predict(net,test_iter) | |
print('Done') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment