Created
March 2, 2024 19:59
-
-
Save tmetsch/f5236f9f16538a98bcc042489455beb9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import transformers | |
from llama_index import core | |
from llama_index.core import chat_engine | |
from llama_index.core import indices | |
from llama_index.core import readers | |
from llama_index.core import utils | |
from llama_index.core.base.llms import types | |
from llama_index.core.indices import vector_store | |
from llama_index.embeddings import huggingface | |
from llama_index.llms.llama_cpp import base | |
MAX_CHAT = 100 | |
custom_prompt = core.PromptTemplate( | |
"""\ | |
<s><INST> | |
Given a conversation (between Human and Assistant) and a follow up message from Human, \ | |
rewrite the message to be a standalone question that captures all relevant context \ | |
from the conversation. | |
<Chat History> | |
{chat_history} | |
<Follow Up Message> | |
{question} | |
<Standalone question> | |
</INST></s> | |
""" | |
) | |
custom_chat_history = [ | |
types.ChatMessage( | |
role=types.MessageRole.USER, | |
content="Hello assistant, we are having a insightful discussion about Thijs Metsch today. Answer questions in a" | |
" positive, helpful and empathetic way.", | |
), | |
types.ChatMessage(role=types.MessageRole.ASSISTANT, content="Okay, sounds good."), | |
] | |
def get_models(path_to_gguf): | |
llm = base.LlamaCPP( | |
model_path=path_to_gguf, | |
context_window=2048, | |
max_new_tokens=256, | |
verbose=False | |
) | |
embedding_model = huggingface.HuggingFaceEmbedding( | |
model_name="WhereIsAI/UAE-Large-V1", | |
cache_folder="hugging_cache") | |
core.Settings.llm = llm | |
core.Settings.embed_model = embedding_model | |
utils.set_global_tokenizer( | |
transformers.AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1").encode | |
) | |
return llm, embedding_model | |
def get_index(path_to_storage, path_to_data): | |
if not os.path.exists(path_to_storage): | |
# load the documents and create the index | |
documents = readers.SimpleDirectoryReader(path_to_data).load_data() | |
index = vector_store.VectorStoreIndex.from_documents(documents, show_progress=True) | |
# store it for later | |
index.storage_context.persist(persist_dir=path_to_storage) | |
else: | |
# load the existing index | |
storage_context = core.StorageContext.from_defaults(persist_dir=path_to_storage) | |
index = indices.load_index_from_storage(storage_context) | |
return index | |
def chat(): | |
llm, embed_model = get_models("<path to model>/mistral-7b-instruct-v0.2.Q4_K_M.gguf") | |
index = get_index("./storage", "./data") | |
query_engine = index.as_query_engine() | |
ce = chat_engine.CondenseQuestionChatEngine.from_defaults( | |
llm=llm, | |
query_engine=query_engine, | |
condense_question_prompt=custom_prompt, | |
chat_history=custom_chat_history, | |
) | |
for _ in range(MAX_CHAT): | |
q = input('\nUser: ') | |
if q == 'exit': | |
break | |
streaming_response = ce.stream_chat(f"<s>[INST]{q}[/INST]") | |
for token in streaming_response.response_gen: | |
print(token, end="") | |
if __name__ == '__main__': | |
chat() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment