Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active March 14, 2024 04:24
Show Gist options
  • Save pszemraj/522159b6d98986bcf4b54f98fbb9625f to your computer and use it in GitHub Desktop.
Save pszemraj/522159b6d98986bcf4b54f98fbb9625f to your computer and use it in GitHub Desktop.
run summarization on a directory with anthropic API + langchain
"""
anthropic_run_summarization.py - Generate summaries using langchain + LLMs
For usage details, run `python anthropic_run_summarization.py --help` and fire will print the usage details.
Notes:
- you need to have ANTHROPIC_API_KEY set as an environment variable (easiest way is export ANTHROPIC_API_KEY=memes123)
- install the dependencies using the requirements.txt file or below
pip install fire langchain langchain-community langchain-anthropic clean-text tqdm tiktoken
*I honestly have no idea what the langchain requirements will be at any time, they change every week
"""
import json
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Optional
import fire
from cleantext import clean
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.prompts import PromptTemplate
from langchain_anthropic import ChatAnthropic
from langchain.globals import set_llm_cache
from tqdm.auto import tqdm
from transformers import GPT2TokenizerFast
from bin.join_textdocs import merge_text_files
from convert_langchain_steps import convert_langchain_json2text
# We can do the same thing with a SQLite cache
from langchain.cache import SQLiteCache
os.environ["MAX_RECURSIVE_SAMPLES"] = "100"
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
map_custom_prompt = """Write a SparkNotes-style summary that focuses on the most important points and takeaways from the text. Use clear, concise language to ensure that the summary is rich in test-relevant information and easy to understand for learners at all levels.
- Simplify complex ideas into simpler explanations, providing the necessary context to make the material accessible.
- Provide analysis and insight to deepen understanding, facilitating learning and critical thinking for a wide audience, including those unfamiliar with the topic.
- Prioritize and thoroughly cover the essential elements of the text, striking a balance between comprehensive detail and brevity to maintain engagement.
Summarize the text:\n'{text}'"""
map_prompt_template = PromptTemplate(
input_variables=["text"], template=map_custom_prompt
)
combine_custom_prompt = """Synthesize key ideas and insights from the provided summaries into a unified, cohesive summary. Highlight common themes and synthesize analysis to provide a deeper understanding of the material. The final summary should be concise, engaging, and easy to follow, providing a well-rounded overview of the original texts.
Weave together the most important aspects to create a narrative that captures the essence of the material, facilitating both understanding and retention.
Summaries for consolidation:\n`{text}\n`
"""
combine_prompt_template = PromptTemplate(
template=combine_custom_prompt, input_variables=["text"]
)
def get_timestamp() -> str:
return datetime.now().strftime("%Y%b%d%H-%M")
def get_cache_location(model_name: str) -> Path:
"""get the cache location"""
location = Path.home() / ".langchain" / "anthropic" / f".cached-{model_name}.db"
location.parent.mkdir(parents=True, exist_ok=True)
return str(location.resolve())
def read_and_clean_file(file_path, lower: bool = False) -> str:
"""
read a file and clean it
:param file_path: path to the file
:param lower: whether to lowercase the text
"""
with open(file_path, "r", encoding="utf-8") as f:
context = clean(f.read(), lower=lower)
return context
def save_output_to_file(
out_dir,
sub_dir,
output,
json_output,
file_name: str = "output",
):
"""save the output to a file""" ""
out_dir = Path(out_dir)
out_dir = out_dir / sub_dir
out_dir.mkdir(parents=True, exist_ok=True)
output_file = out_dir / f"{file_name}_summary.txt"
with output_file.open("w", encoding="utf-8") as f:
f.write(output)
output_file = out_dir / f"{file_name}_summary.json"
with open(output_file, "w", encoding="utf-8") as f:
json.dump(json_output, f, ensure_ascii=False, indent=4)
class CostTracker:
"""
CostTracker - tracks the cost of using an OpenAI inference model
Refer to https://www.anthropic.com/api for more details on pricing.
"""
MODEL_COSTS = {
"claude-3-opus-20240229": {"input": 0.015, "output": 0.075},
"claude-3-sonnet-20240229": {"input": 0.003, "output": 0.015},
"claude-3-haiku-20240307": {"input": 0.00025, "output": 0.000125},
} # costs in $ per 1000 tokens
def __init__(
self,
input_cost_rate: float = None,
output_cost_rate: float = None,
model: str = "claude-3-sonnet-20240229",
):
"""
__init__
:param float input_cost_rate: cost per 1000 tokens for input, defaults to None
:param float output_cost_rate: cost per 1000 tokens for output, defaults to None
:param str model: model name, defaults to "gpt-3.5-turbo"
:raises ValueError: if input_cost_rate or output_cost_rate is None and not found in MODEL_COSTS
"""
self.logger = logging.getLogger(__name__)
self.model = model
if self.model in self.MODEL_COSTS.keys():
self.input_cost_rate = self.MODEL_COSTS[self.model]["input"]
self.output_cost_rate = self.MODEL_COSTS[self.model]["output"]
else:
self.logger.warning(
f"Costs for model {self.model} not found. Defaulting to input_cost_rate={input_cost_rate} & output_cost_rate={output_cost_rate}"
)
self.input_cost_rate = input_cost_rate
self.output_cost_rate = output_cost_rate
if self.input_cost_rate is None or self.output_cost_rate is None:
raise ValueError(
f"input_cost_rate or output_cost_rate not set for {self.model}, set it manually."
)
self.num_batches = 0
self.num_docs = 0
self.total_input_tokens = 0
self.total_output_tokens = 0
# get encoding
self.tokenizer = GPT2TokenizerFast.from_pretrained("Xenova/claude-tokenizer")
def track_batch(self, input_text: str, output_text: str):
self.num_batches += 1
self.total_input_tokens += self.num_tokens_in_text(input_text)
self.total_output_tokens += self.num_tokens_in_text(output_text)
self.num_docs += 1
def get_input_cost(self):
return round(self.total_input_tokens / 1000 * self.input_cost_rate, 2)
def get_output_cost(self):
return round(self.total_output_tokens / 1000 * self.output_cost_rate, 2)
def get_total_cost(self):
return self.get_input_cost() + self.get_output_cost()
@property
def cost_per_doc(self):
return self.get_total_cost() / self.num_docs if self.num_docs else 0
def print_summary(self):
print(str(self))
def num_tokens_in_text(self, text: str) -> int:
if len(text) == 0:
return 0
return len(self.tokenizer.encode(text, padding=False, truncation=False))
def __str__(self):
header = "-" * 40 + "\nCOST SUMMARY:\n" + "-" * 40 + "\n"
return (
f"{header}"
f"model name: {self.model}\n"
f"Total input tokens: {self.total_input_tokens}\n"
f"Total output tokens: {self.total_output_tokens}\n"
f"Cost (estimated): ${self.get_total_cost()} (input: ${self.get_input_cost()}, output: ${self.get_output_cost()})\n"
f"Documents processed: {self.num_docs}\n"
f"Cost per document: ${self.cost_per_doc}\n"
)
def generate_summaries(
input_dir: str,
output_dir: Optional[str] = None,
model: str = "claude-3-sonnet-20240229",
chunk_size: int = 32768,
max_tokens_to_sample: int = 3072,
chunk_overlap: int = 32,
temperature: float = 0.0,
map_reduce_chain: bool = False,
refine_chain: bool = False,
use_custom_prompts: bool = False,
recompute_completed: bool = False,
cost_rate_inputs: Optional[float] = None,
cost_rate_outputs: Optional[float] = None,
max_cost: Optional[float] = None,
extract_steps: Optional[bool] = True,
log_level: str = "WARNING",
recursive: bool = False,
file_extension: str = ".txt",
min_chars: int = 1000,
no_cache: bool = False,
) -> None:
"""
Generate summaries from text files using a specified language model.
:param str input_dir: The directory containing the input text files.
:param Optional[str] output_dir: The directory to write the summary files to. If None, summaries will be written to the input directory.
:param str model: The language model to use for summarization.
:param int chunk_size: The size of text chunks to feed to the model at a time.
:param int chunk_overlap: The number of overlapping characters between chunks.
:param float temperature: The randomness of the model's output.
:param bool map_reduce_chain: If True, use a map-reduce chain for summarization.
:param bool refine_chain: If True, refine the summary using a second pass.
:param Optional[float] cost_rate_inputs: The cost rate for input text.
:param Optional[float] cost_rate_outputs: The cost rate for output text.
:param Optional[float] max_cost: The maximum cost for the summarization process.
:param Optional[bool] extract_steps: If True, extract intermediate steps from the summary.
:param str log_level: The level of logging to use.
:param bool recursive: If True, recursively process subdirectories in the input directory.
:param str file_extension: The file extension for the input text files.
:param int min_chars: The minimum number of characters for a text to be summarized.
:raises ValueError: If the input directory does not exist or contains no valid text files.
"""
logger = logging.getLogger(__name__)
logger.setLevel(log_level)
logging.debug(f"Logging level set to {log_level}.")
logger.info(f"Generating summaries for {input_dir} using {model}.")
path = Path(input_dir)
assert path.exists() and path.is_dir(), f"Path {path} does not exist."
if not no_cache:
set_llm_cache(SQLiteCache(database_path=get_cache_location(model)))
if output_dir is None:
output_dir = path.parent
if not refine_chain and not map_reduce_chain:
logger.warning("No chain specified, using map_reduce chain.")
map_reduce_chain = True
llm = ChatAnthropic(
cache=not no_cache,
model_name=model,
temperature=temperature,
max_tokens_to_sample=max_tokens_to_sample,
default_request_timeout=180,
verbose=False,
)
anthropic_tk = GPT2TokenizerFast.from_pretrained("Xenova/claude-tokenizer")
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
anthropic_tk,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
if refine_chain:
logger.info("Using refine chain.")
chain_refine = load_summarize_chain(
llm, chain_type="refine", return_intermediate_steps=True
)
if use_custom_prompts:
logger.warning(
"Detected use_custom_prompts=True, this does not work with refine chain."
)
use_custom_prompts = False
if map_reduce_chain:
logger.info("Using map reduce chain")
chain_map_reduce = (
load_summarize_chain(
llm,
chain_type="map_reduce",
return_intermediate_steps=True,
verbose=False,
)
if not use_custom_prompts
else load_summarize_chain(
llm,
chain_type="map_reduce",
map_prompt=map_prompt_template,
combine_prompt=combine_prompt_template,
return_intermediate_steps=True,
verbose=False,
)
)
output_dir = (
Path(output_dir)
/ f"{path.stem}-{model}-summaries-{chunk_size}-cp_{use_custom_prompts}"
if output_dir is not None
else path.parent
/ f"{path.stem}-{model}-summaries-{chunk_size}-cp_{use_custom_prompts}"
)
cost_tracker = CostTracker(
input_cost_rate=cost_rate_inputs,
output_cost_rate=cost_rate_outputs,
model=model,
)
if recursive:
logger.warning("Recursive mode is on. This will take a while.")
source_files = (
[f for f in path.iterdir() if f.is_file() and f.suffix == file_extension]
if not recursive
else [f for f in path.rglob(f"*{file_extension}") if f.is_file()]
)
if not source_files:
raise ValueError(
f"No {file_extension} files found. Check input dir:\n\t{path.resolve()}"
)
if len(source_files) > int(os.environ.get("MAX_RECURSIVE_SAMPLES", 100)):
# require the user to confirm if they want to run on a large number of files
logger.warning(
f"Found {len(source_files)} files in {path.resolve()}. This will take a while."
)
response = input("Do you want to continue? (y/n): ").lower()
if response != "y":
logger.info("Exiting.")
return
logger.info(f"Found {len(source_files)} files in:\t{path.resolve()}")
logging.getLogger("httpx").setLevel(logging.WARNING)
skip_files = []
for i, doc_path in enumerate(tqdm(source_files, desc="API Inference"), start=1):
file_name = Path(doc_path).stem
input_text = read_and_clean_file(doc_path)
if len(input_text) < min_chars:
logger.info(
f"Skipping {file_name} as it has less than {min_chars} characters."
)
skip_files.append(file_name)
continue
docs = text_splitter.create_documents([input_text])
if refine_chain:
refine_output = chain_refine(
{"input_documents": docs}, return_only_outputs=True
)
cost_tracker.track_batch(
input_text, refine_output["output_text"]
) # note: slightly underestimates cost (overlap, summary-of-summary)
save_output_to_file(
output_dir,
"refine_output",
refine_output["output_text"],
refine_output,
file_name,
)
if map_reduce_chain:
map_reduce_output = chain_map_reduce(
{"input_documents": docs}, return_only_outputs=True
)
cost_tracker.track_batch(
input_text, map_reduce_output["output_text"]
) # note: slightly underestimates cost (overlap, summary-of-summary)
save_output_to_file(
output_dir,
"map_reduce_output",
map_reduce_output["output_text"],
map_reduce_output,
file_name,
)
logger.info(f"current cost: {cost_tracker.get_total_cost()}")
if max_cost and cost_tracker.get_total_cost() > max_cost:
logger.info(f"max cost reached: {cost_tracker.get_total_cost()}")
break
if extract_steps and map_reduce_chain:
out_dir = convert_langchain_json2text(output_dir / "map_reduce_output")
merge_text_files(out_dir, fancy=True)
print(f"Generated {i} summaries. Output saved to\n\t{output_dir}")
if skip_files:
print(f"Skipped {len(skip_files)} files with less than {min_chars} characters.")
print(f"Skipped files:\n\t{skip_files}")
cost_tracker.print_summary()
logger.info("Done.")
if __name__ == "__main__":
fire.Fire(generate_summaries)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment