-
-
Save dewmal/e8f0296bd9743d3fa9dd5841a65d3cdd to your computer and use it in GitHub Desktop.
Building a RAG System with Ollama and LanceDB: A Comprehensive Tutorial
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import httpx | |
import lancedb | |
import pandas as pd | |
from abc import ABC, abstractmethod | |
from typing import List, Dict, Any, AsyncGenerator, Optional | |
from lancedb.pydantic import LanceModel, Vector | |
from lancedb.embeddings import EmbeddingFunctionRegistry | |
# Base Classes | |
class LLM(ABC): | |
@abstractmethod | |
async def generate(self, prompt: str, system_prompt: Optional[str] = None) -> str: | |
"""Generate a response for the given prompt""" | |
pass | |
@abstractmethod | |
async def streaming_generate(self, prompt: str, system_prompt: Optional[str] = None) -> AsyncGenerator[str, None]: | |
"""Generate a streaming response for the given prompt""" | |
pass | |
class Embedder(ABC): | |
@abstractmethod | |
async def embed_documents(self, documents: List[str]) -> List[List[float]]: | |
"""Embed a list of documents""" | |
pass | |
@abstractmethod | |
async def embed_query(self, query: str) -> List[float]: | |
"""Embed a single query""" | |
pass | |
class VectorStore(ABC): | |
@abstractmethod | |
async def store_embeddings(self, documents: List[str], embeddings: List[List[float]], | |
metadata: List[Dict[str, Any]] = None) -> None: | |
"""Store document embeddings with optional metadata""" | |
pass | |
@abstractmethod | |
async def search(self, query_embedding: List[float], limit: int = 3) -> List[Dict[str, Any]]: | |
"""Search for similar documents using query embedding""" | |
pass | |
# Ollama LLM Implementation | |
class AsyncOllamaLLM(LLM): | |
def __init__(self, model_name: str = "llama3.1", base_url: str = "http://localhost:11434"): | |
self.model_name = model_name | |
self.base_url = base_url | |
self.client = httpx.AsyncClient() | |
async def _post_request(self, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]: | |
response = await self.client.post( | |
f"{self.base_url}/{endpoint}", | |
json=payload | |
) | |
return response.json() | |
async def generate(self, prompt: str, system_prompt: Optional[str] = None) -> str: | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
messages.append({"role": "user", "content": prompt}) | |
response = await self._post_request("api/chat", { | |
"model": self.model_name, | |
"messages": messages, | |
"stream": False | |
}) | |
return response['message']['content'] | |
async def streaming_generate(self, prompt: str, system_prompt: Optional[str] = None) -> AsyncGenerator[str, None]: | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
messages.append({"role": "user", "content": prompt}) | |
async with self.client.stream( | |
"POST", | |
f"{self.base_url}/api/chat", | |
json={ | |
"model": self.model_name, | |
"messages": messages, | |
"stream": True | |
} | |
) as response: | |
async for line in response.aiter_lines(): | |
if line.strip(): | |
chunk = httpx.json.loads(line) | |
yield chunk['message']['content'] | |
async def __aenter__(self): | |
return self | |
async def __aexit__(self, exc_type, exc_val, exc_tb): | |
await self.client.aclose() | |
# Ollama Embedder Implementation | |
class AsyncOllamaEmbedder(Embedder): | |
def __init__(self, model_name: str = "mxbai-embed-large", base_url: str = "http://localhost:11434"): | |
self.model_name = model_name | |
self.base_url = base_url | |
self.client = httpx.AsyncClient() | |
# Initialize LanceDB registry and embedder | |
registry = EmbeddingFunctionRegistry.get_instance() | |
self._lance_embedder = registry.get("ollama").create(name=model_name) | |
@property | |
def SourceField(self): | |
return self._lance_embedder.SourceField | |
@property | |
def VectorField(self): | |
return self._lance_embedder.VectorField | |
def ndims(self): | |
return self._lance_embedder.ndims() | |
async def _get_embedding(self, text: str) -> List[float]: | |
response = await self.client.post( | |
f"{self.base_url}/api/embeddings", | |
json={ | |
"model": self.model_name, | |
"prompt": text | |
} | |
) | |
return response.json()['embedding'] | |
async def embed_documents(self, documents: List[str]) -> List[List[float]]: | |
embeddings = [] | |
for doc in documents: | |
embedding = await self._get_embedding(doc) | |
embeddings.append(embedding) | |
return embeddings | |
async def embed_query(self, query: str) -> List[float]: | |
return await self._get_embedding(query) | |
async def __aenter__(self): | |
return self | |
async def __aexit__(self, exc_type, exc_val, exc_tb): | |
await self.client.aclose() | |
# LanceDB Vector Store Implementation | |
def create_lance_schema(embedder): | |
class LanceDBSchema(LanceModel): | |
text: str = embedder.SourceField() | |
vector: Vector(embedder.ndims()) = embedder.VectorField() | |
index: int | |
title: str | |
url: str | |
return LanceDBSchema | |
class AsyncLanceDBStore(VectorStore): | |
def __init__(self, embedder, db_path: str = "./lancedb", table_name: str = "documents"): | |
self.db = lancedb.connect(db_path) | |
self.table_name = table_name | |
self.table = None | |
self._lock = asyncio.Lock() | |
self.schema = create_lance_schema(embedder) | |
async def store_embeddings(self, documents: List[str], embeddings: List[List[float]], | |
metadata: List[Dict[str, Any]] = None) -> None: | |
if metadata is None: | |
metadata = [{"title": "", "url": "", "index": i} for i in range(len(documents))] | |
table_name = self.db.table_names() | |
if self.table_name in table_name: | |
self.table = self.db.open_table(self.table_name) | |
data = [] | |
for doc, emb, meta in zip(documents, embeddings, metadata): | |
data.append({ | |
"text": doc, | |
"vector": emb, | |
"index": meta.get("index", 0), | |
"title": meta.get("title", ""), | |
"url": meta.get("url", "") | |
}) | |
async with self._lock: | |
if self.table is None: | |
self.table = await asyncio.to_thread( | |
self.db.create_table, | |
self.table_name, | |
data=data, | |
schema=self.schema.to_arrow_schema() | |
) | |
else: | |
await asyncio.to_thread(self.table.add, data) | |
async def search(self, query_embedding: List[float], limit: int = 3) -> List[Dict[str, Any]]: | |
if self.table is None: | |
raise ValueError("No documents have been stored yet") | |
async with self._lock: | |
results = await asyncio.to_thread( | |
lambda: self.table.search(query_embedding) | |
.limit(limit) | |
.to_pydantic(self.schema) | |
) | |
return [ | |
{ | |
"text": r.text, | |
"title": r.title, | |
"url": r.url, | |
"index": r.index | |
} | |
for r in results | |
] | |
# Component Factory | |
class AsyncComponentFactory: | |
def __init__(self, config: Dict[str, Any]): | |
self.config = config | |
@staticmethod | |
async def create_llm(type: str, **kwargs) -> LLM: | |
if type == "ollama": | |
return AsyncOllamaLLM(**kwargs) | |
raise ValueError(f"Unknown LLM type: {type}") | |
@staticmethod | |
async def create_embedder(type: str, **kwargs) -> Embedder: | |
if type == "ollama": | |
return AsyncOllamaEmbedder(**kwargs) | |
raise ValueError(f"Unknown embedder type: {type}") | |
@staticmethod | |
async def create_vector_store(type: str, embedder: Embedder, **kwargs) -> VectorStore: | |
if type == "lancedb": | |
return AsyncLanceDBStore(embedder=embedder, **kwargs) | |
raise ValueError(f"Unknown vector store type: {type}") | |
# Main RAG System | |
class BBCNewsRAG: | |
def __init__(self, config: Dict[str, Any]): | |
self.config = config | |
self.factory = AsyncComponentFactory(config) | |
self.llm = None | |
self.embedder = None | |
self.vector_store = None | |
async def initialize(self): | |
"""Initialize all RAG components""" | |
self.llm = await self.factory.create_llm(**self.config["llm"]) | |
self.embedder = await self.factory.create_embedder(**self.config["embedder"]) | |
self.vector_store = await self.factory.create_vector_store( | |
embedder=self.embedder, | |
**self.config["vector_store"] | |
) | |
async def ingest_data(self, df: pd.DataFrame): | |
"""Ingest the BBC news data from pandas DataFrame into the vector store""" | |
documents = df['text'].tolist() | |
embeddings = await self.embedder.embed_documents(documents) | |
metadata = df.apply( | |
lambda row: { | |
'title': row['title'], | |
'url': row['url'], | |
'index': row['index'] | |
}, | |
axis=1 | |
).tolist() | |
await self.vector_store.store_embeddings( | |
documents=documents, | |
embeddings=embeddings, | |
metadata=metadata | |
) | |
async def query(self, question: str, system_prompt: str = None) -> str: | |
"""Query the RAG system""" | |
query_embedding = await self.embedder.embed_query(question) | |
results = await self.vector_store.search(query_embedding) | |
response = await self.llm.generate( | |
prompt=f"Question: {question}\nContext: {results}", | |
system_prompt=system_prompt or "Answer the question based on the provided context." | |
) | |
return response | |
async def close(self): | |
"""Clean up resources""" | |
if self.llm: | |
await self.llm.__aexit__(None, None, None) | |
if self.embedder: | |
await self.embedder.__aexit__(None, None, None) | |
# Example usage | |
async def main(): | |
# Configuration | |
config = { | |
"llm": { | |
"type": "ollama", | |
"model_name": "llama3.2" | |
}, | |
"embedder": { | |
"type": "ollama", | |
"model_name": "nomic-embed-text" | |
}, | |
"vector_store": { | |
"type": "lancedb", | |
"db_path": "./data/lancedb", | |
"table_name": "documents" | |
} | |
} | |
# Create RAG instance | |
rag = BBCNewsRAG(config) | |
await rag.initialize() | |
try: | |
# Read data using pandas | |
df = pd.read_csv('data/data.txt') | |
# Ingest data | |
await rag.ingest_data(df) | |
# Example query | |
question = "Who is Aarin Chiekrie? What does he do?" | |
response = await rag.query( | |
question, | |
system_prompt="You are a helpful assistant that provides accurate information about the BBC news based on the provided articles." | |
) | |
print(f"Question: {question}") | |
print(f"Answer: {response}") | |
finally: | |
# Clean up | |
await rag.close() | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Data Set