Last active
June 11, 2024 09:18
-
-
Save ghifarit53/e3f38877fe7ad2d3166e561365ce3e73 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import Dataset, DatasetDict | |
from transformers import AutoTokenizer, TrainingArguments, Trainer, DataCollatorForTokenClassification, AutoModelForTokenClassification, pipeline | |
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report, confusion_matrix, ConfusionMatrixDisplay | |
from shutil import rmtree | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import pandas as pd | |
import numpy as np | |
def tokenize_and_align_labels(examples): | |
tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True) | |
labels = [] | |
for i, label in enumerate(examples['entity']): | |
word_ids = tokenized_inputs.word_ids(batch_index=i) # Map tokens to their respective word. | |
previous_word_idx = None | |
label_ids = [] | |
for word_idx in word_ids: # Set the special tokens to -100. | |
if word_idx is None: | |
label_ids.append(-100) | |
elif word_idx != previous_word_idx: # Only label the first token of a given word. | |
label_ids.append(label[word_idx]) | |
else: | |
label_ids.append(-100) | |
previous_word_idx = word_idx | |
labels.append(label_ids) | |
tokenized_inputs["labels"] = labels | |
return tokenized_inputs | |
def df_to_dataset_for_model(df: pd.DataFrame): | |
# So the original df remains unchanged | |
df = df.copy(deep=True) | |
df["entity"] = df["entity"].apply(lambda x: label2id[x]) | |
# Merge tokens into full sentence | |
tmp_df = df.groupby("text_id")["token"].apply(list).reset_index() | |
tmp_df["entity"] = df.groupby("text_id")["entity"].apply(list).reset_index()["entity"] | |
tmp_df.columns = ["text_id", "tokens", "entity"] | |
# Convert to Dataset format | |
tmp_list = [] | |
for i in tmp_df.index: | |
tmp_list.append({ | |
"text_id": tmp_df.loc[i, "text_id"], | |
"tokens": tmp_df.loc[i, "tokens"], | |
"entity": tmp_df.loc[i, "entity"], | |
}) | |
dataset = Dataset.from_list(tmp_list) | |
# Tokenize the dataset | |
tokenized_dataset = dataset.map(tokenize_and_align_labels, | |
batched=True, | |
remove_columns=dataset.column_names) | |
# Split into train and validation | |
train_val = tokenized_dataset.train_test_split(test_size=0.2, seed=42) | |
train_val = DatasetDict({ | |
"train": train_val["train"], | |
"validation": train_val["test"] | |
}) | |
return train_val | |
def train_model(model_name: str, | |
train_dataset: Dataset, | |
val_dataset: Dataset, | |
model_output_path: str): | |
training_args = TrainingArguments( | |
output_dir="./result", | |
overwrite_output_dir = True, | |
eval_strategy="epoch", | |
save_strategy="epoch", | |
learning_rate=2e-5, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=16, | |
num_train_epochs=10, | |
weight_decay=0.01, | |
load_best_model_at_end=True, | |
seed=42, | |
) | |
def model_init(): | |
return AutoModelForTokenClassification.from_pretrained(model_name, id2label=id2label, label2id=label2id) | |
trainer = Trainer( | |
model_init=model_init, | |
args=training_args, | |
train_dataset=train_val['train'], | |
eval_dataset=train_val['validation'], | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
) | |
trainer.train() | |
trainer.save_model(model_output_path) | |
rmtree("./result") | |
def extract_entities_from_result(tokens, result): | |
predicted_entities = [] | |
current_index = 0 | |
for token in tokens: | |
predicted = "O" | |
for entry in result: | |
if entry["start"] <= current_index < entry["end"]: | |
predicted = f"B-{entry['entity_group']}" if current_index == entry['start'] else f"I-{entry['entity_group']}" | |
predicted_entities.append(predicted) | |
current_index += len(token) + 1 | |
return predicted_entities | |
def filter_threshold(model_path: str, | |
df: pd.DataFrame, | |
threshold: float, | |
output_dir: str, | |
output_filename_prefix: str): | |
# Load base model | |
ner = pipeline("token-classification", | |
model=model_path, | |
aggregation_strategy="simple") | |
# above and below are for evaluation | |
above = { | |
'text_id': [], | |
'text': [], | |
'entity': [], | |
'word': [], | |
'confidence': [], | |
} | |
below = { | |
'text_id': [], | |
'text': [], | |
'entity': [], | |
'word': [], | |
'confidence': [], | |
} | |
# predicted above is the one to be returned | |
predicted_above = { | |
'text_id': [], | |
'tokens': [], | |
'predicted_entity': [], | |
} | |
# Get all text ids of unlabelled data | |
text_ids = df["text_id"].unique().tolist() | |
for text_id in tqdm(text_ids): | |
tokens = df[df["text_id"] == text_id]["token"].to_list() | |
text = " ".join(tokens) | |
result = ner(text) | |
if len(result) == 0: | |
below["text_id"].append(text_id) | |
below["text"].append(text) | |
below["entity"].append("-") | |
below["word"].append("-") | |
below["confidence"].append(0) | |
continue | |
score = 0 | |
for entity in result: | |
score += entity["score"] | |
avg = score / len(result) | |
# Filter above and below average | |
if avg >= threshold: | |
predicted_above["text_id"].append(text_id) | |
predicted_above["tokens"].append(tokens) | |
predicted_above["predicted_entity"].append(extract_entities_from_result(tokens, result)) | |
for entity in result: | |
above["text_id"].append(text_id) | |
above["text"].append(text) | |
above["entity"].append(entity["entity_group"]) | |
above["word"].append(entity["word"]) | |
above["confidence"].append(entity["score"]) | |
else: | |
for entity in result: | |
below["text_id"].append(text_id) | |
below["text"].append(text) | |
below["entity"].append(entity["entity_group"]) | |
below["word"].append(entity["word"]) | |
below["confidence"].append(entity["score"]) | |
above_df = pd.DataFrame(above) | |
below_df = pd.DataFrame(below) | |
above_df.to_excel(os.path.join(output_dir, f"{output_filename_prefix}-above-{threshold}.xlsx")) | |
below_df.to_excel(os.path.join(output_dir, f"{output_filename_prefix}-below-{threshold}.xlsx")) | |
print(f"Above {threshold}: {len(above_df['text_id'].unique())} sentences") | |
print(f"Below {threshold}: {len(below_df['text_id'].unique())} sentences") | |
predicted_above = pd.DataFrame(predicted_above) | |
predicted_above = predicted_above.explode(["tokens", "predicted_entity"]).reset_index(drop=True) | |
predicted_above.columns = ["text_id", "token", "entity"] | |
return predicted_above | |
def get_predicted_label_on_test_dataset(model_path: str, df: pd.DataFrame): | |
# Load model | |
ner = pipeline("token-classification", | |
model=model_path, | |
aggregation_strategy="simple") | |
predicted = [] | |
text_ids = df["text_id"].unique().tolist() | |
for text_id in tqdm(text_ids): | |
tokens = df[df["text_id"] == text_id]["token"].to_list() | |
text = " ".join(tokens) | |
result = ner(text) | |
predicted.extend(extract_entities_from_result(tokens, result)) | |
df["predicted_entity"] = predicted | |
def get_overall_performance(df: pd.DataFrame): | |
# Overall accuracy | |
total = len(df) | |
correct = (df["entity"] == df["predicted_entity"]).sum() | |
accuracy = (correct / total) | |
print("Accuracy:", accuracy) | |
# Precision, recall, and F1-score | |
y_true = df['entity'] | |
y_pred = df['predicted_entity'] | |
precision = precision_score(y_true, y_pred, average='weighted', labels=label_list) | |
recall = recall_score(y_true, y_pred, average='weighted', labels=label_list) | |
f1 = f1_score(y_true, y_pred, average='weighted', labels=label_list) | |
print("Precision:", precision) | |
print("Recall:", recall) | |
print("F1 Score:", f1) | |
def get_individual_label_score(df: pd.DataFrame): | |
y_true = df['entity'] | |
y_pred = df['predicted_entity'] | |
report = classification_report(y_true, y_pred, labels=label_list, output_dict=True) | |
return pd.DataFrame(report).transpose() | |
def get_confusion_matrix(df: pd.DataFrame): | |
y_true = df['entity'] | |
y_pred = df['predicted_entity'] | |
cm = confusion_matrix(y_true, y_pred, labels=label_list) | |
ax= plt.subplot() | |
sns.heatmap(cm, annot=True, fmt='g', ax=ax, cmap="Oranges", vmin=0, vmax=160) #annot=True to annotate cells, ftm='g' to disable scientific notation | |
ax.set_xlabel('Predicted labels') | |
ax.set_ylabel('True labels') | |
ax.set_title('Confusion Matrix') | |
plt.xticks(rotation=270) | |
plt.yticks(rotation=0) | |
ax.xaxis.set_ticklabels(label_list) | |
ax.yaxis.set_ticklabels(label_list) | |
def get_misclassified_report(df: pd.DataFrame, output_path: str): | |
misclassified_dict = { | |
'text_id': [], | |
'token': [], | |
'true_label': [], | |
'pred_label': [], | |
'text': [], | |
} | |
for i in tqdm(df.index): | |
true_label = df.loc[i, 'entity'] | |
pred_label = df.loc[i, 'predicted_entity'] | |
text_id = df.loc[i, 'text_id'] | |
if true_label != pred_label: | |
misclassified_dict['token'].append(df.loc[i, 'token']) | |
misclassified_dict['text_id'].append(text_id) | |
misclassified_dict['true_label'].append(true_label) | |
misclassified_dict['pred_label'].append(pred_label) | |
misclassified_dict['text'].append(' '.join(df[df['text_id'] == text_id]['token'].to_list())) | |
misclassified_df = pd.DataFrame.from_dict(misclassified_dict) | |
misclassified_df.to_excel(output_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment