Skip to content

Instantly share code, notes, and snippets.

@AmolMavuduru
Created December 1, 2020 20:23
Show Gist options
  • Save AmolMavuduru/8fd051007b6b49d808bbc1b087c2d4af to your computer and use it in GitHub Desktop.
Save AmolMavuduru/8fd051007b6b49d808bbc1b087c2d4af to your computer and use it in GitHub Desktop.
Finetuning BERT for text-classification.
from transformers import TFBertForSequenceClassification, BertTokenizerFast
import tensorflow as tf
from tensorflow.keras.callbacks import *
import pandas as pd
from sklearn.model_selection import train_test_split
transformer_model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) # model used for classificaiton
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased',
max_length = 256, # max length of the text that can go to BERT
pad_to_max_length = True # pads shorter sequences of text up to the max length)
data = pd.read_csv('./data/combined_news_data_processed.csv')
data.dropna(inplace=True)
X = data['text']
y = data['label']
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.3, random_state=42)
X_train_encoded = dict(tokenizer(list(X_train.values),
add_special_tokens = True, # add [CLS], [SEP]
max_length = 256, # max length of the text that can go to BERT
pad_to_max_length = True, # add [PAD] tokens
return_attention_mask = True))
X_valid_encoded = dict(tokenizer(list(X_valid.values),
add_special_tokens = True, # add [CLS], [SEP]
max_length = 256, # max length of the text that can go to BERT
pad_to_max_length = True, # add [PAD] tokens
return_attention_mask = True))
train_data = tf.data.Dataset.from_tensor_slices((X_train_encoded, list(y_train.values)))
valid_data = tf.data.Dataset.from_tensor_slices((X_valid_encoded, list(y_valid.values)))
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5) # I purposely used a small learning rate for finetuning the model
transformer_model.compile(optimizer=optimizer,
loss=transformer_model.compute_loss,
metrics=['accuracy'])
transformer_model.fit(train_data.shuffle(1000).batch(16),
epochs=1, batch_size=16,
validation_data=valid_data.batch(16))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment