Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active January 23, 2024 07:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save pszemraj/14f7b13bd2d953176db2371e5d320915 to your computer and use it in GitHub Desktop.
Save pszemraj/14f7b13bd2d953176db2371e5d320915 to your computer and use it in GitHub Desktop.
basic implementation of a custom wrapper class for using the grammar synthesis text2text models
"""
Class for correcting text using a pretrained model grammar synthesis model.
- models are available here: https://hf.co/models?other=grammar%20synthesis
requirements for this snippet:
pip install -U transformers accelerate
NOTE: if you want to use 9-bit to fit the model on a smaller GPU, you need bitsandbytes:
pip install -U transformers accelerate bitsandbytes
"""
import warnings
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
class GrammarSynthesizer:
"""
Class for correcting text using a pretrained model grammar synthesis model.
models are available here: https://hf.co/models?other=grammar%20synthesis
# Example usage with the XL
corrector = GrammarSynthesizer("pszemraj/flan-t5-xl-grammar-synthesis")
raw_text = 'sky is blu.'
results = corrector(raw_text, num_beams=2)
print(results)
"""
DEFAULT_MAX_INPUT_LENGTH = 384
DEFAULT_MAX_LENGTH = 128
DEFAULT_NUM_BEAMS = 4
def __init__(
self,
model_name_or_path: str,
should_compile: bool = True,
load_in_8bit: bool = False,
):
"""
Initializes the GrammarSynthesizer.
Args:
model_name_or_path: The name or path of the pretrained model.
should_compile: If True, tries to compile the model for faster inference.
load_in_8bit: If True, loads model in 8-bit precision (lower memory usage). requires bitsandbytes
"""
self.model_name_or_path = model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = self._load_and_compile_model(model_name_or_path, should_compile)
def _load_and_compile_model(
self, model_name_or_path: str, should_compile: bool, load_in_8bit: bool
):
"""
Load and compile the model.
Args:
model_name_or_path: The name or path of the pretrained model.
should_compile: If True, tries to compile the model for faster inference.
load_in_8bit: If True, loads model in 8-bit precision (lower memory usage). requires bitsandbytes
Returns:
The loaded and potentially compiled model.
"""
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path, load_in_8bit=load_in_8bit, device_map="auto"
)
if should_compile:
try:
model = torch.compile(model)
except Exception as e:
print(f"Unable to compile model for faster inference. Reason: {e}")
should_compile = False
self.compiled_model = should_compile
return model
def _prepare_inputs(self, input_text: str):
"""
Prepares the inputs for the model.
Args:
input_text: The input text to prepare.
Returns:
The prepared inputs.
"""
inputs = self.tokenizer.encode(input_text, return_tensors="pt").to(
self.model.device
)
if len(inputs) > self.DEFAULT_MAX_INPUT_LENGTH:
warnings.warn(
"Input is longer than model training data. Unexpected behavior may occur. "
"Consider batch-processing smaller chunks."
)
return inputs
def generate_text(
self,
input_text: str,
max_length: int = DEFAULT_MAX_LENGTH,
num_beams: int = DEFAULT_NUM_BEAMS,
**kwargs,
):
"""
Generates text from the input.
Args:
input_text: The input text to generate from.
max_length: The maximum length of the generated text.
num_beams: The number of beams for beam search.
Returns:
The generated text.
"""
if len(input_text) < 2:
warnings.warn(
f"input text is too short to correct, returning:\t{input_text}"
)
return input_text
inputs = self._prepare_inputs(input_text)
outputs = self.model.generate(
inputs, max_length=max_length, num_beams=num_beams, **kwargs
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def __call__(self, input_text: str, **kwargs):
return self.generate_text(input_text, **kwargs)
@pszemraj
Copy link
Author

pszemraj commented Jun 2, 2023

click-able link for the models

See here: https://hf.co/models?other=grammar%20synthesis

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment