Skip to content

Instantly share code, notes, and snippets.

@B046090010
Last active January 28, 2022 03:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save B046090010/fd532b7ab8ee6b6d220f186bf4289230 to your computer and use it in GitHub Desktop.
Save B046090010/fd532b7ab8ee6b6d220f186bf4289230 to your computer and use it in GitHub Desktop.
# install package & env
! pip install simpletransformers
! pip install torchvision
# setting parameter of model
from simpletransformers.classification import ClassificationModel
train_args={
'num_train_epochs': 1,
'train_batch_size': 16,
'eval_batch_size': 64,
'warmup_steps': 500,
'weight_decay': 0.01,
'logging_steps': 10,
'learning_rate': 5e-5,
'fp16': False,
'overwrite_output_dir': True
}
# create a classification model
model = ClassificationModel('bert', 'ckiplab/bert-base-chinese', use_cuda=True, cuda_device=0, args=train_args) # downloading ckiplab
# data pre-processing
import pandas as pd
fb_data = pd.read_csv('text_party.csv')
fb_label_df = fb_data.rename(columns={'label':'labels','clean_content':'text'}) # In order to input simple_bert
del(fb_data)
# data split
from sklearn.model_selection import train_test_split
train_df, eval_df = train_test_split(fb_label_df, test_size=0.2,random_state=1)
# model training
model.train_model(train_df) # train model
# model evaluation
import sklearn
result, model_outputs, wrong_predictions = model.eval_model(eval_df, acc=sklearn.metrics.accuracy_score)
result
# predict outcome
classification_text = list(eval_df['text'])
predictions, raw_outputs = model.predict(classification_text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment