Skip to content

Instantly share code, notes, and snippets.

@maceip
Created March 15, 2024 01:35
Show Gist options
  • Save maceip/e264e22fb27f65832b599810556259b7 to your computer and use it in GitHub Desktop.
Save maceip/e264e22fb27f65832b599810556259b7 to your computer and use it in GitHub Desktop.
retrieval augmented generation: using solidity repos as FAISS vector store with langchain and vertex-ai. -> rag-llm for boosting solidity code generation quality
# dependencies:
# pip install --upgrade --user -q google-cloud-aiplatform langchain==0.0.332 faiss-cpu==1.7.4
# pip install asyncio requests uvicorn
#
# google cloud auth:
# get a service account key json, then:
# export GOOGLE_APPLICATION_CREDENTIALS=service.json
#
# works on ubuntu 22.04
# pip 22.0.2
# Python 3.10.12
import logging
import time
from typing import List
import requests
import uvicorn
import vertexai
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
# Vertex AI
from google.cloud import aiplatform
from langchain.chains import RetrievalQA
from langchain.embeddings import VertexAIEmbeddings
from langchain.llms import VertexAI
from langchain.prompts import PromptTemplate
from langchain.schema.document import Document
from langchain.text_splitter import Language, RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
load_dotenv()
logger = logging.getLogger(__name__)
app = FastAPI()
# Add CORS middleware
origins = [
"http://localhost:3000", # React default
"http://localhost:8000", # FastAPI default
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print(f"Vertex AI SDK version: {aiplatform.__version__}")
class CustomVertexAIEmbeddings(VertexAIEmbeddings):
requests_per_minute: int
num_instances_per_batch: int
# Overriding embed_documents method
def embed_documents(self, texts: List[str]):
limiter = rate_limit(self.requests_per_minute)
results = []
docs = list(texts)
while docs:
# Working in batches because the API accepts maximum 5
# documents per request to get embeddings
head, docs = (
docs[: self.num_instances_per_batch],
docs[self.num_instances_per_batch :],
)
chunk = self.client.get_embeddings(head)
results.extend(chunk)
next(limiter)
return [r.values for r in results]
# provide Github personal access token
GITHUB_TOKEN = "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" # @param {type:"string"}
GITHUB_REPO = "volt-protocol/ethereum-credit-guild" # @param {type:"string"}
def crawl_github_repo(url: str, is_sub_dir: bool, access_token: str = GITHUB_TOKEN):
ignore_list = ["__init__.py"]
if not is_sub_dir:
api_url = f"https://api.github.com/repos/{url}/contents"
else:
api_url = url
headers = {
"Accept": "application/vnd.github.v3+json",
"Authorization": f"Bearer {access_token}",
}
response = requests.get(api_url, headers=headers)
response.raise_for_status() # Check for any request errors
files = []
contents = response.json()
for item in contents:
if (
item["type"] == "file"
and item["name"] not in ignore_list
and (item["name"].endswith(".sol") or item["name"].endswith(".vy"))
):
files.append(item["html_url"])
elif item["type"] == "dir" and not item["name"].startswith("."):
sub_files = crawl_github_repo(item["url"], True)
time.sleep(0.1)
files.extend(sub_files)
return files
def extract_solidity_code(github_url):
raw_url = github_url.replace("github.com", "raw.githubusercontent.com").replace(
"/blob/", "/"
)
response = requests.get(raw_url)
response.raise_for_status() # Check for any request errors
sol_code = response.text
return sol_code
def rate_limit(max_per_minute):
period = 60 / max_per_minute
print("Waiting")
while True:
before = time.time()
yield
after = time.time()
elapsed = after - before
sleep_time = max(0, period - elapsed)
if sleep_time > 0:
print(".", end="")
time.sleep(sleep_time)
@app.post("/crawl/")
async def crawl_git():
PROJECT_ID = "litapp-fb3d7" # @param {type:"string"}
LOCATION = "us-central1" # @param {type:"string"}
vertexai.init(project=PROJECT_ID, location=LOCATION)
code_strings = []
try:
code_files_urls = crawl_github_repo(GITHUB_REPO, False, GITHUB_TOKEN)
with open("code_files_urls.txt", "w") as f:
for item in code_files_urls:
f.write(item + "\n")
print(f"len: {len(code_files_urls)}")
with open("code_files_urls.txt") as f:
code_files_urls = f.read().splitlines()
print(f"code_files_urls len: {len(code_files_urls)}")
for i in range(0, len(code_files_urls)):
if code_files_urls[i].endswith(".sol"):
content = extract_solidity_code(code_files_urls[i])
doc = Document(
page_content=content,
metadata={"url": code_files_urls[i], "file_index": i},
)
code_strings.append(doc)
print(f"code_strings len: {len(code_strings)}")
text_splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.SOL, chunk_size=2000, chunk_overlap=200
)
texts = text_splitter.split_documents(code_strings)
print(f"texts len: {len(texts)}")
# Initialize Embedding API
EMBEDDING_QPM = 100
EMBEDDING_NUM_BATCH = 5
embeddings = CustomVertexAIEmbeddings(
requests_per_minute=EMBEDDING_QPM,
num_instances_per_batch=EMBEDDING_NUM_BATCH,
model_name="textembedding-gecko@latest",
)
# Create Index from embedded code chunks
db = FAISS.from_documents(texts, embeddings)
print(f"db: {db}")
# Init your retriever.
retriever = db.as_retriever(
search_type="similarity", # Also test "similarity", "mmr"
search_kwargs={"k": 5},
)
print(f"retreiver: {retriever}")
# RAG template
prompt_RAG = """
You are a proficient solidity developer. Respond with the syntactically correct code for to the question below. Make sure you follow these rules:
1. Use context to understand the APIs and how to use it & apply.
2. Do not add license information to the output code.
3. Ensure all the requirements in the question are met.
Question:
{question}
Context:
{context}
Helpful Response :
"""
prompt_RAG_tempate = PromptTemplate(
template=prompt_RAG, input_variables=["context", "question"]
)
# Code Generation
code_llm = VertexAI(
model_name="code-bison@002",
max_output_tokens=2048,
temperature=0.1,
verbose=False,
)
qa_chain = RetrievalQA.from_llm(
llm=code_llm,
prompt=prompt_RAG_tempate,
retriever=retriever,
return_source_documents=True,
)
user_question = "Create a Soliduty smart contract that allows a user to stake and lend tokens"
results = qa_chain({"query": user_question})
print(results["result"])
return {"results": results["result"]}
except Exception as e:
raise HTTPException(status_code=400, detail=f"{e}")
@app.get("/")
def main():
content = """
<body>
macmac
</body>
"""
return HTMLResponse(content=content)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment