Last active
September 14, 2023 15:43
-
-
Save thoraxe/d462fae1cd16b23dccce66cb1a8984a9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_falcon_tgis_context(temperature, repetition_penalty): | |
system_prompt = """ | |
- You are a helpful AI assistant and provide the answer for the question based on the given context. | |
- You answer the question as truthfully as possible using the provided text, and if the answer is not contained within the text below, you say "I don't know". | |
""" | |
## This will wrap the default prompts that are internal to llama-index | |
#query_wrapper_prompt = SimpleInputPrompt(">>QUESTION<<{query_str}\n>>ANSWER<<") | |
query_wrapper_prompt = Prompt("[INST] {query_str} [/INST]") | |
print("Changing default model") | |
# Change default model | |
#embed_model = LangchainEmbedding(HuggingFaceEmbeddings()) | |
embed_model='local:BAAI/bge-base-en' | |
print(f"Getting server environment variables") | |
server_url = os.getenv('TGIS_SERVER_URL', 'http://localhost') # Get server url from env else default | |
server_port = os.getenv('TGIS_SERVER_PORT', '8049') # Get server port from env else default | |
print(f"Initializing TGIS predictor with server_url: {server_url}, server_port: {server_port}") | |
inference_server_url=f"{server_url}:{server_port}/" | |
print(f"Inference Service URL: {inference_server_url}") | |
tgis_predictor = LangChainLLM( | |
llm=HuggingFaceTextGenInference( | |
inference_server_url=inference_server_url, | |
max_new_tokens=256, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
server_kwargs={}, | |
), | |
) | |
print("Creating service_context") | |
service_context = ServiceContext.from_defaults(chunk_size=1024, llm=tgis_predictor, | |
query_wrapper_prompt=query_wrapper_prompt, | |
system_prompt=system_prompt, | |
embed_model=embed_model) | |
return service_context |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Imports | |
from llama_index import set_global_service_context, StorageContext, load_index_from_storage | |
from llama_index.vector_stores import RedisVectorStore | |
from llama_index import VectorStoreIndex, SimpleDirectoryReader, Document | |
from llama_index import ServiceContext | |
from model_context import get_falcon_tgis_context, get_falcon_tgis_context_sentence_window | |
from llama_index.prompts.prompts import SimpleInputPrompt | |
import os, time | |
import logging | |
import sys | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) | |
# Select Model | |
#service_context = get_stablelm_context() | |
#service_context = get_falcon_tgis_context_sentence_window(0.7, 1.03) | |
service_context = get_falcon_tgis_context(0.7, 1.03) | |
#service_context = ServiceContext.from_defaults(embed_model="local") | |
# Load data | |
redis_hostname = os.getenv('REDIS_SERVER_HOSTNAME', 'localhost') # Get server url from env else default | |
print("Connecting to Redis at " + redis_hostname) | |
vector_store = RedisVectorStore( | |
index_name="web-console", | |
index_prefix="llama", | |
redis_url=f"redis://{redis_hostname}:6379", | |
overwrite=False, | |
) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context) | |
#query = "How do I get the SSH key for a cluster from Hive?" | |
query = "How do I install the web terminal?" | |
response = index.as_query_engine(verbose=True, streaming=True).query(query) | |
referenced_documents = "\n\nReferenced documents:\n" | |
for source_node in response.source_nodes: | |
#print(source_node.node.metadata['file_name']) | |
referenced_documents += source_node.node.metadata['file_name'] + '\n' | |
print() | |
print(query) | |
print(str(response)) | |
print(referenced_documents) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Imports | |
from llama_index.vector_stores import RedisVectorStore | |
from llama_index import VectorStoreIndex | |
from model_context import get_falcon_tgis_context | |
from llama_index.tools.query_engine import QueryEngineTool | |
from llama_index.query_engine.router_query_engine import RouterQueryEngine | |
from llama_index.selectors.llm_selectors import LLMSingleSelector, LLMMultiSelector | |
import os, time | |
import logging | |
import sys | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) | |
# Select Model | |
#service_context = get_stablelm_context() | |
#service_context = get_falcon_tgis_context_sentence_window(0.7, 1.03) | |
service_context = get_falcon_tgis_context(0.7, 1.03) | |
#service_context = ServiceContext.from_defaults(embed_model="local") | |
# Load data | |
redis_hostname = os.getenv('REDIS_SERVER_HOSTNAME', 'localhost') # Get server url from env else default | |
print("Connecting to Redis at " + redis_hostname) | |
# https://gpt-index.readthedocs.io/en/stable/examples/query_engine/RouterQueryEngine.html | |
# attempt to use a router query setup | |
print("Setting up vector stores") | |
ops_sop_vector_store = RedisVectorStore( | |
index_name="ops-sop", | |
redis_url=f"redis://{redis_hostname}:6379", | |
overwrite=True, | |
) | |
web_console_store = RedisVectorStore( | |
index_name="web-console", | |
redis_url=f"redis://{redis_hostname}:6379", | |
overwrite=True, | |
) | |
print("Setting up vector indices") | |
ops_index = VectorStoreIndex.from_vector_store(ops_sop_vector_store, service_context=service_context) | |
web_console_index = VectorStoreIndex.from_vector_store(web_console_store, service_context=service_context) | |
print("Setting up query engines") | |
ops_query_engine = ops_index.as_query_engine( | |
verbose=True, | |
streaming=True, | |
) | |
web_console_engine = web_console_index.as_query_engine( | |
verbose=True, | |
streaming=True, | |
) | |
print("Setting up tools") | |
os_query_tool = QueryEngineTool.from_defaults( | |
query_engine=ops_query_engine, | |
description="Documents related to SRE and operations questions about troubleshooting managed OpenShift clusters.", | |
) | |
web_console_query_tool = QueryEngineTool.from_defaults( | |
query_engine=web_console_engine, | |
description="User and administrator documentation related to the OpenShift web console and its configuration.", | |
) | |
print("Setting up router") | |
query_engine = RouterQueryEngine( | |
service_context=service_context, | |
selector=LLMSingleSelector.from_defaults(service_context=service_context), | |
query_engine_tools=[ | |
os_query_tool, | |
web_console_query_tool, | |
], | |
) | |
#query = input("What's your query? ") | |
query = "I am an OpenShift administrator and I would like to know how to install the web terminal." | |
response = query_engine.query(query) | |
referenced_documents = "\n\nReferenced documents:\n" | |
for source_node in response.source_nodes: | |
#print(source_node.node.metadata['file_name']) | |
referenced_documents += source_node.node.metadata['file_name'] + '\n' | |
print() | |
print(query) | |
print(str(response)) | |
print(referenced_documents) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment