Last active
December 13, 2024 21:00
-
-
Save MortalHappiness/7030bbe96c4bece8a07ea9057ba18b86 to your computer and use it in GitHub Desktop.
Monkey-patching the https://github.com/microsoft/graphrag package to make it support the locally hosted LLM models on Ollama. See https://chishengliu.com/posts/graphrag-local-ollama/ for details.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Monkey patch the graphrag package to support local models deployed via Ollama. Only tested for graphrag version 0.3.2. | |
# See https://chishengliu.com/posts/graphrag-local-ollama/ for details. | |
# | |
# Usage: | |
# * How to patch the `graphrag.index` CLI: | |
# * Save https://github.com/microsoft/graphrag/blob/v0.3.2/graphrag/index/__main__.py as "index.py" | |
# * Replace line 8 with: | |
# ```python | |
# from graphrag.index.cli import index_cli | |
# from graphrag_monkey_patch import patch_graphrag | |
# patch_graphrag() | |
# ``` | |
# * Then, you can use `python index.py` instead of `python -m graphrag.index` to do the indexing. | |
# * How to patch the `graphrag.query` CLI: | |
# * Save https://github.com/microsoft/graphrag/blob/v0.3.2/graphrag/query/__main__.py as "query.py" | |
# * Replace line 9 with: | |
# ```python | |
# from graphrag.query.cli import run_global_search, run_local_search | |
# from graphrag_monkey_patch import patch_graphrag | |
# patch_graphrag() | |
# ``` | |
# * Then, you can use `python query.py` instead of `python -m graphrag.query` to do the queries. | |
def patch_graphrag(): | |
patch_openai_embeddings_llm() | |
patch_query_embedding() | |
patch_global_search() | |
def patch_openai_embeddings_llm(): | |
# Reference: https://github.com/microsoft/graphrag/issues/370#issuecomment-2211821370 | |
from graphrag.llm.openai.openai_embeddings_llm import OpenAIEmbeddingsLLM | |
import ollama | |
async def _execute_llm(self, input, **kwargs): | |
embedding_list = [] | |
for inp in input: | |
embedding = ollama.embeddings(model=self.configuration.model, prompt=inp) | |
embedding_list.append(embedding["embedding"]) | |
return embedding_list | |
OpenAIEmbeddingsLLM._execute_llm = _execute_llm | |
def patch_query_embedding(): | |
# Reference: https://github.com/microsoft/graphrag/issues/345#issuecomment-2230317752 | |
from graphrag.query.llm.oai.embedding import OpenAIEmbedding | |
import ollama | |
from tenacity import ( | |
AsyncRetrying, | |
RetryError, | |
Retrying, | |
retry_if_exception_type, | |
stop_after_attempt, | |
wait_exponential_jitter, | |
) | |
def _embed_with_retry(self, text, **kwargs): | |
try: | |
retryer = Retrying( | |
stop=stop_after_attempt(self.max_retries), | |
wait=wait_exponential_jitter(max=10), | |
reraise=True, | |
retry=retry_if_exception_type(self.retry_error_types), | |
) | |
for attempt in retryer: | |
with attempt: | |
embedding = (ollama.embeddings(model=self.model, prompt=text)["embedding"] or []) | |
return (embedding, len(text)) | |
except RetryError as e: | |
self._reporter.error( | |
message="Error at embed_with_retry()", | |
details={self.__class__.__name__: str(e)}, | |
) | |
return ([], 0) | |
else: | |
# TODO: why not just throw in this case? | |
return ([], 0) | |
async def _aembed_with_retry(self, text, **kwargs): | |
try: | |
retryer = AsyncRetrying( | |
stop=stop_after_attempt(self.max_retries), | |
wait=wait_exponential_jitter(max=10), | |
reraise=True, | |
retry=retry_if_exception_type(self.retry_error_types), | |
) | |
async for attempt in retryer: | |
with attempt: | |
embedding = (ollama.embeddings(model=self.model, prompt=text)["embedding"] or []) | |
return (embedding, len(text)) | |
except RetryError as e: | |
self._reporter.error( | |
message="Error at embed_with_retry()", | |
details={self.__class__.__name__: str(e)}, | |
) | |
return ([], 0) | |
else: | |
# TODO: why not just throw in this case? | |
return ([], 0) | |
OpenAIEmbedding._embed_with_retry = _embed_with_retry | |
OpenAIEmbedding._aembed_with_retry = _aembed_with_retry | |
def patch_global_search(): | |
# Reference: https://github.com/microsoft/graphrag/issues/575#issuecomment-2252045859 | |
from graphrag.query.structured_search.global_search.search import GlobalSearch | |
import logging | |
import time | |
from graphrag.query.llm.text_utils import num_tokens | |
from graphrag.query.structured_search.base import SearchResult | |
log = logging.getLogger(__name__) | |
async def _map_response_single_batch(self, context_data, query, **llm_kwargs): | |
"""Generate answer for a single chunk of community reports.""" | |
start_time = time.time() | |
search_prompt = "" | |
try: | |
search_prompt = self.map_system_prompt.format(context_data=context_data) | |
search_messages = [ {"role": "user", "content": search_prompt + "\n\n### USER QUESTION ### \n\n" + query} ] | |
async with self.semaphore: | |
search_response = await self.llm.agenerate( | |
messages=search_messages, streaming=False, **llm_kwargs | |
) | |
log.info("Map response: %s", search_response) | |
try: | |
# parse search response json | |
processed_response = self.parse_search_response(search_response) | |
except ValueError: | |
# Clean up and retry parse | |
try: | |
# parse search response json | |
processed_response = self.parse_search_response(search_response) | |
except ValueError: | |
log.warning( | |
"Warning: Error parsing search response json - skipping this batch" | |
) | |
processed_response = [] | |
return SearchResult( | |
response=processed_response, | |
context_data=context_data, | |
context_text=context_data, | |
completion_time=time.time() - start_time, | |
llm_calls=1, | |
prompt_tokens=num_tokens(search_prompt, self.token_encoder), | |
) | |
except Exception: | |
log.exception("Exception in _map_response_single_batch") | |
return SearchResult( | |
response=[{"answer": "", "score": 0}], | |
context_data=context_data, | |
context_text=context_data, | |
completion_time=time.time() - start_time, | |
llm_calls=1, | |
prompt_tokens=num_tokens(search_prompt, self.token_encoder), | |
) | |
GlobalSearch._map_response_single_batch = _map_response_single_batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# An example config that works without errors when combined with the monkey-patch script in this Gist. | |
# See https://chishengliu.com/posts/graphrag-local-ollama/ for details. | |
encoding_model: cl100k_base | |
skip_workflows: [] | |
llm: | |
api_key: NONE | |
type: openai_chat # or azure_openai_chat | |
model: llama3.1:8b | |
model_supports_json: true # recommended if this is available for your model. | |
max_tokens: 8191 | |
# request_timeout: 180.0 | |
api_base: http://localhost:11434/v1 | |
# api_version: 2024-02-15-preview | |
# organization: <organization_id> | |
# deployment_name: <azure_model_deployment_name> | |
# tokens_per_minute: 150_000 # set a leaky bucket throttle | |
# requests_per_minute: 10_000 # set a leaky bucket throttle | |
# max_retries: 10 | |
# max_retry_wait: 10.0 | |
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times | |
# concurrent_requests: 25 # the number of parallel inflight requests that may be made | |
# temperature: 0 # temperature for sampling | |
# top_p: 1 # top-p sampling | |
# n: 1 # Number of completions to generate | |
parallelization: | |
stagger: 0.3 | |
# num_threads: 50 # the number of threads to use for parallel processing | |
async_mode: threaded # or asyncio | |
embeddings: | |
## parallelization: override the global parallelization settings for embeddings | |
async_mode: threaded # or asyncio | |
# target: required # or all | |
llm: | |
api_key: NONE | |
type: openai_embedding # or azure_openai_embedding | |
model: nomic-embed-text | |
api_base: http://localhost:11434/api | |
# api_version: 2024-02-15-preview | |
# organization: <organization_id> | |
# deployment_name: <azure_model_deployment_name> | |
# tokens_per_minute: 150_000 # set a leaky bucket throttle | |
# requests_per_minute: 10_000 # set a leaky bucket throttle | |
# max_retries: 10 | |
# max_retry_wait: 10.0 | |
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times | |
# concurrent_requests: 25 # the number of parallel inflight requests that may be made | |
# batch_size: 16 # the number of documents to send in a single request | |
batch_max_tokens: 8191 # the maximum number of tokens to send in a single request | |
chunks: | |
size: 300 | |
overlap: 100 | |
group_by_columns: [id] # by default, we don't allow chunks to cross documents | |
input: | |
type: file # or blob | |
file_type: text # or csv | |
base_dir: "input" | |
file_encoding: utf-8 | |
file_pattern: ".*\\.txt$" | |
cache: | |
type: file # or blob | |
base_dir: "cache" | |
# connection_string: <azure_blob_storage_connection_string> | |
# container_name: <azure_blob_storage_container_name> | |
storage: | |
type: file # or blob | |
base_dir: "output/${timestamp}/artifacts" | |
# connection_string: <azure_blob_storage_connection_string> | |
# container_name: <azure_blob_storage_container_name> | |
reporting: | |
type: file # or console, blob | |
base_dir: "output/${timestamp}/reports" | |
# connection_string: <azure_blob_storage_connection_string> | |
# container_name: <azure_blob_storage_container_name> | |
entity_extraction: | |
## llm: override the global llm settings for this task | |
## parallelization: override the global parallelization settings for this task | |
## async_mode: override the global async_mode settings for this task | |
prompt: "prompts/entity_extraction.txt" | |
entity_types: [organization,person,geo,event] | |
max_gleanings: 1 | |
summarize_descriptions: | |
## llm: override the global llm settings for this task | |
## parallelization: override the global parallelization settings for this task | |
## async_mode: override the global async_mode settings for this task | |
prompt: "prompts/summarize_descriptions.txt" | |
max_length: 500 | |
claim_extraction: | |
## llm: override the global llm settings for this task | |
## parallelization: override the global parallelization settings for this task | |
## async_mode: override the global async_mode settings for this task | |
# enabled: true | |
prompt: "prompts/claim_extraction.txt" | |
description: "Any claims or facts that could be relevant to information discovery." | |
max_gleanings: 1 | |
community_reports: | |
## llm: override the global llm settings for this task | |
## parallelization: override the global parallelization settings for this task | |
## async_mode: override the global async_mode settings for this task | |
prompt: "prompts/community_report.txt" | |
max_length: 2000 | |
max_input_length: 8000 | |
cluster_graph: | |
max_cluster_size: 10 | |
embed_graph: | |
enabled: false # if true, will generate node2vec embeddings for nodes | |
# num_walks: 10 | |
# walk_length: 40 | |
# window_size: 2 | |
# iterations: 3 | |
# random_seed: 597832 | |
umap: | |
enabled: false # if true, will generate UMAP embeddings for nodes | |
snapshots: | |
graphml: false | |
raw_entities: false | |
top_level_nodes: false | |
local_search: | |
# text_unit_prop: 0.5 | |
# community_prop: 0.1 | |
# conversation_history_max_turns: 5 | |
# top_k_mapped_entities: 10 | |
# top_k_relationships: 10 | |
# llm_temperature: 0 # temperature for sampling | |
# llm_top_p: 1 # top-p sampling | |
# llm_n: 1 # Number of completions to generate | |
# max_tokens: 12000 | |
global_search: | |
# llm_temperature: 0 # temperature for sampling | |
# llm_top_p: 1 # top-p sampling | |
# llm_n: 1 # Number of completions to generate | |
# max_tokens: 12000 | |
# data_max_tokens: 12000 | |
# map_max_tokens: 1000 | |
# reduce_max_tokens: 2000 | |
# concurrency: 32 |
the function
patch_all
doesn't seem to be available in graphrag_monkey_patch.py
Good catch! I forgot that I renamed the function to patch_graphrag
. Thanks.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
the function
patch_all
doesn't seem to be available in graphrag_monkey_patch.py