Created
April 15, 2024 09:49
-
-
Save si3mshady/03d6cb613548958a4736908c1d3cb483 to your computer and use it in GitHub Desktop.
Custom Implementation of LLM for use within LlamaIndex RAG pipeline
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, List, Mapping, Any | |
import requests | |
from llama_index.core import SimpleDirectoryReader, SummaryIndex | |
from llama_index.core.callbacks import CallbackManager | |
from llama_index.core.llms import ( | |
CustomLLM, | |
CompletionResponse, | |
CompletionResponseGen, | |
LLMMetadata, | |
) | |
from llama_index.core.llms.callbacks import llm_completion_callback | |
from llama_index.core import Settings | |
class OurLLM(CustomLLM): | |
context_window: int = 3900 | |
num_output: int = 256 | |
model_name: str = "mixtral" | |
dummy_response: str = "My response" | |
@property | |
def metadata(self) -> LLMMetadata: | |
"""Get LLM metadata.""" | |
return LLMMetadata( | |
context_window=self.context_window, | |
num_output=self.num_output, | |
model_name=self.model_name, | |
) | |
@llm_completion_callback() | |
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
url = "http://34.125.222.119:8080/completion" | |
headers = {"Content-Type": "application/json"} | |
data = { "prompt": f"Only a single and concise response is required for the question: Question {prompt}", "n_predict": 128 } | |
response = requests.post(url, headers=headers, json=data) | |
response = requests.post(url, json=data) | |
# print(response.json()['message']['content']) | |
self.dummy_response = response.json()['content'] | |
return CompletionResponse(text=self.dummy_response) | |
@llm_completion_callback() | |
def stream_complete( | |
self, prompt: str, **kwargs: Any | |
) -> CompletionResponseGen: | |
response = "" | |
for token in self.dummy_response: | |
response += token | |
yield CompletionResponse(text=response, delta=token) | |
# define our LLM | |
Settings.llm = OurLLM() | |
# define embed model | |
Settings.embed_model = "local:BAAI/bge-base-en-v1.5" | |
# Load the your data | |
documents = SimpleDirectoryReader("./data").load_data() | |
index = SummaryIndex.from_documents(documents) | |
# Query and print response | |
query_engine = index.as_query_engine() | |
while True: | |
user_input = input("Enter your query (type 'quit' to exit): ") | |
if user_input.lower() == "quit": | |
break | |
response = query_engine.query(user_input) | |
print(response) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment