Created
February 21, 2024 21:40
-
-
Save pszemraj/a7fe99569d22dffb0568e253de805de7 to your computer and use it in GitHub Desktop.
textsum - run summarization on directory on CPU with IPEX optimization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
cli.py - Command line interface for textsum. | |
this edition: fast CPU inference with intel IPEX https://archive.ph/oY5b1 | |
Usage: | |
textsum-dir --help | |
""" | |
import os | |
import logging | |
import pprint as pp | |
import random | |
from pathlib import Path | |
from typing import Optional | |
os.environ["OMP_NUM_THREADS"] = f"{max(1, os.cpu_count() - 2)}" | |
os.environ["MKL_NUM_THREADS"] = f"{max(1, os.cpu_count() - 2)}" | |
os.environ["OMP_THREAD_LIMIT"] = f"{os.cpu_count()}" | |
import fire # noqa: E402 | |
import intel_extension_for_pytorch as ipex # noqa: E402 | |
import textsum # noqa: E402 | |
import torch # noqa: E402 | |
from intel_extension_for_pytorch.quantization import convert, prepare # noqa: E402 | |
from textsum.summarize import Summarizer # noqa: E402 | |
from textsum.utils import enable_tf32, setup_logging # noqa: E402 | |
from tqdm.auto import tqdm # noqa: E402 | |
EXAMPLE_TEXT = "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct." | |
def log_mem_footprint(test_model): | |
"""Prints the memory footprint of the given model.""" | |
fp = test_model.get_memory_footprint() * (10**-9) | |
print(f"model memory footprint:\t{round(fp, 2)} GB") | |
def quantize_model_ipex(model: Summarizer, batch_size: int = 8, jit_trace=False): | |
""" | |
Quantize the provided PyTorch model using Intel Extension for PyTorch (IPEX) dynamic quantization and evaluate its memory footprint. | |
See ipex docs at https://archive.ph/ClFBd for more details | |
Parameters: | |
- model: The PyTorch model to be quantized. | |
- jit_trace: Whether to trace the model with torch.jit.trace. | |
Returns: | |
- converted_model: The quantized model. | |
""" | |
example_inputs = model.tokenizer( | |
EXAMPLE_TEXT, padding="max_length", return_tensors="pt" | |
) | |
dynamic_qconfig = ipex.quantization.default_dynamic_qconfig_mapping | |
log_mem_footprint(model.model) | |
# dummy forward pass data | |
vocab_size = model.model.config.vocab_size | |
seq_length = model.tokenizer.model_max_length | |
example_inputs = torch.randint(vocab_size, size=[batch_size, seq_length]) | |
# Prepare the model for dynamic quantization | |
prepared_model = prepare( | |
model.model, | |
dynamic_qconfig, | |
example_inputs=example_inputs, | |
bn_folding=False, | |
) | |
# Convert the model to its quantized version | |
model.model = convert(prepared_model) | |
log_mem_footprint(model.model) | |
logging.info(f"dtype is {model.model.dtype}") | |
if jit_trace: | |
vocab_size = model.model.config.vocab_size | |
seq_length = model.tokenizer.model_max_length | |
data = torch.randint(vocab_size, size=[batch_size, seq_length]) | |
with torch.no_grad(): | |
traced_model = torch.jit.trace( | |
model.model, (data,), check_trace=False, strict=False | |
) | |
model.model = torch.jit.freeze(traced_model) | |
return model | |
def main( | |
input_dir: str, | |
output_dir: Optional[str] = None, | |
model: str = "pszemraj/long-t5-tglobal-base-16384-book-summary", | |
no_cuda: bool = False, | |
tf32: bool = False, | |
force_cache: bool = False, | |
load_in_8bit: bool = False, | |
compile: bool = False, | |
quantize: bool = False, | |
use_amp: bool = False, | |
optimum_onnx: bool = False, | |
batch_length: int = 4096, | |
batch_stride: int = 16, | |
num_beams: int = 4, | |
length_penalty: float = 0.8, | |
repetition_penalty: float = 2.5, | |
max_length_ratio: float = 0.25, | |
min_length: int = 8, | |
encoder_no_repeat_ngram_size: int = 4, | |
no_repeat_ngram_size: int = 3, | |
early_stopping: bool = True, | |
shuffle: bool = False, | |
lowercase: bool = False, | |
loglevel: Optional[int] = logging.INFO, | |
logfile: Optional[str] = None, | |
file_extension: str = "txt", | |
skip_completed: bool = False, | |
disable_progress_bar: bool = False, | |
): | |
""" | |
Main function to summarize text files in a directory. | |
Args: | |
input_dir (str, required): The directory containing the input files. | |
output_dir (str, optional): Directory to write the output files. If None, writes to input_dir/summarized. | |
model (str, optional): The name of the model to use for summarization. Default: "pszemraj/long-t5-tglobal-base-16384-book-summary". | |
no_cuda (bool, optional): Flag to not use cuda if available. Default: False. | |
tf32 (bool, optional): Enable tf32 data type for computation (requires ampere series GPU or newer). Default: False. | |
force_cache (bool, optional): Force the use_cache flag to True in the Summarizer. Default: False. | |
load_in_8bit (bool, optional): Flag to load the model in 8 bit precision (requires bitsandbytes). Default: False. | |
compile (bool, optional): Compile the model for inference (requires torch 2.0+). Default: False. | |
optimum_onnx (bool, optional): Optimize the model for inference (requires onnxruntime-tools). Default: False. | |
batch_length (int, optional): The length of each batch. Default: 4096. | |
batch_stride (int, optional): The stride of each batch. Default: 16. | |
num_beams (int, optional): The number of beams to use for beam search. Default: 4. | |
length_penalty (float, optional): The length penalty to use for decoding. Default: 0.8. | |
repetition_penalty (float, optional): The repetition penalty to use for beam search. Default: 2.5. | |
max_length_ratio (float, optional): The maximum length of the summary as a ratio of the batch length. Default: 0.25. | |
min_length (int, optional): The minimum length of the summary. Default: 8. | |
encoder_no_repeat_ngram_size (int, optional): Encoder no repeat ngram size (input text). Smaller values mean more unique summaries. Default: 4. | |
no_repeat_ngram_size (int, optional): The decoder no repeat ngram size (output text). Default: 3. | |
early_stopping (bool, optional): Whether to use early stopping. Default: True. | |
shuffle (bool, optional): Shuffle the input files before summarizing. Default: False. | |
lowercase (bool, optional): Whether to lowercase the input text. Default: False. | |
loglevel (int, optional): The log level to use (default: 20 - INFO). Default: 30. | |
logfile (str, optional): Path to the log file. This will set loglevel to INFO (if not set) and write to the file. | |
file_extension (str, optional): The file extension to use when searching for input files., defaults to "txt" | |
skip_completed (bool, optional): Skip files that have already been summarized. Default: False. | |
Returns: | |
None | |
""" | |
setup_logging(loglevel, logfile) | |
logging.info("starting textsum cli") | |
logging.info(f"textsum version:\t{textsum.__version__}") | |
params = { | |
"min_length": min_length, | |
"encoder_no_repeat_ngram_size": encoder_no_repeat_ngram_size, | |
"no_repeat_ngram_size": no_repeat_ngram_size, | |
"repetition_penalty": repetition_penalty, | |
"num_beams": num_beams, | |
"num_beam_groups": 1, | |
"length_penalty": length_penalty, | |
"early_stopping": early_stopping, | |
"do_sample": False, | |
} | |
if tf32: | |
enable_tf32() # enable tf32 for computation | |
summarizer = Summarizer( | |
model_name_or_path=model, | |
use_cuda=not no_cuda, | |
token_batch_length=batch_length, | |
batch_stride=batch_stride, | |
max_length_ratio=max_length_ratio, | |
load_in_8bit=load_in_8bit, | |
compile_model=False, | |
optimum_onnx=optimum_onnx, | |
force_cache=force_cache, | |
disable_progress_bar=disable_progress_bar, | |
**params, | |
) | |
if quantize: | |
logging.info("quantizing model") | |
summarizer = quantize_model_ipex(summarizer) | |
# general ipex optimizations | |
summarizer.model = ipex.optimize( | |
summarizer.model, | |
weights_prepack=False, | |
conv_bn_folding=False, | |
linear_bn_folding=False, | |
replace_dropout_with_identity=True, | |
auto_kernel_selection=True, | |
) | |
if compile: | |
logging.info("compiling model") | |
summarizer.model = torch.compile(summarizer.model, backend="ipex") | |
summarizer.print_config() | |
logging.info(summarizer.config) | |
# get the input files | |
input_files = list(Path(input_dir).glob(f"*.{file_extension}")) | |
logging.info(f"found {len(input_files)} input files") | |
if shuffle: | |
logging.info("shuffling input files") | |
random.SystemRandom().shuffle(input_files) | |
# get the output directory | |
output_dir = Path(output_dir) if output_dir else Path(input_dir) / "summarized" | |
output_dir.mkdir(exist_ok=True, parents=True) | |
failed_files = [] | |
completed_files = [] | |
for f in tqdm(input_files, desc="summarizing files"): | |
_prospective_output_file = output_dir / f"{f.stem}_summary.txt" | |
if skip_completed and _prospective_output_file.exists(): | |
logging.info(f"skipping file (found existing summary):\t{str(f)}") | |
continue | |
try: | |
if use_amp: | |
with torch.cpu.amp.autocast(): | |
_ = summarizer.summarize_file( | |
file_path=f, output_dir=output_dir, lowercase=lowercase | |
) | |
else: | |
_ = summarizer.summarize_file( | |
file_path=f, output_dir=output_dir, lowercase=lowercase | |
) | |
completed_files.append(str(f)) | |
except Exception as e: | |
logging.error(f"failed to summarize file:\t{f}") | |
logging.error(e) | |
print(e) | |
failed_files.append(f) | |
if isinstance(e, RuntimeError): | |
# if a runtime error occurs, exit immediately | |
logging.error("Stopping summarization: runtime error") | |
failed_files.extend(input_files[input_files.index(f) + 1 :]) | |
break | |
logging.info(f"failed to summarize {len(failed_files)} files") | |
if len(failed_files) > 0: | |
logging.info(f"failed files:\n\t{pp.pformat(failed_files)}") | |
logging.debug("saving summarizer params and config") | |
summarizer.save_params(output_path=output_dir, hf_tag=model) | |
summarizer.save_config(output_dir / "textsum_config.json") | |
logging.info( | |
f"finished summarizing files - output dir:\n\t{str(output_dir.resolve())}" | |
) | |
def run(): | |
"""Entry point for console_scripts""" | |
fire.Fire(main) | |
if __name__ == "__main__": | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment