Skip to content

Instantly share code, notes, and snippets.

@ngaloppo
Last active August 10, 2023 16:32
Show Gist options
  • Save ngaloppo/3ac1de817d1588aa411f660e773b5d2c to your computer and use it in GitHub Desktop.
Save ngaloppo/3ac1de817d1588aa411f660e773b5d2c to your computer and use it in GitHub Desktop.
"""## Loading the dataset"""
from datasets import load_dataset, load_metric, load_from_disk
from pathlib import Path
raw_datasets = load_dataset("samsum")
metric = load_metric("rouge")
model_checkpoint = "lidiya/bart-large-xsum-samsum"
"""## Preprocessing the data"""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
max_input_length = 512
max_target_length = 128
def preprocess_function(examples):
inputs = [doc for doc in examples["dialogue"]]
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(
examples["summary"], max_length=max_target_length, truncation=True
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_ds = raw_datasets.map(preprocess_function, batched=True)
"""## Fine-tuning the model"""
from transformers import (
AutoModelForSeq2SeqLM,
DataCollatorForSeq2Seq,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
)
from optimum.intel import OVModelForSeq2SeqLM
# model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
ov_model_id = f"{model_checkpoint}-ov-fp32"
ov_model_id = "ptq_model"
if (Path(ov_model_id) / "config.json").exists():
model = OVModelForSeq2SeqLM.from_pretrained(ov_model_id)
else:
model = OVModelForSeq2SeqLM.from_pretrained(model_checkpoint, export=True)
model.save_pretrained(ov_model_id)
import nltk
import numpy as np
nltk.download("punkt")
def compute_metrics(eval_pred, tokenizer):
predictions, labels = eval_pred
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Rouge expects a newline after each sentence
decoded_preds = [
"\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds
]
decoded_labels = [
"\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels
]
result = metric.compute(
predictions=decoded_preds, references=decoded_labels, use_stemmer=True
)
# Extract a few results
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
# Add mean generated length
prediction_lens = [
np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions
]
result["gen_len"] = np.mean(prediction_lens)
return {k: round(v, 4) for k, v in result.items()}
"""Prediction"""
# from tqdm import tqdm
# def predict(ds):
# predictions = []
# for input in tqdm(ds["input_ids"], total=len(ds["input_ids"])):
# input = input[None, :]
# pred = model.generate(input, max_new_tokens=max_target_length)
# predictions.append(pred.flatten().tolist())
# return predictions
# tokenized_validation_dataset = tokenized_ds["validation"].with_format("torch")
# predictions = predict(tokenized_validation_dataset)
# metrics = compute_metrics((predictions, tokenized_ds["validation"]["labels"]), tokenizer)
# print(metrics)
"""[Uploaded the model](https://huggingface.co/transformers/model_sharing.html) to the [🤗 Model Hub](https://huggingface.co/models). You can use it to generate results as shown below."""
from transformers import pipeline, SummarizationPipeline
# class BARTSummarizationPipeline(SummarizationPipeline):
# def __init__(self, max_length = 62, *args, **kwargs):
# self.max_length = max_length
# super().__init__(*args, **kwargs)
# def preprocess(self, inputs, truncation=True, **kwargs):
# kwargs.update({ 'max_length': self.max_length })
# return super().preprocess(inputs, truncation=True, **kwargs)
# summarizer = BARTSummarizationPipeline(model=model, tokenizer=tokenizer, max_length=max_input_length)
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, max_length=max_input_length)
conversation = """Hannah: Hey, do you have Betty's number?
Amanda: Lemme check
Amanda: Sorry, can't find it.
Amanda: Ask Larry
Amanda: He called her last time we were at the park together
Hannah: I don't know him well
Amanda: Don't be shy, he's very nice
Hannah: If you say so..
Hannah: I'd rather you texted him
Amanda: Just text him 🙂
Hannah: Urgh.. Alright
Hannah: Bye
Amanda: Bye bye
"""
print(summarizer(conversation))
from evaluate import evaluator
evaluator = evaluator("summarization")
evaluator.METRIC_KWARGS = { 'use_stemmer': True }
results = evaluator.compute(
model_or_pipeline=summarizer,
data=raw_datasets["validation"],
metric="rouge",
input_column="dialogue",
label_column="summary",
generation_kwargs={'max_new_tokens': max_target_length},
)
print(results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment