This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Store our loss and accuracy for plotting | |
train_loss_set = [] | |
learning_rate = [] | |
# Gradients gets accumulated by default | |
model.zero_grad() | |
# tnrange is a tqdm wrapper around the normal python range | |
for _ in tnrange(1,epochs+1,desc='Epoch'): | |
print("<" + "="*22 + F" Epoch {_} "+ "="*22 + ">") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load BertForSequenceClassification, the pretrained BERT model with a single linear classification layer on top. | |
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2).to(device) | |
# Parameters: | |
lr = 2e-5 | |
adam_epsilon = 1e-8 | |
# Number of training epochs (authors recommend between 2 and 4) | |
epochs = 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Split into a training set and a test set using a stratified k fold | |
train_inputs,validation_inputs,train_labels,validation_labels = train_test_split(input_ids,labels,random_state=SEED,test_size=0.1) | |
train_masks,validation_masks,_,_ = train_test_split(attention_masks,input_ids,random_state=SEED,test_size=0.1) | |
# convert all our data into torch tensors, required data type for our model | |
train_inputs = torch.tensor(train_inputs) | |
validation_inputs = torch.tensor(validation_inputs) | |
train_labels = torch.tensor(train_labels) | |
validation_labels = torch.tensor(validation_labels) | |
train_masks = torch.tensor(train_masks) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
df = pd.read_csv("raw/in_domain_train.tsv", delimiter='\t', header=None, names=['sentence_source', 'label', 'label_notes', 'sentence']) | |
print(df.sample(5)) | |
## create label and sentence list | |
sentences = df.sentence.values | |
#check distribution of data based on labels | |
print("Distribution of data based on labels: ",df.label.value_counts()) | |
# Set the maximum sequence length. The longest sequence in our training set is 47, but we'll leave room on the end anyway. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# To upload data from local disk at run time, uncomment below code | |
#from google.colab import files | |
#uploaded = files.upload() | |
# The below code is when we integrate Google drive to current Colab session. | |
# This is helpful when we want to store the trained model, and later download it to local. | |
from google.colab import drive | |
drive.mount('/content/gdrive') | |
os.chdir('/content/gdrive/My Drive') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<link href="../core-scaffold/core-scaffold.html" rel="import"> | |
<link href="../core-header-panel/core-header-panel.html" rel="import"> | |
<link href="../core-menu/core-menu.html" rel="import"> | |
<link href="../core-item/core-item.html" rel="import"> | |
<link href="../core-icon-button/core-icon-button.html" rel="import"> | |
<link href="../core-toolbar/core-toolbar.html" rel="import"> | |
<link href="../core-field/core-field.html" rel="import"> | |
<link href="../core-icon/core-icon.html" rel="import"> | |
<link href="../core-input/core-input.html" rel="import"> | |
<link href="../core-icons/core-icons.html" rel="import"> |