Skip to content

Instantly share code, notes, and snippets.

@samhita-alla
Last active November 10, 2023 22:07
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samhita-alla/99de0ac46776cb3d4f08ad7e4908512d to your computer and use it in GitHub Desktop.
Save samhita-alla/99de0ac46776cb3d4f08ad7e4908512d to your computer and use it in GitHub Desktop.
LangChain x Flyte
import json
import os
from functools import partial
import flytekit
from flytekit import ImageSpec, Resources, Secret, map_task, task, workflow
embed_image = ImageSpec(
name="langchain-flyte-vectordb",
packages=[
"langchain",
"pinecone-client",
"huggingface_hub",
"sentence_transformers",
"yt_dlp",
"pydub",
"openai",
],
apt_packages=["ffmpeg"],
registry="ghcr.io/samhita-alla",
base_image="ghcr.io/flyteorg/flytekit:py3.11-1.7.0",
)
query_image = ImageSpec(
name="langchain-flyte-query",
packages=[
"langchain",
"pinecone-client",
"huggingface_hub",
"sentence_transformers",
"openai",
"spacy",
"textstat",
"https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0.tar.gz",
],
registry="ghcr.io/samhita-alla",
base_image="ghcr.io/flyteorg/flytekit:py3.11-1.7.0",
)
SECRET_GROUP = "arn:aws:secretsmanager:us-east-2:356633062068:secret"
SECRET_KEY = "flyte_langchain-YtD8OW"
@task(
cache=True,
cache_version="1.0",
secret_requests=[
Secret(
group=SECRET_GROUP,
key=SECRET_KEY,
mount_requirement=Secret.MountType.FILE,
),
],
container_image=embed_image,
requests=Resources(mem="5Gi"),
)
def embed_and_store(url: str, index_name: str) -> str:
import openai
import pinecone
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Pinecone
pinecone.init(
api_key=json.loads(
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY)
)["pinecone_api_key"],
environment=json.loads(
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY)
)["pinecone_environment"],
)
openai.api_key = json.loads(
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY)
)["openai_api_key"]
# Directory to save audio files
save_dir = os.path.join(flytekit.current_context().working_directory, "youtube")
# Transcribe the videos to text
loader = GenericLoader(YoutubeAudioLoader([url], save_dir), OpenAIWhisperParser())
docs = loader.load()
combined_docs = [doc.page_content for doc in docs]
text = " ".join(combined_docs)
# Split them
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
splits = text_splitter.split_text(text)
huggingface_embeddings = HuggingFaceEmbeddings(
cache_folder=os.path.join(
flytekit.current_context().working_directory, "embeddings-cache-folder"
)
)
Pinecone.from_texts(
texts=splits, embedding=huggingface_embeddings, index_name=index_name
)
return f"{url} data is stored in the vectordb."
@task(
disable_deck=False,
secret_requests=[
Secret(
group=SECRET_GROUP,
key=SECRET_KEY,
mount_requirement=Secret.MountType.FILE,
),
],
container_image=query_image,
requests=Resources(mem="5Gi"),
)
def query_vectordb(index_name: str, query: str) -> str:
import pinecone
from langchain.callbacks import FlyteCallbackHandler
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Pinecone
pinecone.init(
api_key=json.loads(
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY)
)["pinecone_api_key"],
environment=json.loads(
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY)
)["pinecone_environment"],
)
huggingface_embeddings = HuggingFaceEmbeddings(
cache_folder=os.path.join(
flytekit.current_context().working_directory, "embeddings-cache-folder"
)
)
vectordb = Pinecone.from_existing_index(index_name, huggingface_embeddings)
retriever = vectordb.as_retriever(search_type="similarity", search_kwargs={"k": 2})
qa_chain = RetrievalQA.from_chain_type(
llm=ChatOpenAI(
model_name="gpt-3.5-turbo",
callbacks=[FlyteCallbackHandler()],
temperature=0,
openai_api_key=json.loads(
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY)
)["openai_api_key"],
),
chain_type="stuff",
retriever=retriever,
)
result = qa_chain.run(query)
return result
@workflow
def flyte_youtube_embed_wf(
index_name: str = "flyte-youtube-data",
urls: list[str] = [
"https://youtu.be/CNmO1q3MamM",
"https://youtu.be/8rLj_YVOpzE",
"https://youtu.be/sGqS8PFQz6c",
"https://youtu.be/1668vZczslw",
"https://youtu.be/NrFOXQKrREA",
"https://youtu.be/4ktHNeT8kq4",
"https://youtu.be/gMyTz8gKWVc",
],
) -> list[str]:
partial_embed_and_store = partial(embed_and_store, index_name=index_name)
return map_task(partial_embed_and_store)(url=urls)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment