Created
March 14, 2024 15:32
-
-
Save detour1999/dda425f62abaa4cf2e0b474d171dac92 to your computer and use it in GitHub Desktop.
map reduce langchain
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from openai import OpenAI | |
from dotenv import load_dotenv | |
import json | |
import os | |
import structlog | |
from langchain.prompts import PromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.docstore.document import Document | |
from langchain.chains import AnalyzeDocumentChain, LLMChain, StuffDocumentsChain | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain_openai import OpenAI | |
from typing import List, Optional | |
from pydantic import BaseModel, Field | |
from langchain.output_parsers import PydanticOutputParser | |
from langchain.pydantic_v1 import BaseModel, Field, validator | |
LOGGER = structlog.get_logger(__name__) | |
load_dotenv() | |
class SummarizedDocument(BaseModel): | |
author: Optional[str] = Field(description="The author or authors of the document") | |
title: Optional[str] = Field(description="The title of the document") | |
subject: Optional[str] = Field(description="The subject of the document") | |
keywords: List[str] = Field(description="Keywords relating to the document") | |
summary: str = Field(description="A summary of the document") | |
parser = PydanticOutputParser(pydantic_object=SummarizedDocument) | |
def summarize_content(content: str) -> dict: | |
""" | |
Summarizes the given content by running it through a map-reduce summarization chain. | |
Args: | |
content: The text content to be summarized. | |
Returns: | |
A dictionary with the summarized content and additional metadata such as keywords and subject. | |
""" | |
client = OpenAI(api_key=os.environ.get('OPEN_AI_KEY'), temperature=0, max_tokens=800) | |
LOGGER.debug("summarize_content.client_initialized") | |
map_prompt_template = """ | |
Write a summary of this chunk of text that includes the main points and any important details. | |
```{text}``` | |
""" | |
combine_prompt_template = """ | |
You are a researcher who can summarize texts. Write a summary of the texts delimited by triple backquotes. | |
Include the summary, at most ten keywords and the subject of the text. | |
{format_instructions} | |
```{text}``` | |
""" | |
map_prompt = PromptTemplate(template=map_prompt_template, input_variables=["text"]) | |
combine_prompt = PromptTemplate(template=combine_prompt_template, input_variables=["text"], partial_variables={"format_instructions": parser.get_format_instructions()}) | |
LOGGER.debug("summarize_content.prompts_initialized") | |
summary_chain = load_summarize_chain(llm=client, chain_type='map_reduce', map_prompt=map_prompt, combine_prompt=combine_prompt) | |
LOGGER.debug("summarize_content.summary_chain_loaded", chain_type='map_reduce') | |
summarize_document_chain = AnalyzeDocumentChain(combine_docs_chain=summary_chain) | |
summary = summarize_document_chain.invoke(content, token_max=2000) | |
LOGGER.debug("summarize_content.invoke_summary_chain", output_text=summary["output_text"]) | |
try: | |
resp = json.loads(summary["output_text"]) | |
except json.JSONDecodeError: | |
LOGGER.error("summarize_content.json_decode_error", error=summary["output_text"]) | |
raise | |
LOGGER.debug("summarize_content.parsed_response", response=resp) | |
return resp |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment