Last active
January 28, 2022 03:49
-
-
Save B046090010/fd532b7ab8ee6b6d220f186bf4289230 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# install package & env | |
! pip install simpletransformers | |
! pip install torchvision | |
# setting parameter of model | |
from simpletransformers.classification import ClassificationModel | |
train_args={ | |
'num_train_epochs': 1, | |
'train_batch_size': 16, | |
'eval_batch_size': 64, | |
'warmup_steps': 500, | |
'weight_decay': 0.01, | |
'logging_steps': 10, | |
'learning_rate': 5e-5, | |
'fp16': False, | |
'overwrite_output_dir': True | |
} | |
# create a classification model | |
model = ClassificationModel('bert', 'ckiplab/bert-base-chinese', use_cuda=True, cuda_device=0, args=train_args) # downloading ckiplab | |
# data pre-processing | |
import pandas as pd | |
fb_data = pd.read_csv('text_party.csv') | |
fb_label_df = fb_data.rename(columns={'label':'labels','clean_content':'text'}) # In order to input simple_bert | |
del(fb_data) | |
# data split | |
from sklearn.model_selection import train_test_split | |
train_df, eval_df = train_test_split(fb_label_df, test_size=0.2,random_state=1) | |
# model training | |
model.train_model(train_df) # train model | |
# model evaluation | |
import sklearn | |
result, model_outputs, wrong_predictions = model.eval_model(eval_df, acc=sklearn.metrics.accuracy_score) | |
result | |
# predict outcome | |
classification_text = list(eval_df['text']) | |
predictions, raw_outputs = model.predict(classification_text) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment