Skip to content

Instantly share code, notes, and snippets.

@ireneisdoomed
Last active February 5, 2023 15:23
Show Gist options
  • Save ireneisdoomed/9c2981ceaa781dfea261731def8cb9ab to your computer and use it in GitHub Desktop.
Save ireneisdoomed/9c2981ceaa781dfea261731def8cb9ab to your computer and use it in GitHub Desktop.
## TRAINING
import logging
from datasets import load_dataset, DatasetDict, Dataset
import tensorflow as tf
from transformers import AutoTokenizer, DefaultDataCollator, TFAutoModelForSequenceClassification
def tokenize_function(dataset_split):
return tokenizer(dataset_split["text"], padding="max_length", truncation=True)
def explode_label_columns(dataset, label2id):
ds = DatasetDict()
for split in ["train", "test"]:
pdf = dataset[split].to_pandas().explode("label_descriptions").rename({"label_descriptions": "label_description"}, axis=1).reset_index()
pdf["label"] = pdf["label_description"].map(label2id)
pdf = pdf[["text", "label", "label_description"]]
good_dataset = Dataset.from_pandas(pdf, preserve_index=False)
ds[split] = good_dataset
return ds
def subset_split(dataset, split, n_samples):
return dataset[split].shuffle(seed=42).select(range(n_samples))
def convert_to_tf_dataset(dataset, data_collator, shuffle_flag, batch_size):
return (
dataset.to_tf_dataset(
columns=["attention_mask", "input_ids", "token_type_ids"],
label_cols=["labels"],
shuffle=shuffle_flag,
collate_fn=data_collator,
batch_size=batch_size
)
)
def get_label_metadata(dataset):
"""
It takes a dataset and returns a list of labels, a dictionary mapping label ids to labels, and a
dictionary mapping labels to label ids
Args:
dataset: the dataset object
"""
labels = [label for label in dataset['train'].features.keys() if label not in ['text', 'label_descriptions']]
id2label = dict(enumerate(labels))
label2id = {label:idx for idx, label in enumerate(labels)}
return labels, id2label, label2id
def main():
logging.basicConfig(level=logging.INFO)
logging.info("Loading dataset")
dataset = load_dataset("opentargets/clinical_trial_reason_to_stop", split='train').train_test_split(test_size=0.1)
global labels
labels, id2label, label2id = get_label_metadata(dataset)
dataset = explode_label_columns(dataset, label2id)
logging.info("Tokenizing dataset")
global tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
small_train_dataset = subset_split(tokenized_datasets, "train", 500)
small_eval_dataset = subset_split(tokenized_datasets, "test", 50)
logging.info("Loading model")
data_collator = DefaultDataCollator(return_tensors="tf")
tf_train_dataset = convert_to_tf_dataset(small_train_dataset, data_collator, shuffle_flag=True, batch_size=32)
tf_validation_dataset = convert_to_tf_dataset(small_eval_dataset, data_collator, shuffle_flag=False, batch_size=32)
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=17, id2label=id2label, label2id=label2id)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=tf.metrics.SparseCategoricalAccuracy(),
)
logging.info("Training model")
model.fit(tf_train_dataset, epochs=3, validation_data=tf_validation_dataset)
model.save_pretrained('models/model_500n_3_epochs_classificator_tf', saved_model=True, save_format='tf')
tokenizer.save_pretrained('models/model_500n_3_epochs_classificator_tf_tokenizer')
logging.info("Model saved. Exiting.")
if __name__ == "__main__":
main()
## EVALUATION
from datasets import load_dataset, DatasetDict, Dataset
from evaluate import evaluator
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
from typing import Union
def explode_label_columns(
data: Union[Dataset, DatasetDict],
split_names: list,
label2id: dict):
def fix_dataset(dataset):
pdf = dataset.to_pandas().explode("label_descriptions").rename({"label_descriptions": "label_description"}, axis=1).reset_index()
pdf["label"] = pdf["label_description"].map(label2id)
pdf = pdf[["text", "label", "label_description"]]
good_dataset = Dataset.from_pandas(pdf, preserve_index=False)
return good_dataset
ds = DatasetDict()
if isinstance(data, DatasetDict):
for split in split_names:
ds[split] = fix_dataset(data[split])
elif isinstance(data, Dataset):
split = split_names[0]
ds[split] = fix_dataset(data)
return ds
dataset_agg = load_dataset("opentargets/clinical_trial_reason_to_stop", split="all")
model = TFAutoModelForSequenceClassification.from_pretrained("./model_3_epochs_classificator_tf", local_files_only=True)
dataset = explode_label_columns(dataset_agg, ["all"], model.config.label2id)
# Evaluate from a local model
tokenizer = AutoTokenizer.from_pretrained("./model_3_epochs_classificator_tf", local_files_only=True, from_pt=False)
task_evaluator = evaluator("text-classification")
eval_results = task_evaluator.compute(
model_or_pipeline=model,
data=dataset["all"],
label_mapping=model.config.label2id,
tokenizer=tokenizer,
metric="f1",
)
print(eval_results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment