Skip to content

Instantly share code, notes, and snippets.

@jbesw
Created August 24, 2023 17:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jbesw/3e2ba20a7e40d97f187437f84a137b2d to your computer and use it in GitHub Desktop.
Save jbesw/3e2ba20a7e40d97f187437f84a137b2d to your computer and use it in GitHub Desktop.
Langchain and Lambda
import json
import boto3
import os
from langchain.llms.bedrock import Bedrock
from langchain import PromptTemplate
from typing import Optional, List, Mapping, Any, Dict
from langchain.retrievers import AmazonKendraRetriever
from langchain.chains import RetrievalQA
S3_BUCKET_NAME = os.environ["S3_BUCKET_NAME"]
PROMPT_TEMPLATE_S3_KEY = os.environ["PROMPT_TEMPLATE_S3_KEY"]
BEDROCK_MODEL_ID = os.environ["BEDROCK_MODEL_IDENTIFIER"]
REGION = os.environ['AWS_REGION']
KENDRA_INDEX_ID = os.getenv("KENDRA_INDEX_ID", None)
KENDRA_CLIENT = boto3.client("kendra", REGION)
model_args = {
"max_tokens_to_sample":4096,
"temperature":0.5
}
llm = Bedrock(model_id=BEDROCK_MODEL_ID, verbose = True,model_kwargs = model_args)
def build_chain():
retriever = AmazonKendraRetriever(index_id=KENDRA_INDEX_ID)
#Code for reading a file from S3 bucket
s3 = boto3.resource('s3')
obj = s3.Object(S3_BUCKET_NAME, PROMPT_TEMPLATE_S3_KEY)
prompt_template = obj.get()['Body'].read().decode('utf-8')
print(prompt_template)
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": PROMPT}
return RetrievalQA.from_chain_type(
llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs=chain_type_kwargs,
return_source_documents=True
)
def run_chain(chain, prompt: str, history=[]):
print('prompt:', prompt)
result = chain(prompt)
# To make it compatible with chat samples
return {
"answer": result['result'],
"source_documents": result['source_documents']
}
def lambda_handler(event, context):
print(f"boto3-version: {boto3.__version__}")
print('kendra index:', KENDRA_INDEX_ID)
print('S3_BUCKET_NAME:', S3_BUCKET_NAME)
print('PROMPT_TEMPLATE_S3_KEY:', PROMPT_TEMPLATE_S3_KEY)
event_body = json.loads(event["body"])
query = event_body["query"]
print(f"query: {query}")
chain = build_chain()
result = run_chain(chain, query)
print(result['answer'])
source_docs = []
if 'source_documents' in result:
print('Sources:')
for d in result['source_documents']:
print(d.metadata['source'])
source_docs.append(d.metadata['source'])
output = {"answer": result['answer'], "source_documents": source_docs}
return {
"statusCode": 200,
"headers": {
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
},
"body": json.dumps(output)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment