Skip to content

Instantly share code, notes, and snippets.

@ghifarit53
Last active June 11, 2024 09:18
Show Gist options
  • Save ghifarit53/e3f38877fe7ad2d3166e561365ce3e73 to your computer and use it in GitHub Desktop.
Save ghifarit53/e3f38877fe7ad2d3166e561365ce3e73 to your computer and use it in GitHub Desktop.
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, TrainingArguments, Trainer, DataCollatorForTokenClassification, AutoModelForTokenClassification, pipeline
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report, confusion_matrix, ConfusionMatrixDisplay
from shutil import rmtree
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
def tokenize_and_align_labels(examples):
tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True)
labels = []
for i, label in enumerate(examples['entity']):
word_ids = tokenized_inputs.word_ids(batch_index=i) # Map tokens to their respective word.
previous_word_idx = None
label_ids = []
for word_idx in word_ids: # Set the special tokens to -100.
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx: # Only label the first token of a given word.
label_ids.append(label[word_idx])
else:
label_ids.append(-100)
previous_word_idx = word_idx
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs
def df_to_dataset_for_model(df: pd.DataFrame):
# So the original df remains unchanged
df = df.copy(deep=True)
df["entity"] = df["entity"].apply(lambda x: label2id[x])
# Merge tokens into full sentence
tmp_df = df.groupby("text_id")["token"].apply(list).reset_index()
tmp_df["entity"] = df.groupby("text_id")["entity"].apply(list).reset_index()["entity"]
tmp_df.columns = ["text_id", "tokens", "entity"]
# Convert to Dataset format
tmp_list = []
for i in tmp_df.index:
tmp_list.append({
"text_id": tmp_df.loc[i, "text_id"],
"tokens": tmp_df.loc[i, "tokens"],
"entity": tmp_df.loc[i, "entity"],
})
dataset = Dataset.from_list(tmp_list)
# Tokenize the dataset
tokenized_dataset = dataset.map(tokenize_and_align_labels,
batched=True,
remove_columns=dataset.column_names)
# Split into train and validation
train_val = tokenized_dataset.train_test_split(test_size=0.2, seed=42)
train_val = DatasetDict({
"train": train_val["train"],
"validation": train_val["test"]
})
return train_val
def train_model(model_name: str,
train_dataset: Dataset,
val_dataset: Dataset,
model_output_path: str):
training_args = TrainingArguments(
output_dir="./result",
overwrite_output_dir = True,
eval_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=10,
weight_decay=0.01,
load_best_model_at_end=True,
seed=42,
)
def model_init():
return AutoModelForTokenClassification.from_pretrained(model_name, id2label=id2label, label2id=label2id)
trainer = Trainer(
model_init=model_init,
args=training_args,
train_dataset=train_val['train'],
eval_dataset=train_val['validation'],
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
trainer.save_model(model_output_path)
rmtree("./result")
def extract_entities_from_result(tokens, result):
predicted_entities = []
current_index = 0
for token in tokens:
predicted = "O"
for entry in result:
if entry["start"] <= current_index < entry["end"]:
predicted = f"B-{entry['entity_group']}" if current_index == entry['start'] else f"I-{entry['entity_group']}"
predicted_entities.append(predicted)
current_index += len(token) + 1
return predicted_entities
def filter_threshold(model_path: str,
df: pd.DataFrame,
threshold: float,
output_dir: str,
output_filename_prefix: str):
# Load base model
ner = pipeline("token-classification",
model=model_path,
aggregation_strategy="simple")
# above and below are for evaluation
above = {
'text_id': [],
'text': [],
'entity': [],
'word': [],
'confidence': [],
}
below = {
'text_id': [],
'text': [],
'entity': [],
'word': [],
'confidence': [],
}
# predicted above is the one to be returned
predicted_above = {
'text_id': [],
'tokens': [],
'predicted_entity': [],
}
# Get all text ids of unlabelled data
text_ids = df["text_id"].unique().tolist()
for text_id in tqdm(text_ids):
tokens = df[df["text_id"] == text_id]["token"].to_list()
text = " ".join(tokens)
result = ner(text)
if len(result) == 0:
below["text_id"].append(text_id)
below["text"].append(text)
below["entity"].append("-")
below["word"].append("-")
below["confidence"].append(0)
continue
score = 0
for entity in result:
score += entity["score"]
avg = score / len(result)
# Filter above and below average
if avg >= threshold:
predicted_above["text_id"].append(text_id)
predicted_above["tokens"].append(tokens)
predicted_above["predicted_entity"].append(extract_entities_from_result(tokens, result))
for entity in result:
above["text_id"].append(text_id)
above["text"].append(text)
above["entity"].append(entity["entity_group"])
above["word"].append(entity["word"])
above["confidence"].append(entity["score"])
else:
for entity in result:
below["text_id"].append(text_id)
below["text"].append(text)
below["entity"].append(entity["entity_group"])
below["word"].append(entity["word"])
below["confidence"].append(entity["score"])
above_df = pd.DataFrame(above)
below_df = pd.DataFrame(below)
above_df.to_excel(os.path.join(output_dir, f"{output_filename_prefix}-above-{threshold}.xlsx"))
below_df.to_excel(os.path.join(output_dir, f"{output_filename_prefix}-below-{threshold}.xlsx"))
print(f"Above {threshold}: {len(above_df['text_id'].unique())} sentences")
print(f"Below {threshold}: {len(below_df['text_id'].unique())} sentences")
predicted_above = pd.DataFrame(predicted_above)
predicted_above = predicted_above.explode(["tokens", "predicted_entity"]).reset_index(drop=True)
predicted_above.columns = ["text_id", "token", "entity"]
return predicted_above
def get_predicted_label_on_test_dataset(model_path: str, df: pd.DataFrame):
# Load model
ner = pipeline("token-classification",
model=model_path,
aggregation_strategy="simple")
predicted = []
text_ids = df["text_id"].unique().tolist()
for text_id in tqdm(text_ids):
tokens = df[df["text_id"] == text_id]["token"].to_list()
text = " ".join(tokens)
result = ner(text)
predicted.extend(extract_entities_from_result(tokens, result))
df["predicted_entity"] = predicted
def get_overall_performance(df: pd.DataFrame):
# Overall accuracy
total = len(df)
correct = (df["entity"] == df["predicted_entity"]).sum()
accuracy = (correct / total)
print("Accuracy:", accuracy)
# Precision, recall, and F1-score
y_true = df['entity']
y_pred = df['predicted_entity']
precision = precision_score(y_true, y_pred, average='weighted', labels=label_list)
recall = recall_score(y_true, y_pred, average='weighted', labels=label_list)
f1 = f1_score(y_true, y_pred, average='weighted', labels=label_list)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)
def get_individual_label_score(df: pd.DataFrame):
y_true = df['entity']
y_pred = df['predicted_entity']
report = classification_report(y_true, y_pred, labels=label_list, output_dict=True)
return pd.DataFrame(report).transpose()
def get_confusion_matrix(df: pd.DataFrame):
y_true = df['entity']
y_pred = df['predicted_entity']
cm = confusion_matrix(y_true, y_pred, labels=label_list)
ax= plt.subplot()
sns.heatmap(cm, annot=True, fmt='g', ax=ax, cmap="Oranges", vmin=0, vmax=160) #annot=True to annotate cells, ftm='g' to disable scientific notation
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')
plt.xticks(rotation=270)
plt.yticks(rotation=0)
ax.xaxis.set_ticklabels(label_list)
ax.yaxis.set_ticklabels(label_list)
def get_misclassified_report(df: pd.DataFrame, output_path: str):
misclassified_dict = {
'text_id': [],
'token': [],
'true_label': [],
'pred_label': [],
'text': [],
}
for i in tqdm(df.index):
true_label = df.loc[i, 'entity']
pred_label = df.loc[i, 'predicted_entity']
text_id = df.loc[i, 'text_id']
if true_label != pred_label:
misclassified_dict['token'].append(df.loc[i, 'token'])
misclassified_dict['text_id'].append(text_id)
misclassified_dict['true_label'].append(true_label)
misclassified_dict['pred_label'].append(pred_label)
misclassified_dict['text'].append(' '.join(df[df['text_id'] == text_id]['token'].to_list()))
misclassified_df = pd.DataFrame.from_dict(misclassified_dict)
misclassified_df.to_excel(output_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment