Last active
June 6, 2023 16:45
-
-
Save moradology/e3799d40ffb02559c69d4632167e9c9a to your computer and use it in GitHub Desktop.
An example of using vector embeddings to constrain LLM output
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import argparse | |
import pathlib | |
import subprocess | |
import tempfile | |
import textwrap | |
from langchain.llms import OpenAI | |
from langchain.chains import LLMChain | |
from langchain.docstore.document import Document | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
def get_github_docs(repo_owner, repo_name): | |
with tempfile.TemporaryDirectory() as d: | |
subprocess.check_call( | |
f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .", | |
cwd=d, | |
shell=True, | |
) | |
git_sha = ( | |
subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d) | |
.decode("utf-8") | |
.strip() | |
) | |
repo_path = pathlib.Path(d) | |
markdown_files = list( | |
repo_path.glob("*.md") | |
) + list( | |
repo_path.glob("*/*.mdx") | |
) + list( | |
repo_path.glob("*/*.md") | |
) + list( | |
repo_path.glob("*.mdx") | |
) + list( | |
repo_path.glob("*/*/*.md") | |
) + list( | |
repo_path.glob("*/*/*/*.md") | |
) | |
print(f"The markdown files to be used as context: {markdown_files}") | |
for markdown_file in markdown_files: | |
with open(markdown_file, "r") as f: | |
relative_path = markdown_file.relative_to(repo_path) | |
github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}" | |
yield Document(page_content=f.read(), metadata={"source": github_url}) | |
def build_index(sources): | |
source_chunks = [] | |
splitter = CharacterTextSplitter(separator=" ", chunk_size=1024, chunk_overlap=128) | |
for source in sources: | |
for chunk in splitter.split_text(source.page_content): | |
source_chunks.append(Document(page_content=chunk, metadata=source.metadata)) | |
return Chroma.from_documents(source_chunks, OpenAIEmbeddings()) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="GitHub Project Information") | |
parser.add_argument("-p", "--project", help="GitHub project name") | |
parser.add_argument("-u", "--profile", help="GitHub profile name") | |
parser.add_argument("-t", "--topic", help="Topic for content generation") | |
args = parser.parse_args() | |
github_project = args.project | |
github_profile = args.profile | |
topic = args.topic | |
project_docs = get_github_docs(github_profile, github_project) | |
vector_index = build_index(project_docs) | |
prompt_template = """You are a senior software engineer. Use the context below (which is taken from a github repository) to write a very short blog post about the topic and its relation to this repository. | |
Context: {context} | |
Topic: {topic} | |
Blog post:""" | |
PROMPT = PromptTemplate( | |
template=prompt_template, input_variables=["context", "topic"] | |
) | |
llm = OpenAI(temperature=0, model_name="text-davinci-003") | |
chain = LLMChain(llm=llm, prompt=PROMPT) | |
def generate_content(topic): | |
docs = vector_index.similarity_search(topic, k=4) | |
inputs = [{"context": doc.page_content, "topic": topic} for doc in docs] | |
return chain.apply(inputs) | |
generated_content = generate_content(topic) | |
for d in generated_content: | |
print(textwrap.fill(d["text"], width=100)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment