Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created February 21, 2024 21:40
Show Gist options
  • Save pszemraj/a7fe99569d22dffb0568e253de805de7 to your computer and use it in GitHub Desktop.
Save pszemraj/a7fe99569d22dffb0568e253de805de7 to your computer and use it in GitHub Desktop.
textsum - run summarization on directory on CPU with IPEX optimization
"""
cli.py - Command line interface for textsum.
this edition: fast CPU inference with intel IPEX https://archive.ph/oY5b1
Usage:
textsum-dir --help
"""
import os
import logging
import pprint as pp
import random
from pathlib import Path
from typing import Optional
os.environ["OMP_NUM_THREADS"] = f"{max(1, os.cpu_count() - 2)}"
os.environ["MKL_NUM_THREADS"] = f"{max(1, os.cpu_count() - 2)}"
os.environ["OMP_THREAD_LIMIT"] = f"{os.cpu_count()}"
import fire # noqa: E402
import intel_extension_for_pytorch as ipex # noqa: E402
import textsum # noqa: E402
import torch # noqa: E402
from intel_extension_for_pytorch.quantization import convert, prepare # noqa: E402
from textsum.summarize import Summarizer # noqa: E402
from textsum.utils import enable_tf32, setup_logging # noqa: E402
from tqdm.auto import tqdm # noqa: E402
EXAMPLE_TEXT = "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."
def log_mem_footprint(test_model):
"""Prints the memory footprint of the given model."""
fp = test_model.get_memory_footprint() * (10**-9)
print(f"model memory footprint:\t{round(fp, 2)} GB")
def quantize_model_ipex(model: Summarizer, batch_size: int = 8, jit_trace=False):
"""
Quantize the provided PyTorch model using Intel Extension for PyTorch (IPEX) dynamic quantization and evaluate its memory footprint.
See ipex docs at https://archive.ph/ClFBd for more details
Parameters:
- model: The PyTorch model to be quantized.
- jit_trace: Whether to trace the model with torch.jit.trace.
Returns:
- converted_model: The quantized model.
"""
example_inputs = model.tokenizer(
EXAMPLE_TEXT, padding="max_length", return_tensors="pt"
)
dynamic_qconfig = ipex.quantization.default_dynamic_qconfig_mapping
log_mem_footprint(model.model)
# dummy forward pass data
vocab_size = model.model.config.vocab_size
seq_length = model.tokenizer.model_max_length
example_inputs = torch.randint(vocab_size, size=[batch_size, seq_length])
# Prepare the model for dynamic quantization
prepared_model = prepare(
model.model,
dynamic_qconfig,
example_inputs=example_inputs,
bn_folding=False,
)
# Convert the model to its quantized version
model.model = convert(prepared_model)
log_mem_footprint(model.model)
logging.info(f"dtype is {model.model.dtype}")
if jit_trace:
vocab_size = model.model.config.vocab_size
seq_length = model.tokenizer.model_max_length
data = torch.randint(vocab_size, size=[batch_size, seq_length])
with torch.no_grad():
traced_model = torch.jit.trace(
model.model, (data,), check_trace=False, strict=False
)
model.model = torch.jit.freeze(traced_model)
return model
def main(
input_dir: str,
output_dir: Optional[str] = None,
model: str = "pszemraj/long-t5-tglobal-base-16384-book-summary",
no_cuda: bool = False,
tf32: bool = False,
force_cache: bool = False,
load_in_8bit: bool = False,
compile: bool = False,
quantize: bool = False,
use_amp: bool = False,
optimum_onnx: bool = False,
batch_length: int = 4096,
batch_stride: int = 16,
num_beams: int = 4,
length_penalty: float = 0.8,
repetition_penalty: float = 2.5,
max_length_ratio: float = 0.25,
min_length: int = 8,
encoder_no_repeat_ngram_size: int = 4,
no_repeat_ngram_size: int = 3,
early_stopping: bool = True,
shuffle: bool = False,
lowercase: bool = False,
loglevel: Optional[int] = logging.INFO,
logfile: Optional[str] = None,
file_extension: str = "txt",
skip_completed: bool = False,
disable_progress_bar: bool = False,
):
"""
Main function to summarize text files in a directory.
Args:
input_dir (str, required): The directory containing the input files.
output_dir (str, optional): Directory to write the output files. If None, writes to input_dir/summarized.
model (str, optional): The name of the model to use for summarization. Default: "pszemraj/long-t5-tglobal-base-16384-book-summary".
no_cuda (bool, optional): Flag to not use cuda if available. Default: False.
tf32 (bool, optional): Enable tf32 data type for computation (requires ampere series GPU or newer). Default: False.
force_cache (bool, optional): Force the use_cache flag to True in the Summarizer. Default: False.
load_in_8bit (bool, optional): Flag to load the model in 8 bit precision (requires bitsandbytes). Default: False.
compile (bool, optional): Compile the model for inference (requires torch 2.0+). Default: False.
optimum_onnx (bool, optional): Optimize the model for inference (requires onnxruntime-tools). Default: False.
batch_length (int, optional): The length of each batch. Default: 4096.
batch_stride (int, optional): The stride of each batch. Default: 16.
num_beams (int, optional): The number of beams to use for beam search. Default: 4.
length_penalty (float, optional): The length penalty to use for decoding. Default: 0.8.
repetition_penalty (float, optional): The repetition penalty to use for beam search. Default: 2.5.
max_length_ratio (float, optional): The maximum length of the summary as a ratio of the batch length. Default: 0.25.
min_length (int, optional): The minimum length of the summary. Default: 8.
encoder_no_repeat_ngram_size (int, optional): Encoder no repeat ngram size (input text). Smaller values mean more unique summaries. Default: 4.
no_repeat_ngram_size (int, optional): The decoder no repeat ngram size (output text). Default: 3.
early_stopping (bool, optional): Whether to use early stopping. Default: True.
shuffle (bool, optional): Shuffle the input files before summarizing. Default: False.
lowercase (bool, optional): Whether to lowercase the input text. Default: False.
loglevel (int, optional): The log level to use (default: 20 - INFO). Default: 30.
logfile (str, optional): Path to the log file. This will set loglevel to INFO (if not set) and write to the file.
file_extension (str, optional): The file extension to use when searching for input files., defaults to "txt"
skip_completed (bool, optional): Skip files that have already been summarized. Default: False.
Returns:
None
"""
setup_logging(loglevel, logfile)
logging.info("starting textsum cli")
logging.info(f"textsum version:\t{textsum.__version__}")
params = {
"min_length": min_length,
"encoder_no_repeat_ngram_size": encoder_no_repeat_ngram_size,
"no_repeat_ngram_size": no_repeat_ngram_size,
"repetition_penalty": repetition_penalty,
"num_beams": num_beams,
"num_beam_groups": 1,
"length_penalty": length_penalty,
"early_stopping": early_stopping,
"do_sample": False,
}
if tf32:
enable_tf32() # enable tf32 for computation
summarizer = Summarizer(
model_name_or_path=model,
use_cuda=not no_cuda,
token_batch_length=batch_length,
batch_stride=batch_stride,
max_length_ratio=max_length_ratio,
load_in_8bit=load_in_8bit,
compile_model=False,
optimum_onnx=optimum_onnx,
force_cache=force_cache,
disable_progress_bar=disable_progress_bar,
**params,
)
if quantize:
logging.info("quantizing model")
summarizer = quantize_model_ipex(summarizer)
# general ipex optimizations
summarizer.model = ipex.optimize(
summarizer.model,
weights_prepack=False,
conv_bn_folding=False,
linear_bn_folding=False,
replace_dropout_with_identity=True,
auto_kernel_selection=True,
)
if compile:
logging.info("compiling model")
summarizer.model = torch.compile(summarizer.model, backend="ipex")
summarizer.print_config()
logging.info(summarizer.config)
# get the input files
input_files = list(Path(input_dir).glob(f"*.{file_extension}"))
logging.info(f"found {len(input_files)} input files")
if shuffle:
logging.info("shuffling input files")
random.SystemRandom().shuffle(input_files)
# get the output directory
output_dir = Path(output_dir) if output_dir else Path(input_dir) / "summarized"
output_dir.mkdir(exist_ok=True, parents=True)
failed_files = []
completed_files = []
for f in tqdm(input_files, desc="summarizing files"):
_prospective_output_file = output_dir / f"{f.stem}_summary.txt"
if skip_completed and _prospective_output_file.exists():
logging.info(f"skipping file (found existing summary):\t{str(f)}")
continue
try:
if use_amp:
with torch.cpu.amp.autocast():
_ = summarizer.summarize_file(
file_path=f, output_dir=output_dir, lowercase=lowercase
)
else:
_ = summarizer.summarize_file(
file_path=f, output_dir=output_dir, lowercase=lowercase
)
completed_files.append(str(f))
except Exception as e:
logging.error(f"failed to summarize file:\t{f}")
logging.error(e)
print(e)
failed_files.append(f)
if isinstance(e, RuntimeError):
# if a runtime error occurs, exit immediately
logging.error("Stopping summarization: runtime error")
failed_files.extend(input_files[input_files.index(f) + 1 :])
break
logging.info(f"failed to summarize {len(failed_files)} files")
if len(failed_files) > 0:
logging.info(f"failed files:\n\t{pp.pformat(failed_files)}")
logging.debug("saving summarizer params and config")
summarizer.save_params(output_path=output_dir, hf_tag=model)
summarizer.save_config(output_dir / "textsum_config.json")
logging.info(
f"finished summarizing files - output dir:\n\t{str(output_dir.resolve())}"
)
def run():
"""Entry point for console_scripts"""
fire.Fire(main)
if __name__ == "__main__":
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment