This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BERT_Arch(nn.Module): | |
def __init__(self, bert_model): | |
super(BERT_Arch, self).__init__() | |
self.bert = bert_model | |
# relu activation function | |
self.relu = nn.ReLU() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# freeze all the parameters | |
for params in bert_model.parameters(): | |
params.requires_grad = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import DataLoader, TensorDataset, SequentialSampler, RandomSampler | |
# wrap tensors | |
valdata = TensorDataset(valseq, valmask, valy) | |
b_size = 32 | |
# sampler for sampling the data during training | |
valsampler = SequentialSampler(valdata) | |
# dataLoader for validation set | |
valDataLoader = DataLoader(valdata, sampler = valsampler, batch_size = b_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# tokenizing and encoding the sequences in the validation set | |
tokensval = tokenizer.batch_encode_plus( | |
valtext.tolist(), | |
max_length = 25, | |
pad_to_max_length=True, | |
truncation=True | |
) | |
# tokenizeing and encoding the sequences in the training set | |
tokenstrain = tokenizer.batch_encode_plus ( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# get length of all the messages in the train set | |
sequence_len = [len(i.split()) for i in traintext] | |
pd.Series(sequence_len).hist(bins = 30) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# importing BERT-base pretrained model | |
bert_model= AutoModel.from_pretrained('bert-base-uncased') | |
# Loading the BERT tokenizer | |
bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | |
#Tokenizer | |
# sample data | |
sample_text = ["this is a bert tutorial", "we will fine tune the bert model"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# split train dataset into train, validation and test sets | |
traintext, temptext, trainlabels, templabels = train_test_split(df['text'], df['label'], | |
random_state=2018, | |
test_size=0.3, | |
stratify=df['label']) | |
valtext, testtext, vallabels, testlabels = train_test_split(temptext, templabels, | |
random_state=2018, | |
test_size=0.5, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import pandas as pd | |
import torch.nn as nn | |
import transformers | |
import numpy as np | |
from sklearn.model_selection import train_test_split | |
from transformers import AutoModel, BertTokenizerFast | |
from sklearn.metrics import classification_report | |
# specify your GPU |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train_data=new_dataset[:987] | |
valid_data=new_dataset[987:] | |
valid_data['Predictions']=predicted_closing_price | |
plt.plot(train_data["Close"]) | |
plt.plot(valid_data[['Close',"Predictions"]]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
lstm_model.save("saved_model.h5") |