Skip to content

Instantly share code, notes, and snippets.

@detour1999
Created March 14, 2024 15:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save detour1999/dda425f62abaa4cf2e0b474d171dac92 to your computer and use it in GitHub Desktop.
Save detour1999/dda425f62abaa4cf2e0b474d171dac92 to your computer and use it in GitHub Desktop.
map reduce langchain
from openai import OpenAI
from dotenv import load_dotenv
import json
import os
import structlog
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.chains import AnalyzeDocumentChain, LLMChain, StuffDocumentsChain
from langchain.chains.summarize import load_summarize_chain
from langchain_openai import OpenAI
from typing import List, Optional
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
from langchain.pydantic_v1 import BaseModel, Field, validator
LOGGER = structlog.get_logger(__name__)
load_dotenv()
class SummarizedDocument(BaseModel):
author: Optional[str] = Field(description="The author or authors of the document")
title: Optional[str] = Field(description="The title of the document")
subject: Optional[str] = Field(description="The subject of the document")
keywords: List[str] = Field(description="Keywords relating to the document")
summary: str = Field(description="A summary of the document")
parser = PydanticOutputParser(pydantic_object=SummarizedDocument)
def summarize_content(content: str) -> dict:
"""
Summarizes the given content by running it through a map-reduce summarization chain.
Args:
content: The text content to be summarized.
Returns:
A dictionary with the summarized content and additional metadata such as keywords and subject.
"""
client = OpenAI(api_key=os.environ.get('OPEN_AI_KEY'), temperature=0, max_tokens=800)
LOGGER.debug("summarize_content.client_initialized")
map_prompt_template = """
Write a summary of this chunk of text that includes the main points and any important details.
```{text}```
"""
combine_prompt_template = """
You are a researcher who can summarize texts. Write a summary of the texts delimited by triple backquotes.
Include the summary, at most ten keywords and the subject of the text.
{format_instructions}
```{text}```
"""
map_prompt = PromptTemplate(template=map_prompt_template, input_variables=["text"])
combine_prompt = PromptTemplate(template=combine_prompt_template, input_variables=["text"], partial_variables={"format_instructions": parser.get_format_instructions()})
LOGGER.debug("summarize_content.prompts_initialized")
summary_chain = load_summarize_chain(llm=client, chain_type='map_reduce', map_prompt=map_prompt, combine_prompt=combine_prompt)
LOGGER.debug("summarize_content.summary_chain_loaded", chain_type='map_reduce')
summarize_document_chain = AnalyzeDocumentChain(combine_docs_chain=summary_chain)
summary = summarize_document_chain.invoke(content, token_max=2000)
LOGGER.debug("summarize_content.invoke_summary_chain", output_text=summary["output_text"])
try:
resp = json.loads(summary["output_text"])
except json.JSONDecodeError:
LOGGER.error("summarize_content.json_decode_error", error=summary["output_text"])
raise
LOGGER.debug("summarize_content.parsed_response", response=resp)
return resp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment