Skip to content

Instantly share code, notes, and snippets.

@thoraxe
Last active September 14, 2023 15:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thoraxe/d462fae1cd16b23dccce66cb1a8984a9 to your computer and use it in GitHub Desktop.
Save thoraxe/d462fae1cd16b23dccce66cb1a8984a9 to your computer and use it in GitHub Desktop.
def get_falcon_tgis_context(temperature, repetition_penalty):
system_prompt = """
- You are a helpful AI assistant and provide the answer for the question based on the given context.
- You answer the question as truthfully as possible using the provided text, and if the answer is not contained within the text below, you say "I don't know".
"""
## This will wrap the default prompts that are internal to llama-index
#query_wrapper_prompt = SimpleInputPrompt(">>QUESTION<<{query_str}\n>>ANSWER<<")
query_wrapper_prompt = Prompt("[INST] {query_str} [/INST]")
print("Changing default model")
# Change default model
#embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
embed_model='local:BAAI/bge-base-en'
print(f"Getting server environment variables")
server_url = os.getenv('TGIS_SERVER_URL', 'http://localhost') # Get server url from env else default
server_port = os.getenv('TGIS_SERVER_PORT', '8049') # Get server port from env else default
print(f"Initializing TGIS predictor with server_url: {server_url}, server_port: {server_port}")
inference_server_url=f"{server_url}:{server_port}/"
print(f"Inference Service URL: {inference_server_url}")
tgis_predictor = LangChainLLM(
llm=HuggingFaceTextGenInference(
inference_server_url=inference_server_url,
max_new_tokens=256,
temperature=temperature,
repetition_penalty=repetition_penalty,
server_kwargs={},
),
)
print("Creating service_context")
service_context = ServiceContext.from_defaults(chunk_size=1024, llm=tgis_predictor,
query_wrapper_prompt=query_wrapper_prompt,
system_prompt=system_prompt,
embed_model=embed_model)
return service_context
# Imports
from llama_index import set_global_service_context, StorageContext, load_index_from_storage
from llama_index.vector_stores import RedisVectorStore
from llama_index import VectorStoreIndex, SimpleDirectoryReader, Document
from llama_index import ServiceContext
from model_context import get_falcon_tgis_context, get_falcon_tgis_context_sentence_window
from llama_index.prompts.prompts import SimpleInputPrompt
import os, time
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
# Select Model
#service_context = get_stablelm_context()
#service_context = get_falcon_tgis_context_sentence_window(0.7, 1.03)
service_context = get_falcon_tgis_context(0.7, 1.03)
#service_context = ServiceContext.from_defaults(embed_model="local")
# Load data
redis_hostname = os.getenv('REDIS_SERVER_HOSTNAME', 'localhost') # Get server url from env else default
print("Connecting to Redis at " + redis_hostname)
vector_store = RedisVectorStore(
index_name="web-console",
index_prefix="llama",
redis_url=f"redis://{redis_hostname}:6379",
overwrite=False,
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_vector_store(vector_store, service_context=service_context)
#query = "How do I get the SSH key for a cluster from Hive?"
query = "How do I install the web terminal?"
response = index.as_query_engine(verbose=True, streaming=True).query(query)
referenced_documents = "\n\nReferenced documents:\n"
for source_node in response.source_nodes:
#print(source_node.node.metadata['file_name'])
referenced_documents += source_node.node.metadata['file_name'] + '\n'
print()
print(query)
print(str(response))
print(referenced_documents)
# Imports
from llama_index.vector_stores import RedisVectorStore
from llama_index import VectorStoreIndex
from model_context import get_falcon_tgis_context
from llama_index.tools.query_engine import QueryEngineTool
from llama_index.query_engine.router_query_engine import RouterQueryEngine
from llama_index.selectors.llm_selectors import LLMSingleSelector, LLMMultiSelector
import os, time
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
# Select Model
#service_context = get_stablelm_context()
#service_context = get_falcon_tgis_context_sentence_window(0.7, 1.03)
service_context = get_falcon_tgis_context(0.7, 1.03)
#service_context = ServiceContext.from_defaults(embed_model="local")
# Load data
redis_hostname = os.getenv('REDIS_SERVER_HOSTNAME', 'localhost') # Get server url from env else default
print("Connecting to Redis at " + redis_hostname)
# https://gpt-index.readthedocs.io/en/stable/examples/query_engine/RouterQueryEngine.html
# attempt to use a router query setup
print("Setting up vector stores")
ops_sop_vector_store = RedisVectorStore(
index_name="ops-sop",
redis_url=f"redis://{redis_hostname}:6379",
overwrite=True,
)
web_console_store = RedisVectorStore(
index_name="web-console",
redis_url=f"redis://{redis_hostname}:6379",
overwrite=True,
)
print("Setting up vector indices")
ops_index = VectorStoreIndex.from_vector_store(ops_sop_vector_store, service_context=service_context)
web_console_index = VectorStoreIndex.from_vector_store(web_console_store, service_context=service_context)
print("Setting up query engines")
ops_query_engine = ops_index.as_query_engine(
verbose=True,
streaming=True,
)
web_console_engine = web_console_index.as_query_engine(
verbose=True,
streaming=True,
)
print("Setting up tools")
os_query_tool = QueryEngineTool.from_defaults(
query_engine=ops_query_engine,
description="Documents related to SRE and operations questions about troubleshooting managed OpenShift clusters.",
)
web_console_query_tool = QueryEngineTool.from_defaults(
query_engine=web_console_engine,
description="User and administrator documentation related to the OpenShift web console and its configuration.",
)
print("Setting up router")
query_engine = RouterQueryEngine(
service_context=service_context,
selector=LLMSingleSelector.from_defaults(service_context=service_context),
query_engine_tools=[
os_query_tool,
web_console_query_tool,
],
)
#query = input("What's your query? ")
query = "I am an OpenShift administrator and I would like to know how to install the web terminal."
response = query_engine.query(query)
referenced_documents = "\n\nReferenced documents:\n"
for source_node in response.source_nodes:
#print(source_node.node.metadata['file_name'])
referenced_documents += source_node.node.metadata['file_name'] + '\n'
print()
print(query)
print(str(response))
print(referenced_documents)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment