Created
March 15, 2024 01:35
-
-
Save maceip/e264e22fb27f65832b599810556259b7 to your computer and use it in GitHub Desktop.
retrieval augmented generation: using solidity repos as FAISS vector store with langchain and vertex-ai. -> rag-llm for boosting solidity code generation quality
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# dependencies: | |
# pip install --upgrade --user -q google-cloud-aiplatform langchain==0.0.332 faiss-cpu==1.7.4 | |
# pip install asyncio requests uvicorn | |
# | |
# google cloud auth: | |
# get a service account key json, then: | |
# export GOOGLE_APPLICATION_CREDENTIALS=service.json | |
# | |
# works on ubuntu 22.04 | |
# pip 22.0.2 | |
# Python 3.10.12 | |
import logging | |
import time | |
from typing import List | |
import requests | |
import uvicorn | |
import vertexai | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import HTMLResponse | |
# Vertex AI | |
from google.cloud import aiplatform | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import VertexAIEmbeddings | |
from langchain.llms import VertexAI | |
from langchain.prompts import PromptTemplate | |
from langchain.schema.document import Document | |
from langchain.text_splitter import Language, RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
load_dotenv() | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Add CORS middleware | |
origins = [ | |
"http://localhost:3000", # React default | |
"http://localhost:8000", # FastAPI default | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
print(f"Vertex AI SDK version: {aiplatform.__version__}") | |
class CustomVertexAIEmbeddings(VertexAIEmbeddings): | |
requests_per_minute: int | |
num_instances_per_batch: int | |
# Overriding embed_documents method | |
def embed_documents(self, texts: List[str]): | |
limiter = rate_limit(self.requests_per_minute) | |
results = [] | |
docs = list(texts) | |
while docs: | |
# Working in batches because the API accepts maximum 5 | |
# documents per request to get embeddings | |
head, docs = ( | |
docs[: self.num_instances_per_batch], | |
docs[self.num_instances_per_batch :], | |
) | |
chunk = self.client.get_embeddings(head) | |
results.extend(chunk) | |
next(limiter) | |
return [r.values for r in results] | |
# provide Github personal access token | |
GITHUB_TOKEN = "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" # @param {type:"string"} | |
GITHUB_REPO = "volt-protocol/ethereum-credit-guild" # @param {type:"string"} | |
def crawl_github_repo(url: str, is_sub_dir: bool, access_token: str = GITHUB_TOKEN): | |
ignore_list = ["__init__.py"] | |
if not is_sub_dir: | |
api_url = f"https://api.github.com/repos/{url}/contents" | |
else: | |
api_url = url | |
headers = { | |
"Accept": "application/vnd.github.v3+json", | |
"Authorization": f"Bearer {access_token}", | |
} | |
response = requests.get(api_url, headers=headers) | |
response.raise_for_status() # Check for any request errors | |
files = [] | |
contents = response.json() | |
for item in contents: | |
if ( | |
item["type"] == "file" | |
and item["name"] not in ignore_list | |
and (item["name"].endswith(".sol") or item["name"].endswith(".vy")) | |
): | |
files.append(item["html_url"]) | |
elif item["type"] == "dir" and not item["name"].startswith("."): | |
sub_files = crawl_github_repo(item["url"], True) | |
time.sleep(0.1) | |
files.extend(sub_files) | |
return files | |
def extract_solidity_code(github_url): | |
raw_url = github_url.replace("github.com", "raw.githubusercontent.com").replace( | |
"/blob/", "/" | |
) | |
response = requests.get(raw_url) | |
response.raise_for_status() # Check for any request errors | |
sol_code = response.text | |
return sol_code | |
def rate_limit(max_per_minute): | |
period = 60 / max_per_minute | |
print("Waiting") | |
while True: | |
before = time.time() | |
yield | |
after = time.time() | |
elapsed = after - before | |
sleep_time = max(0, period - elapsed) | |
if sleep_time > 0: | |
print(".", end="") | |
time.sleep(sleep_time) | |
@app.post("/crawl/") | |
async def crawl_git(): | |
PROJECT_ID = "litapp-fb3d7" # @param {type:"string"} | |
LOCATION = "us-central1" # @param {type:"string"} | |
vertexai.init(project=PROJECT_ID, location=LOCATION) | |
code_strings = [] | |
try: | |
code_files_urls = crawl_github_repo(GITHUB_REPO, False, GITHUB_TOKEN) | |
with open("code_files_urls.txt", "w") as f: | |
for item in code_files_urls: | |
f.write(item + "\n") | |
print(f"len: {len(code_files_urls)}") | |
with open("code_files_urls.txt") as f: | |
code_files_urls = f.read().splitlines() | |
print(f"code_files_urls len: {len(code_files_urls)}") | |
for i in range(0, len(code_files_urls)): | |
if code_files_urls[i].endswith(".sol"): | |
content = extract_solidity_code(code_files_urls[i]) | |
doc = Document( | |
page_content=content, | |
metadata={"url": code_files_urls[i], "file_index": i}, | |
) | |
code_strings.append(doc) | |
print(f"code_strings len: {len(code_strings)}") | |
text_splitter = RecursiveCharacterTextSplitter.from_language( | |
language=Language.SOL, chunk_size=2000, chunk_overlap=200 | |
) | |
texts = text_splitter.split_documents(code_strings) | |
print(f"texts len: {len(texts)}") | |
# Initialize Embedding API | |
EMBEDDING_QPM = 100 | |
EMBEDDING_NUM_BATCH = 5 | |
embeddings = CustomVertexAIEmbeddings( | |
requests_per_minute=EMBEDDING_QPM, | |
num_instances_per_batch=EMBEDDING_NUM_BATCH, | |
model_name="textembedding-gecko@latest", | |
) | |
# Create Index from embedded code chunks | |
db = FAISS.from_documents(texts, embeddings) | |
print(f"db: {db}") | |
# Init your retriever. | |
retriever = db.as_retriever( | |
search_type="similarity", # Also test "similarity", "mmr" | |
search_kwargs={"k": 5}, | |
) | |
print(f"retreiver: {retriever}") | |
# RAG template | |
prompt_RAG = """ | |
You are a proficient solidity developer. Respond with the syntactically correct code for to the question below. Make sure you follow these rules: | |
1. Use context to understand the APIs and how to use it & apply. | |
2. Do not add license information to the output code. | |
3. Ensure all the requirements in the question are met. | |
Question: | |
{question} | |
Context: | |
{context} | |
Helpful Response : | |
""" | |
prompt_RAG_tempate = PromptTemplate( | |
template=prompt_RAG, input_variables=["context", "question"] | |
) | |
# Code Generation | |
code_llm = VertexAI( | |
model_name="code-bison@002", | |
max_output_tokens=2048, | |
temperature=0.1, | |
verbose=False, | |
) | |
qa_chain = RetrievalQA.from_llm( | |
llm=code_llm, | |
prompt=prompt_RAG_tempate, | |
retriever=retriever, | |
return_source_documents=True, | |
) | |
user_question = "Create a Soliduty smart contract that allows a user to stake and lend tokens" | |
results = qa_chain({"query": user_question}) | |
print(results["result"]) | |
return {"results": results["result"]} | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"{e}") | |
@app.get("/") | |
def main(): | |
content = """ | |
<body> | |
macmac | |
</body> | |
""" | |
return HTMLResponse(content=content) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment