Last active
August 10, 2023 16:32
-
-
Save ngaloppo/3ac1de817d1588aa411f660e773b5d2c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""## Loading the dataset""" | |
from datasets import load_dataset, load_metric, load_from_disk | |
from pathlib import Path | |
raw_datasets = load_dataset("samsum") | |
metric = load_metric("rouge") | |
model_checkpoint = "lidiya/bart-large-xsum-samsum" | |
"""## Preprocessing the data""" | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
max_input_length = 512 | |
max_target_length = 128 | |
def preprocess_function(examples): | |
inputs = [doc for doc in examples["dialogue"]] | |
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) | |
# Setup the tokenizer for targets | |
with tokenizer.as_target_tokenizer(): | |
labels = tokenizer( | |
examples["summary"], max_length=max_target_length, truncation=True | |
) | |
model_inputs["labels"] = labels["input_ids"] | |
return model_inputs | |
tokenized_ds = raw_datasets.map(preprocess_function, batched=True) | |
"""## Fine-tuning the model""" | |
from transformers import ( | |
AutoModelForSeq2SeqLM, | |
DataCollatorForSeq2Seq, | |
Seq2SeqTrainingArguments, | |
Seq2SeqTrainer, | |
) | |
from optimum.intel import OVModelForSeq2SeqLM | |
# model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) | |
ov_model_id = f"{model_checkpoint}-ov-fp32" | |
ov_model_id = "ptq_model" | |
if (Path(ov_model_id) / "config.json").exists(): | |
model = OVModelForSeq2SeqLM.from_pretrained(ov_model_id) | |
else: | |
model = OVModelForSeq2SeqLM.from_pretrained(model_checkpoint, export=True) | |
model.save_pretrained(ov_model_id) | |
import nltk | |
import numpy as np | |
nltk.download("punkt") | |
def compute_metrics(eval_pred, tokenizer): | |
predictions, labels = eval_pred | |
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) | |
# Replace -100 in the labels as we can't decode them. | |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
# Rouge expects a newline after each sentence | |
decoded_preds = [ | |
"\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds | |
] | |
decoded_labels = [ | |
"\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels | |
] | |
result = metric.compute( | |
predictions=decoded_preds, references=decoded_labels, use_stemmer=True | |
) | |
# Extract a few results | |
result = {key: value.mid.fmeasure * 100 for key, value in result.items()} | |
# Add mean generated length | |
prediction_lens = [ | |
np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions | |
] | |
result["gen_len"] = np.mean(prediction_lens) | |
return {k: round(v, 4) for k, v in result.items()} | |
"""Prediction""" | |
# from tqdm import tqdm | |
# def predict(ds): | |
# predictions = [] | |
# for input in tqdm(ds["input_ids"], total=len(ds["input_ids"])): | |
# input = input[None, :] | |
# pred = model.generate(input, max_new_tokens=max_target_length) | |
# predictions.append(pred.flatten().tolist()) | |
# return predictions | |
# tokenized_validation_dataset = tokenized_ds["validation"].with_format("torch") | |
# predictions = predict(tokenized_validation_dataset) | |
# metrics = compute_metrics((predictions, tokenized_ds["validation"]["labels"]), tokenizer) | |
# print(metrics) | |
"""[Uploaded the model](https://huggingface.co/transformers/model_sharing.html) to the [🤗 Model Hub](https://huggingface.co/models). You can use it to generate results as shown below.""" | |
from transformers import pipeline, SummarizationPipeline | |
# class BARTSummarizationPipeline(SummarizationPipeline): | |
# def __init__(self, max_length = 62, *args, **kwargs): | |
# self.max_length = max_length | |
# super().__init__(*args, **kwargs) | |
# def preprocess(self, inputs, truncation=True, **kwargs): | |
# kwargs.update({ 'max_length': self.max_length }) | |
# return super().preprocess(inputs, truncation=True, **kwargs) | |
# summarizer = BARTSummarizationPipeline(model=model, tokenizer=tokenizer, max_length=max_input_length) | |
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, max_length=max_input_length) | |
conversation = """Hannah: Hey, do you have Betty's number? | |
Amanda: Lemme check | |
Amanda: Sorry, can't find it. | |
Amanda: Ask Larry | |
Amanda: He called her last time we were at the park together | |
Hannah: I don't know him well | |
Amanda: Don't be shy, he's very nice | |
Hannah: If you say so.. | |
Hannah: I'd rather you texted him | |
Amanda: Just text him 🙂 | |
Hannah: Urgh.. Alright | |
Hannah: Bye | |
Amanda: Bye bye | |
""" | |
print(summarizer(conversation)) | |
from evaluate import evaluator | |
evaluator = evaluator("summarization") | |
evaluator.METRIC_KWARGS = { 'use_stemmer': True } | |
results = evaluator.compute( | |
model_or_pipeline=summarizer, | |
data=raw_datasets["validation"], | |
metric="rouge", | |
input_column="dialogue", | |
label_column="summary", | |
generation_kwargs={'max_new_tokens': max_target_length}, | |
) | |
print(results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment