Skip to content

Instantly share code, notes, and snippets.

@arunm8489
Created August 2, 2020 04:39
Show Gist options
  • Save arunm8489/7cb1c6ecd832d9e133f9066a8b41b9e0 to your computer and use it in GitHub Desktop.
Save arunm8489/7cb1c6ecd832d9e133f9066a8b41b9e0 to your computer and use it in GitHub Desktop.
class Train():
def __init__(self,epochs,lr=0.01,train_loader=train_loader,test_loader=test_loader,seq_len=440):
self.train_loader = train_loader
self.test_loader = test_loader
self.epochs = epochs
self.lr = lr
self.seq_len = seq_len
self.checkpoint_path = 'model1/chkpoint1_'
self.best_model_path = 'model1/bestmodel1.pt'
self.test_loss_min = 3.95275
self.criterion = nn.CrossEntropyLoss()
parameters = filter(lambda p: p.requires_grad, model.parameters())
self.optimizer = torch.optim.Adam(params=parameters,lr=self.lr)
def split_out(self,x,seq_len):
batch_size = x.shape[0]
return x[:,:seq_len],x[:,seq_len].view(batch_size,-1),x[:,seq_len + 1].view(batch_size,-1),x[:,seq_len + 2].view(batch_size,-1),x[:,seq_len + 3].view(batch_size,-1),x[:,seq_len +4].view(batch_size,-1),x[:,(seq_len + 5):].view(batch_size,-1).float()
def train_model(self,model):
metrics = {'train_auc': [],'test_auc' : [], 'train accuracy':[], 'test accuracy':[], 'train loss':[],'test loss': []}
epochs = self.epochs
test_loss_min = self.test_loss_min
for epoch in range(epochs):
start_time = time.time()
print(f'Epoch: {epoch + 1}')
train_epoch_roc_auc, train_epoch_accuracy, train_epoch_loss = self.train_epoch(model)
test_epoch_roc_auc, test_epoch_accuracy, test_epoch_loss = self.validation_epoch(model)
metrics['train_auc'].append(train_epoch_roc_auc)
metrics['train accuracy'].append(train_epoch_accuracy)
metrics['train loss'].append(train_epoch_loss)
metrics['test_auc'].append(test_epoch_roc_auc)
metrics['test accuracy'].append(test_epoch_accuracy)
metrics['test loss'].append(test_epoch_loss)
print(f'Train : accuracy {train_epoch_accuracy}, auc score {train_epoch_roc_auc}, loss {train_epoch_loss}')
print(f'Test : accuracy {test_epoch_accuracy}, auc score {test_epoch_roc_auc}, loss {test_epoch_loss}')
checkpoint = {
'epoch': epoch + 1,
'valid_loss_min': test_epoch_loss,
'state_dict': model.state_dict(),
'optimizer': self.optimizer.state_dict(),
}
# save checkpoint
self.save_ckp(checkpoint, False, self.checkpoint_path, self.best_model_path)
if test_epoch_loss <= test_loss_min:
print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(test_loss_min,test_epoch_loss))
# save checkpoint as best model
self.save_ckp(checkpoint, True, self.checkpoint_path, self.best_model_path)
test_loss_min = test_epoch_loss
time_elapsed = time.time() - start_time
print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('--'*50)
return metrics
def score(self,y,y_pred,y_prob_pred):
roc_score = roc_auc_score(y,y_prob_pred)
acc = accuracy_score(y,y_pred)
return roc_score, acc
def validation_epoch(self,model):
model.eval()
test_loader = self.test_loader
Auc_score = 0.0
Accuracy = 0.0
Loss = 0.0
Data_size = len(test_loader.dataset)
for x,y in test_loader:
batch_size = x.shape[0]
x = x.long()
y = y.long()
x = x.to(device)
y = y.to(device)
text,state,prefix,cat,sub_cat,grade,num = self.split_out(x,self.seq_len)
y_pred = model(text,state,prefix,cat,sub_cat,grade,num)
loss = self.criterion(y_pred, y)
_, y_hat_pred = torch.max(y_pred,axis=1)
y_hat_prob = y_pred[:,-1]
y_, y_hat_pred_,y_hat_prob_ = y.detach().clone(), y_hat_pred.detach().clone(), y_hat_prob.detach().clone()
auc_score, acc = self.score(y_.cpu().numpy(),y_hat_pred_.cpu().numpy(),y_hat_prob_.cpu().numpy())
Loss += loss.item() * batch_size
Auc_score += auc_score * batch_size
Accuracy += acc * batch_size
epoch_roc_auc = Auc_score / Data_size
epoch_accuracy = Accuracy / Data_size
epoch_loss = Loss / Data_size
return epoch_roc_auc, epoch_accuracy, epoch_loss
def train_epoch(self,model):
model.train()
Auc_score = 0
Accuracy = 0
Loss = 0
train_loader = self.train_loader
Data_size = len(train_loader.dataset)
for x,y in train_loader:
batch_size = x.shape[0]
x = x.long()
y = y.long()
x = x.to(device)
y = y.to(device)
text,state,prefix,cat,sub_cat,grade,num = self.split_out(x,self.seq_len)
y_pred = model(text,state,prefix,cat,sub_cat,grade,num)
self.optimizer.zero_grad()
loss = self.criterion(y_pred, y)
loss.backward()
self.optimizer.step()
_, y_hat_pred = torch.max(y_pred,axis=1)
y_hat_prob = y_pred[:,-1]
y_, y_hat_pred_,y_hat_prob_ = y.detach().clone(), y_hat_pred.detach().clone(), y_hat_prob.detach().clone()
auc_score, acc = self.score(y_.cpu().numpy(),y_hat_pred_.cpu().numpy(),y_hat_prob_.cpu().numpy())
Loss += loss.item() * batch_size
Auc_score += auc_score * batch_size
Accuracy += acc * batch_size
train_epoch_roc_auc = Auc_score / Data_size
train_epoch_accuracy = Accuracy / Data_size
train_epoch_loss = Loss / Data_size
return train_epoch_roc_auc,train_epoch_accuracy,train_epoch_loss
def save_ckp(self,state, is_best, checkpoint_path, best_model_path):
"""
state: checkpoint we want to save
is_best: is this the best checkpoint; min validation loss
checkpoint_path: path to save checkpoint
best_model_path: path to save best model
"""
# save checkpoint data to the path given, checkpoint_path
torch.save(state, checkpoint_path)
# if it is a best model, min validation loss
if is_best:
# copy that checkpoint file to best path given, best_model_path
shutil.copyfile(checkpoint_path, best_model_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment