Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active July 12, 2022 22:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pszemraj/c1b0a76445418b6bbddd5f9633d1bb7f to your computer and use it in GitHub Desktop.
Save pszemraj/c1b0a76445418b6bbddd5f9633d1bb7f to your computer and use it in GitHub Desktop.
generates the rest of the email using a textgen model from transformers
"""
email_gen.py - generates the rest of the email using a textgen model from transformers
"""
import logging
import os
import argparse
from pathlib import Path
import time
import pprint as pp
import torch
import transformers
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
filename="email_gen.log",
filemode="a",
)
def call_model(
generator: transformers.pipeline,
prompt: str,
num_beams=4,
min_length=4,
max_length=64,
no_repeat_ngram_size=3,
temperature=0.3,
max_time=600,
partial_text=False,
verbose=False,
):
"""
call_model - a helper function that calls the model and returns the generated text
Args:
generator (transformers.pipeline): the model to use
prompt (str): the prompt to use
num_beams (int, optional): number of beams to use for beam search (default: 4)
min_length (int, optional): min length of generated text (default: 4)
max_length (int, optional): max length of generated text (default: 64)
no_repeat_ngram_size (int, optional): no repeat ngram size (default: 3)
temperature (float, optional): temperature for generation (default: 0.3)
max_time (int, optional): max time for generation in seconds (default: 600)
partial_text (bool, optional): return only the generated text (default: False)
verbose (bool, optional): print verbose output (default: False)
Returns:
str: the generated text (email)
"""
st = time.perf_counter()
logging.info(
f"generating response for prompt:\n{prompt} with num_beams:\n\t{num_beams}"
)
print(f"generating response for prompt:\t{prompt}\twith num_beams:\t{num_beams}")
result = generator(
prompt,
min_length=min_length + len(prompt),
max_length=max_length + len(prompt),
no_repeat_ngram_size=no_repeat_ngram_size,
repetition_penalty=3.5,
length_penalty=0.8,
temperature=temperature,
num_beams=num_beams,
max_time=max_time,
remove_invalid_values=True,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
do_sample=False,
early_stopping=True,
return_full_text=not partial_text,
)
response = result[0]["generated_text"]
if verbose:
w_prompt = f"<PROMPT>{prompt}<END-OF-PROMPT>"
pp.pprint(w_prompt + response)
rt = (time.perf_counter() - st) / 60
logging.info(f"runtime: {rt:.2f} minutes")
return response
def get_parser():
"""
get_parser - a helper function for the argparse module
"""
parser = argparse.ArgumentParser(
description="remove all instances of a string from filenames."
)
parser.add_argument(
"-p",
"--prompt",
type=str,
required=False,
default=None,
help="prompt to generate from",
)
parser.add_argument(
"-m",
"--model",
type=str,
required=False,
default="pszemraj/distilgpt2-email-generation",
help="the model tag on huggingface OR path to model directory",
)
parser.add_argument(
"-i",
"--input-path",
required=False,
type=str,
default=None,
help="path to the input text file (optional)",
)
parser.add_argument(
"-o",
"--output-path",
default=None,
type=str,
help="path to the output text file (optional)",
)
# generation params
parser.add_argument(
"-nb",
"--num-beams",
required=False,
default=4,
type=int,
help="number of beams to use for beam search (default: 4)",
)
parser.add_argument(
"-ml",
"--max-length",
required=False,
default=64,
type=int,
help="max length of generated text (default: 64)",
)
parser.add_argument(
"--min-length",
required=False,
default=4,
type=int,
help="min length of generated text (default: 4)",
)
parser.add_argument(
"-r",
"--no-repeat-ngram-size",
required=False,
default=3,
type=int,
help="no repeat ngram size (default: 3)",
)
parser.add_argument(
"-t",
"--temperature",
required=False,
default=0.3,
type=float,
help="temperature for generation (default: 0.3)",
)
parser.add_argument(
"-mt",
"--max-time",
required=False,
default=600,
type=int,
help="max time for generation (default: 300)",
)
parser.add_argument(
"-pt",
"--partial-text",
default=False,
action="store_true",
help="return only the generated text (default: False)",
)
parser.add_argument(
"-v",
"--verbose",
required=False,
action="store_true",
help="print verbose output",
)
return parser
if __name__ == "__main__":
args = get_parser().parse_args()
logging.info(f"args: {args}")
input_path = Path(args.input_path) if args.input_path is not None else None
output_path = Path(args.output_path) if args.output_path is not None else None
prompt = args.prompt if args.prompt is not None else None
assert (
prompt is not None or input_path is not None
), "must provide either prompt or input_path"
if prompt is None:
with open(input_path, "r") as f:
prompt = f.read()
model_tag = args.model
num_beams = args.num_beams
max_length = args.max_length
min_length = args.min_length
no_repeat_ngram_size = args.no_repeat_ngram_size
temperature = args.temperature
max_time = args.max_time
partial_text = args.partial_text
verbose = args.verbose
email_gen = transformers.pipeline(
"text-generation",
model_tag,
use_fast=False,
device=0 if torch.cuda.is_available() else -1,
)
generated_text = call_model(
email_gen,
prompt,
num_beams=num_beams,
min_length=min_length,
max_length=max_length,
no_repeat_ngram_size=no_repeat_ngram_size,
temperature=temperature,
max_time=max_time,
partial_text=partial_text,
verbose=verbose,
)
print(
"\n" * 2,
generated_text,
)
if output_path is not None:
with open(output_path, "w", encoding="utf-8", errors="ignore") as f:
f.write(generated_text)
if verbose:
print(f"wrote generated text to {output_path}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment