This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sklearn.model_selection import train_test_split | |
| X= data['comment'].values | |
| y= data['label'].values | |
| train_df, test_df, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2,random_state=42) | |
| TEST_SPLIT = 0.2 | |
| BATCH_SIZE = 64 | |
| train_size = int(len(train_df) * (1-TEST_SPLIT)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #creating encoding from bert tokenizer for test data | |
| encodings_test = construct_encodings(list(test_df.values.flatten()), tkzr, max_len=MAX_LEN) | |
| #Constructing the dataset for test | |
| tfdataset_test = construct_tfdataset(encodings_test, y_test) | |
| BATCH_SIZE = 64 | |
| tfdataset_test = tfdataset_test.shuffle(len(test_df)) | |
| tfdataset_test = tfdataset_test.batch(BATCH_SIZE) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| N_EPOCHS = 2 | |
| model_no_clean = TFDistilBertForSequenceClassification.from_pretrained(MODEL_NAME,num_labels=2) | |
| optimizer = optimizers.Adam(learning_rate=3e-5) | |
| loss = losses.SparseCategoricalCrossentropy(from_logits=True) | |
| modelcompile(optimizer=optimizer, loss=loss, metrics=['accuracy']) | |
| model.fit(tfdataset_train1, batch_size=BATCH_SIZE, epochs=N_EPOCHS) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| TEST_SPLIT = 0.2 | |
| BATCH_SIZE = 64 | |
| train_size = int(len(train_df) * (1-TEST_SPLIT)) | |
| tfdataset1 = tfdataset1.shuffle(len(train_df)) | |
| tfdataset_train1 = tfdataset1.take(train_size) | |
| tfdataset_cv1 = tfdataset1.skip(train_size) | |
| tfdataset_train1 = tfdataset_train1.batch(BATCH_SIZE) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def construct_tfdataset(encodings, y=None): | |
| if y.any(): | |
| return tf.data.Dataset.from_tensor_slices((dict(encodings),y)) | |
| else: | |
| # this case is used when making predictions on unseen samples after training | |
| return tf.data.Dataset.from_tensor_slices(dict(encodings)) | |
| tfdataset1 = construct_tfdataset(encodings, y_train) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def construct_encodings(x, tkzr, max_len, trucation=True, padding=True): | |
| return tkzr(x, max_length=max_len, truncation=trucation, padding=padding) | |
| encodings = construct_encodings(list(train_df.flatten()), tkzr1, max_len=MAX_LEN) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| %%capture | |
| #!python3 -m venv venv | |
| #!source venv/bin/activate | |
| #!pip install tensorflow transformers | |
| import tensorflow as tf | |
| from tensorflow.keras import activations, optimizers, losses | |
| from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification | |
| MODEL_NAME = 'distilbert-base-uncased' | |
| MAX_LEN = 40 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #getting Vocab file | |
| vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy() | |
| do_lower_case = bert_layer.resolved_object.do_lower_case.numpy() | |
| from bert.tokenization import bert_tokenization | |
| tokenizer=bert_tokenization.FullTokenizer(vocab_file,do_lower_case) | |
| def create_tokens_mask_segment(text,tokenizer,max_seq_length): | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow_hub as hub | |
| ## Loading the Pretrained Model from tensorflow HUB | |
| tf.keras.backend.clear_session() | |
| # maximum length of a seq in the data we have, for now i am making it as 55. You can change this | |
| max_seq_length = 40 | |
| #BERT takes 3 inputs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import datetime | |
| from tensorflow.keras.callbacks import ModelCheckpoint,TensorBoard,EarlyStopping,ReduceLROnPlateau | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| #creating tensboard call back object | |
| log_dir = os.path.join('logs1','fits', datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) | |
| filepath = "best_lstm_model.hdf5" | |
| #history_loss= LossHistory(validation_data=validation_data) | |
| #logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) |
NewerOlder