Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active September 4, 2023 16:38
Show Gist options
  • Save pszemraj/931411bb8faca93545f5636411652b93 to your computer and use it in GitHub Desktop.
Save pszemraj/931411bb8faca93545f5636411652b93 to your computer and use it in GitHub Desktop.
unsupervised summary eval using several metrics, including a new 'max salient similarity' score to compute faithfulness w.r.t. original document.
"""
eval_summaries.py - evaluate summary/document pairs via a variety of metrics,
Metrics include max salient similarity, topic similarity, compression factor,
readability scores, and spelling error fraction
details:
python eval_summaries.py --help
this script was developed while evaluating summaries generated with the textsum package
https://github.com/pszemraj/textsum - try it out!
"""
import csv
import json
import logging
import pathlib
import pprint as pp
import re
import sqlite3
import fire
import numpy as np
import sentence_transformers
from scipy.spatial.distance import jensenshannon
from sentence_splitter import SentenceSplitter
from sentence_transformers import SentenceTransformer, util
from sklearn.decomposition import NMF, LatentDirichletAllocation, TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from spellchecker import SpellChecker
from textstat import flesch_kincaid_grade, gunning_fog
from tqdm.auto import tqdm
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# metric functions
def spelling_error_frac(text: str) -> float:
"""spelling_error_frac - calculates the fraction of words in a text that are misspelled"""
spell = SpellChecker()
words = re.findall(r"\b\w+\b", text.lower())
misspelled_words = spell.unknown(words)
return len(misspelled_words) / len(words)
def cosine_similarity_score(
document: str,
summary: str,
ngram_range: tuple = (1, 3),
dtype: np.dtype = np.float32,
) -> float:
"""
cosine_similarity_score - calculates the cosine similarity between the tfidf vectors of a document and a summary
:param str document: document to be summarized
:param str summary: summary of the document
:param tuple ngram_range: ngram range to use, defaults to (1, 3)
:param np.dtype dtype: dtype to use for the tfidf matrix, defaults to np.float32
:return float: cosine similarity between the tfidf vectors of the document and summary
"""
vectorizer = TfidfVectorizer(ngram_range=ngram_range, dtype=dtype)
tfidf_matrix = vectorizer.fit_transform([document, summary])
return cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2]).flatten()[0]
def topic_similarity_score(
document: str,
summary: str,
n_topics: int = 10,
method: str = "lda",
distance: str = "jsd",
) -> float:
"""
topic_similarity_score - calculates the similarity between the topics of a document and a summary using topic modeling
:param str document: document to be summarized
:param str summary: summary of the document
:param int n_topics: number of topics to use, defaults to 10
:param str method: method to use for topic modeling, defaults to "lda" = Latent Dirichlet Allocation
:param str distance: distance metric to use, defaults to "jsd" = Jensen-Shannon Divergence
:raises ValueError: if method is not one of "lda", "nmf", or "lsa"
:return float: similarity between the topics of the document and summary
"""
vectorizer = TfidfVectorizer(stop_words="english")
tfidf_matrix = vectorizer.fit_transform([document, summary])
if method == "lda":
model = LatentDirichletAllocation(n_components=n_topics)
elif method == "nmf":
model = NMF(n_components=n_topics, max_iter=1000)
elif method == "lsa":
model = TruncatedSVD(n_components=n_topics)
else:
raise ValueError(
"Invalid method specified. Choose from 'lda', 'nmf', or 'lsa'."
)
topic_matrix = model.fit_transform(tfidf_matrix)
topic_matrix /= topic_matrix.sum(axis=1)[
:, np.newaxis
] # Normalize the rows to sum to 1
return (
1 - jensenshannon(topic_matrix[0], topic_matrix[1])
if distance == "jsd"
else cosine_similarity(topic_matrix[0:1], topic_matrix[1:2]).flatten()[0]
)
def compression_factor(document: str, summary: str) -> float:
"""
compression_factor - calculates the compression multiple document->summary in characters
ex: compression_factor("hello world", "hello") = 2.2
"""
# normalize all whitespace to single spaces in the document and summary
_nrml_doc = re.sub(r"\s+", " ", document)
_nrml_summary = re.sub(r"\s+", " ", summary)
return round(len(_nrml_doc) / len(_nrml_summary), 3)
def readability_scores(summary: str) -> dict:
"""readability_scores - calculates the readability scores of a summary"""
return {
"flesch_kincaid": flesch_kincaid_grade(summary),
"gunning_fog": gunning_fog(summary),
}
def max_salient_similarity(
document: str,
summary: str,
model: sentence_transformers.SentenceTransformer,
splitter: SentenceSplitter,
doc_chunk_size: int = 5,
summary_chunk_size: int = 1,
distance: str = "cosine",
) -> float:
"""
max_salient_similarity - calculates the semantic similarity between a document and a summary
:param str document: document to be summarized
:param str summary: summary of the document
:param sentence_transformers.SentenceTransformer model: sbert model to use for encoding
:param SentenceSplitter splitter: splitter to use for splitting the document and summary into sentences
:param int doc_chunk_size: number of sentences to encode at a time, defaults to 5
:param int summary_chunk_size: number of sentences to encode at a time, defaults to 1
:param str distance: distance metric to use, defaults to "cosine" = cosine similarity
:return float: semantic similarity between the document and summary
"""
document_sentences = splitter.split(document)
summary_sentences = splitter.split(summary)
document_embeddings = np.max(
[
np.max(
model.encode(
document_sentences[i : i + doc_chunk_size],
normalize_embeddings=True,
show_progress_bar=False,
),
axis=0,
)
for i in tqdm(
range(0, len(document_sentences), doc_chunk_size),
desc="Document embeddings",
)
],
axis=0,
) # note: changing the np.max in this function changes the pooling strategy
summary_embeddings = np.max(
[
np.max(
model.encode(
summary_sentences[i : i + summary_chunk_size],
normalize_embeddings=True,
show_progress_bar=False,
),
axis=0,
)
for i in tqdm(
range(0, len(summary_sentences), summary_chunk_size),
desc="Summary embeddings",
)
],
axis=0,
)
similarity = (
util.cos_sim(document_embeddings, summary_embeddings).flatten()
if distance == "cosine"
else util.dot_score(document_embeddings, summary_embeddings).flatten()
)
return float(similarity)
# fnnctions for saving results
def save_to_csv(
run_name: str,
summary_name: str,
scores: dict,
params: dict,
csv_path: str = "evaluation_results.csv",
):
"""write the results of the evaluation to a csv file"""
file_exists = pathlib.Path(csv_path).is_file()
with open(csv_path, "a", encoding="utf-8") as f:
writer = csv.writer(f)
if not file_exists:
header = (
["run_name", "summary_name"] + list(scores.keys()) + list(params.keys())
)
writer.writerow(header)
row = [run_name, summary_name] + list(scores.values()) + list(params.values())
writer.writerow(row)
def save_to_database(
run_name: str,
summary_name: str,
scores: dict,
params: dict = {},
db_path: str = "evaluation_results.sqlite",
csv_path: str = "evaluation_results.csv",
):
"""
save_to_database - saves the results of the evaluation to a sqlite database and a csv file
:param str run_name: name of the run
:param str summary_name: name of the summary file
:param dict scores: scores of the evaluation
:param dict params: run parameters, defaults to {}
:param str db_path: path to the sqlite database, defaults to "evaluation_results.sqlite"
:param str csv_path: path to the csv file, defaults to "evaluation_results.csv"
"""
save_to_csv(run_name, summary_name, scores, params, csv_path)
conn = sqlite3.connect(db_path)
c = conn.cursor()
c.execute(
"CREATE TABLE IF NOT EXISTS results ("
"run_name TEXT, summary_name TEXT, cosine_similarity REAL, "
"topic_similarity REAL, compression_factor REAL, "
"misspelled_percentage REAL, "
"flesch_kincaid REAL, gunning_fog REAL, max_salient_similarity REAL, params TEXT)"
) # Create table if it doesn't exist
c.execute(
"INSERT INTO results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
run_name,
summary_name,
scores["cosine_similarity"],
scores["topic_similarity"],
scores["compression_factor"],
scores["misspelled_percentage"],
scores["flesch_kincaid"],
scores["gunning_fog"],
scores["max_salient_similarity"],
json.dumps(params),
),
)
conn.commit()
conn.close()
def evaluate_summaries(
summary_directory: str,
document_directory: str,
run_name: str = None,
drop_section_scores: bool = True,
topic_similarity_method: str = "nmf",
n_topics: int = 15,
wandb_logging: bool = False,
wandb_project_name: str = "SummaRazor-evaluation",
db_path: str = "summ_evaluation_results.sqlite",
csv_path: str = "summ_evaluation_results.csv",
inference_param_path: str = "summarization_parameters.json",
sbert_model_name: str = "sentence-transformers/paraphrase-MiniLM-L3-v2",
doc_chunk_size: int = 5,
no_salient_similarity: bool = False,
debug: bool = False,
):
"""
evaluate_summaries - evaluate the summaries+reference documents by directory using multiple metrics
:param str summary_directory: summary directory
:param str document_directory: document directory (source of the summaries)
:param str run_name: name of the run, defaults to None & uses document directory name
:param bool drop_section_scores: whether to drop section scores, defaults to True
:param str topic_similarity_method: method to calculate topic similarity, defaults to "nmf"
:param int n_topics: number of topics to use, defaults to 15
:param bool wandb_logging: log to wandb, defaults to False
:param str wandb_project_name: wandb project name, defaults to "SummaRazor-evaluation"
:param str db_path: file path to sqlite database, defaults to "summ_evaluation_results.sqlite"
:param str csv_path: file path to csv file, defaults to "summ_evaluation_results.csv"
:param str inference_param_path: relative path to inference parameters file, defaults to "summarization_parameters.json"
:param str sbert_model_name: sentence bert model name, defaults to "sentence-transformers/paraphrase-MiniLM-L3-v2"
:param int doc_chunk_size: document chunk size to encode (in sentences), defaults to 5
:param bool no_salient_similarity: disable semantic similarity calculation with SBERT, defaults to False
:param bool debug: enable debug mode, defaults to False
"""
summary_path = pathlib.Path(summary_directory)
document_path = pathlib.Path(document_directory)
assert summary_path.is_dir(), f"Summary directory {summary_path} does not exist"
assert document_path.is_dir(), f"Document directory {document_path} does not exist"
if debug:
logger.setLevel(logging.DEBUG)
logger.debug("Debug mode enabled")
logger.debug(
f"Summary directory:\t{summary_path}\nDocument directory:\t{document_path}"
)
if run_name is None:
run_name = summary_path.name
if wandb_logging:
import wandb
logger.info("Logging to wandb")
wandb.init(project=wandb_project_name, name=run_name)
param_file = summary_path / inference_param_path
params = {}
if param_file.is_file():
logger.info(f"Loading parameters from {param_file}")
with open(param_file, "r", encoding="utf-8") as f:
params = json.load(f)
if wandb_logging:
wandb.config.update(params)
else:
logger.warning(f"Could not find parameters file at {param_file}")
if not no_salient_similarity:
logger.info(f"Loading SBERT model {sbert_model_name}")
sbert_model = SentenceTransformer(sbert_model_name)
splitter = SentenceSplitter(language="en")
# update params
params["topics"] = n_topics
params["tm_method"] = topic_similarity_method
params["drop_section_scores"] = drop_section_scores
params["sbert_model_name"] = sbert_model_name
params["no_salient_similarity"] = no_salient_similarity
# evaluate summaries
files = list(summary_path.glob("*_summary.txt"))
for summary_file in tqdm(files, desc="Evaluating summaries"):
logger.debug(f"Evaluating {summary_file.name}")
_document_name = summary_file.name.replace("_summary.txt", "")
document_file = document_path / f"{_document_name}.txt"
logger.debug(
f"Document file: {document_file}\nExists: {document_file.is_file()}"
)
try:
with open(document_file, "r", encoding="utf-8") as f:
document = f.read()
with open(summary_file, "r", encoding="utf-8") as f:
summary = f.read()
if drop_section_scores:
logger.debug("Dropping section scores")
summary = re.sub(
r"\nSection Scores.*?---\n", "", summary, flags=re.DOTALL
)
scores = {
"cosine_similarity": cosine_similarity_score(document, summary),
"topic_similarity": topic_similarity_score(
document, summary, n_topics=n_topics, method=topic_similarity_method
),
"compression_factor": compression_factor(document, summary),
"misspelled_percentage": spelling_error_frac(summary),
}
if (
not no_salient_similarity
): # Compute the semantic similarity only if the flag is not set
scores["max_salient_similarity"] = max_salient_similarity(
document,
summary,
model=sbert_model,
splitter=splitter,
doc_chunk_size=doc_chunk_size,
)
else:
scores["max_salient_similarity"] = np.nan
scores.update(readability_scores(summary))
logger.debug(f"Scores for {summary_file.name}: {pp.pformat(scores)}")
except Exception as e:
logger.error(f"Error evaluating {summary_file.name}: {e}")
print(f"Error evaluating {summary_file.name}: {e}")
continue
save_to_database(
run_name,
summary_file.name,
scores,
params,
db_path=db_path,
csv_path=csv_path,
) # save to database and csv
if wandb_logging:
wandb.log(scores)
logging.info(f"Results saved to {db_path} and {csv_path}.")
if wandb_logging:
wandb.finish()
if __name__ == "__main__":
fire.Fire(evaluate_summaries)

Improving 'Unsupervised' Summarization with Max Salient Similarity

Peter Szemraj Mar 16th, 2023

⚠️ Please note that while this is promising, it is still under research/testing and can probably be improved. Ideas/feedback welcome!

A novel approach to unsupervised evaluation of abstractive summaries is presented, which uses semantic similarity and max-pooling to account for paraphrasing and capture the most important information in a document.

scores

About

Evaluating abstractive summaries can be challenging, especially when looking for an unsupervised metric that is close to human judgment. Traditional methods, such as cosine similarity with n-grams, may not be effective due to the nature of abstractive summarization, which often involves rephrasing or paraphrasing the original text. Therefore, it is crucial to find a metric that takes into account the semantic similarity between the document and the summary while accounting for paraphrasing.

How/What

Initially, an attempt was made to use semantic similarity with mean pooling to encode text chunks (3 sentences per document chunk and 1 sentence per summary chunk) and then compute their similarity. However, this approach did not yield satisfactory results because it failed to capture the essential information contained in the summaries and the document.

An alternative approach was then explored: using max-pooling instead of mean-pooling when combining the encoded chunks. Max-pooling takes the maximum value of each feature in the vector, effectively highlighting the most salient information in the document and summary. I was excited to find this to be the first (and only) metric in the script that closely approximated my own human judgments on a so-called gauntlet of long-dccuments for summarization. You can find the definition in the max_salient_similarity function in eval_summaries.py.

The success of this approach can be attributed to its ability to capture the most important or "salient" information in both the document and the summary ( see this interactive plot ). By focusing on this key information via max pooling, the method can better assess how well the summary captures the essential aspects of the document that "stood out". This approach is particularly intuitive w.r.t. evaluating summaries because they are designed to convey the most important information from a document in a concise manner.

  • Note: from the models evaluated in the initial pass here, the bigbird-pegasus on arxiv and bigpatent produces pretty terrible summaries (see linked folder), while typically long-t5 models and LED-L would be the 'best'

Other interesting notes:

  • max_salient_similarity works quite well in my testing with sentence-transformers/paraphrase-MiniLM-L3-v2, which is quite fast even on CPU.
    • Tests with using other paraphrase models from SBERT did not yield a substantial or consistent improvement in separation of model performance (w.r.t. to human eval, and literally the numbers do not change much)
    • The idea behind using a paraphrase trained SBERT Models is intuitively that in abstractive summarization, ideas are often paraphrased or simplified in the summary.
    • at time of writing, I've only briefly explored other SBERT/embedding models, and found the utility to be worse than paraphrase models. Still, some like intfloat/e5-base etc should be tried
  • exploration of other pooling methods may be warranted as well, such as hierarchical pooling. min pooling for summarization use cases should be at least as bad as mean if not worse.

Concerns & Future Work

However, there are some potential drawbacks to this method. In its current form, it cannot yet be used as an optimization method because it assumes that all summaries are approximately the same length and that summary length is not a critical factor or concern. Consequently, there is no penalty for overly long summaries, which would be favored if this method were used for optimization. This issue is an area for further research and improvement.

Secondly, it is unclear at time of writing how well this metric penalizes models that 'make up'/hallucinate random facts/info in the summary that are not present in the source document. Needs to be evaluated further, potentially adjusting the granularity of pooling on the document side (and potentially the summary)

Conclusion

In conclusion, the use of semantic similarity with max-pooling proves to be a promising method for unsupervised evaluation of abstractive summaries. This approach accounts for paraphrasing and focuses on capturing the most important information from a document, which is more in line with human judgments. Further research and experimentation with this method could lead to more accurate and reliable metrics for evaluating summarization systems, especially for documents/data without a golden test to compute ROUGE scores on.

fire
numpy
pyspellchecker
scikit-learn
scipy
sentence-splitter
sentence-transformers
pyspellchecker
textstat
tqdm
@pszemraj
Copy link
Author

update 1

I believe there was a bug in how I was encoding the document, which was encoding each sentence separately, then pooling encoded sentences (max pooling) in chunks of doc_chunk_size, then pooling all of those together (max pooling) again for the final array comp.

I had planned for pooling to happen only once, at the end, and encode sentences together in aggregated string chunks of doc_chunk_size and then taking the max pooling of all of those embedded chunks. Below is an initial pass (have not tried running it yet) of this.

NOTE: it is entirely possible that the updated way/way I originally intended is worse and I accidentally found the way that actually works (logic: sentence-transformers are trained on sentences, and worse representation of larger groups might render things not useful). We'll see.

original

def max_salient_similarity(
    document: str,
    summary: str,
    model: sentence_transformers.SentenceTransformer,
    splitter: SentenceSplitter,
    doc_chunk_size: int = 5,
    summary_chunk_size: int = 1,
    distance: str = "cosine",
) -> float:
    """
    max_salient_similarity - calculates the semantic similarity between a document and a summary
    :param str document: document to be summarized
    :param str summary: summary of the document
    :param sentence_transformers.SentenceTransformer model: sbert model to use for encoding
    :param SentenceSplitter splitter: splitter to use for splitting the document and summary into sentences
    :param int doc_chunk_size: number of sentences to encode at a time, defaults to 5
    :param int summary_chunk_size: number of sentences to encode at a time, defaults to 1
    :param str distance: distance metric to use, defaults to "cosine" = cosine similarity
    :return float: semantic similarity between the document and summary
    """
    document_sentences = splitter.split(document)
    summary_sentences = splitter.split(summary)

    document_embeddings = np.max(
        [
            np.max(
                model.encode(
                    document_sentences[i : i + doc_chunk_size],
                    normalize_embeddings=True,
                    show_progress_bar=False,
                ),
                axis=0,
            )
            for i in tqdm(
                range(0, len(document_sentences), doc_chunk_size),
                desc="Document embeddings",
            )
        ],
        axis=0,
    )  # note: changing the np.max in this function changes the pooling strategy
    summary_embeddings = np.max(
        [
            np.max(
                model.encode(
                    summary_sentences[i : i + summary_chunk_size],
                    normalize_embeddings=True,
                    show_progress_bar=False,
                ),
                axis=0,
            )
            for i in tqdm(
                range(0, len(summary_sentences), summary_chunk_size),
                desc="Summary embeddings",
            )
        ],
        axis=0,
    )

    similarity = (
        util.cos_sim(document_embeddings, summary_embeddings).flatten()
        if distance == "cosine"
        else util.dot_score(document_embeddings, summary_embeddings).flatten()
    )

    return float(similarity)

'fixed' version (originally intended method)

⚠️ to be tested/updated/integrated if results in better representation of the differences in summaries from different models

import numpy as np
from tqdm import tqdm
from sentence_transformers import util


def max_salient_similarity(
    document: str,
    summary: str,
    model: sentence_transformers.SentenceTransformer,
    splitter: SentenceSplitter,
    doc_chunk_size: int = 5,
    summary_chunk_size: int = 1,
    distance: str = "cosine",
) -> float:
    """
    max_salient_similarity - calculates the semantic similarity between a document and a summary
    :param str document: document to be summarized
    :param str summary: summary of the document
    :param sentence_transformers.SentenceTransformer model: sbert model to use for encoding
    :param SentenceSplitter splitter: splitter to use for splitting the document and summary into sentences
    :param int doc_chunk_size: number of sentences to encode at a time, defaults to 5
    :param int summary_chunk_size: number of sentences to encode at a time, defaults to 1
    :param str distance: distance metric to use, defaults to "cosine" = cosine similarity
    :return float: semantic similarity between the document and summary
    """
    if distance not in ["cosine", "dot"]:
        raise ValueError(
            "Invalid distance metric. Supported metrics are 'cosine' and 'dot'."
        )

    document_sentences = splitter.split(document)
    summary_sentences = splitter.split(summary)

    document_embeddings = np.max(
        [
            model.encode(
                " ".join(document_sentences[i : i + doc_chunk_size]),
                normalize_embeddings=True,
                show_progress_bar=False,
            )
            for i in tqdm(
                range(0, len(document_sentences), doc_chunk_size),
                desc="Document embeddings",
            )
        ],
        axis=0,
    )

    summary_embeddings = np.max(
        [
            model.encode(
                " ".join(summary_sentences[i : i + summary_chunk_size]),
                normalize_embeddings=True,
                show_progress_bar=False,
            )
            for i in tqdm(
                range(0, len(summary_sentences), summary_chunk_size),
                desc="Summary embeddings",
            )
        ],
        axis=0,
    )

    similarity = (
        util.cos_sim(document_embeddings, summary_embeddings).flatten()
        if distance == "cosine"
        else util.dot_score(document_embeddings, summary_embeddings).flatten()
    )

    return float(similarity)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment