Last active
February 5, 2023 15:23
-
-
Save ireneisdoomed/9c2981ceaa781dfea261731def8cb9ab to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## TRAINING | |
import logging | |
from datasets import load_dataset, DatasetDict, Dataset | |
import tensorflow as tf | |
from transformers import AutoTokenizer, DefaultDataCollator, TFAutoModelForSequenceClassification | |
def tokenize_function(dataset_split): | |
return tokenizer(dataset_split["text"], padding="max_length", truncation=True) | |
def explode_label_columns(dataset, label2id): | |
ds = DatasetDict() | |
for split in ["train", "test"]: | |
pdf = dataset[split].to_pandas().explode("label_descriptions").rename({"label_descriptions": "label_description"}, axis=1).reset_index() | |
pdf["label"] = pdf["label_description"].map(label2id) | |
pdf = pdf[["text", "label", "label_description"]] | |
good_dataset = Dataset.from_pandas(pdf, preserve_index=False) | |
ds[split] = good_dataset | |
return ds | |
def subset_split(dataset, split, n_samples): | |
return dataset[split].shuffle(seed=42).select(range(n_samples)) | |
def convert_to_tf_dataset(dataset, data_collator, shuffle_flag, batch_size): | |
return ( | |
dataset.to_tf_dataset( | |
columns=["attention_mask", "input_ids", "token_type_ids"], | |
label_cols=["labels"], | |
shuffle=shuffle_flag, | |
collate_fn=data_collator, | |
batch_size=batch_size | |
) | |
) | |
def get_label_metadata(dataset): | |
""" | |
It takes a dataset and returns a list of labels, a dictionary mapping label ids to labels, and a | |
dictionary mapping labels to label ids | |
Args: | |
dataset: the dataset object | |
""" | |
labels = [label for label in dataset['train'].features.keys() if label not in ['text', 'label_descriptions']] | |
id2label = dict(enumerate(labels)) | |
label2id = {label:idx for idx, label in enumerate(labels)} | |
return labels, id2label, label2id | |
def main(): | |
logging.basicConfig(level=logging.INFO) | |
logging.info("Loading dataset") | |
dataset = load_dataset("opentargets/clinical_trial_reason_to_stop", split='train').train_test_split(test_size=0.1) | |
global labels | |
labels, id2label, label2id = get_label_metadata(dataset) | |
dataset = explode_label_columns(dataset, label2id) | |
logging.info("Tokenizing dataset") | |
global tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) | |
tokenized_datasets = dataset.map(tokenize_function, batched=True) | |
small_train_dataset = subset_split(tokenized_datasets, "train", 500) | |
small_eval_dataset = subset_split(tokenized_datasets, "test", 50) | |
logging.info("Loading model") | |
data_collator = DefaultDataCollator(return_tensors="tf") | |
tf_train_dataset = convert_to_tf_dataset(small_train_dataset, data_collator, shuffle_flag=True, batch_size=32) | |
tf_validation_dataset = convert_to_tf_dataset(small_eval_dataset, data_collator, shuffle_flag=False, batch_size=32) | |
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=17, id2label=id2label, label2id=label2id) | |
model.compile( | |
optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5), | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=tf.metrics.SparseCategoricalAccuracy(), | |
) | |
logging.info("Training model") | |
model.fit(tf_train_dataset, epochs=3, validation_data=tf_validation_dataset) | |
model.save_pretrained('models/model_500n_3_epochs_classificator_tf', saved_model=True, save_format='tf') | |
tokenizer.save_pretrained('models/model_500n_3_epochs_classificator_tf_tokenizer') | |
logging.info("Model saved. Exiting.") | |
if __name__ == "__main__": | |
main() | |
## EVALUATION | |
from datasets import load_dataset, DatasetDict, Dataset | |
from evaluate import evaluator | |
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification | |
from typing import Union | |
def explode_label_columns( | |
data: Union[Dataset, DatasetDict], | |
split_names: list, | |
label2id: dict): | |
def fix_dataset(dataset): | |
pdf = dataset.to_pandas().explode("label_descriptions").rename({"label_descriptions": "label_description"}, axis=1).reset_index() | |
pdf["label"] = pdf["label_description"].map(label2id) | |
pdf = pdf[["text", "label", "label_description"]] | |
good_dataset = Dataset.from_pandas(pdf, preserve_index=False) | |
return good_dataset | |
ds = DatasetDict() | |
if isinstance(data, DatasetDict): | |
for split in split_names: | |
ds[split] = fix_dataset(data[split]) | |
elif isinstance(data, Dataset): | |
split = split_names[0] | |
ds[split] = fix_dataset(data) | |
return ds | |
dataset_agg = load_dataset("opentargets/clinical_trial_reason_to_stop", split="all") | |
model = TFAutoModelForSequenceClassification.from_pretrained("./model_3_epochs_classificator_tf", local_files_only=True) | |
dataset = explode_label_columns(dataset_agg, ["all"], model.config.label2id) | |
# Evaluate from a local model | |
tokenizer = AutoTokenizer.from_pretrained("./model_3_epochs_classificator_tf", local_files_only=True, from_pt=False) | |
task_evaluator = evaluator("text-classification") | |
eval_results = task_evaluator.compute( | |
model_or_pipeline=model, | |
data=dataset["all"], | |
label_mapping=model.config.label2id, | |
tokenizer=tokenizer, | |
metric="f1", | |
) | |
print(eval_results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment