Skip to content

Instantly share code, notes, and snippets.

@moradology
Last active June 6, 2023 16:45
Show Gist options
  • Save moradology/e3799d40ffb02559c69d4632167e9c9a to your computer and use it in GitHub Desktop.
Save moradology/e3799d40ffb02559c69d4632167e9c9a to your computer and use it in GitHub Desktop.
An example of using vector embeddings to constrain LLM output
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import pathlib
import subprocess
import tempfile
import textwrap
from langchain.llms import OpenAI
from langchain.chains import LLMChain
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
def get_github_docs(repo_owner, repo_name):
with tempfile.TemporaryDirectory() as d:
subprocess.check_call(
f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .",
cwd=d,
shell=True,
)
git_sha = (
subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d)
.decode("utf-8")
.strip()
)
repo_path = pathlib.Path(d)
markdown_files = list(
repo_path.glob("*.md")
) + list(
repo_path.glob("*/*.mdx")
) + list(
repo_path.glob("*/*.md")
) + list(
repo_path.glob("*.mdx")
) + list(
repo_path.glob("*/*/*.md")
) + list(
repo_path.glob("*/*/*/*.md")
)
print(f"The markdown files to be used as context: {markdown_files}")
for markdown_file in markdown_files:
with open(markdown_file, "r") as f:
relative_path = markdown_file.relative_to(repo_path)
github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}"
yield Document(page_content=f.read(), metadata={"source": github_url})
def build_index(sources):
source_chunks = []
splitter = CharacterTextSplitter(separator=" ", chunk_size=1024, chunk_overlap=128)
for source in sources:
for chunk in splitter.split_text(source.page_content):
source_chunks.append(Document(page_content=chunk, metadata=source.metadata))
return Chroma.from_documents(source_chunks, OpenAIEmbeddings())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GitHub Project Information")
parser.add_argument("-p", "--project", help="GitHub project name")
parser.add_argument("-u", "--profile", help="GitHub profile name")
parser.add_argument("-t", "--topic", help="Topic for content generation")
args = parser.parse_args()
github_project = args.project
github_profile = args.profile
topic = args.topic
project_docs = get_github_docs(github_profile, github_project)
vector_index = build_index(project_docs)
prompt_template = """You are a senior software engineer. Use the context below (which is taken from a github repository) to write a very short blog post about the topic and its relation to this repository.
Context: {context}
Topic: {topic}
Blog post:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "topic"]
)
llm = OpenAI(temperature=0, model_name="text-davinci-003")
chain = LLMChain(llm=llm, prompt=PROMPT)
def generate_content(topic):
docs = vector_index.similarity_search(topic, k=4)
inputs = [{"context": doc.page_content, "topic": topic} for doc in docs]
return chain.apply(inputs)
generated_content = generate_content(topic)
for d in generated_content:
print(textwrap.fill(d["text"], width=100))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment