Skip to content

Instantly share code, notes, and snippets.

@MortalHappiness
Last active December 13, 2024 21:00
Show Gist options
  • Save MortalHappiness/7030bbe96c4bece8a07ea9057ba18b86 to your computer and use it in GitHub Desktop.
Save MortalHappiness/7030bbe96c4bece8a07ea9057ba18b86 to your computer and use it in GitHub Desktop.
Monkey-patching the https://github.com/microsoft/graphrag package to make it support the locally hosted LLM models on Ollama. See https://chishengliu.com/posts/graphrag-local-ollama/ for details.
# Monkey patch the graphrag package to support local models deployed via Ollama. Only tested for graphrag version 0.3.2.
# See https://chishengliu.com/posts/graphrag-local-ollama/ for details.
#
# Usage:
# * How to patch the `graphrag.index` CLI:
# * Save https://github.com/microsoft/graphrag/blob/v0.3.2/graphrag/index/__main__.py as "index.py"
# * Replace line 8 with:
# ```python
# from graphrag.index.cli import index_cli
# from graphrag_monkey_patch import patch_graphrag
# patch_graphrag()
# ```
# * Then, you can use `python index.py` instead of `python -m graphrag.index` to do the indexing.
# * How to patch the `graphrag.query` CLI:
# * Save https://github.com/microsoft/graphrag/blob/v0.3.2/graphrag/query/__main__.py as "query.py"
# * Replace line 9 with:
# ```python
# from graphrag.query.cli import run_global_search, run_local_search
# from graphrag_monkey_patch import patch_graphrag
# patch_graphrag()
# ```
# * Then, you can use `python query.py` instead of `python -m graphrag.query` to do the queries.
def patch_graphrag():
patch_openai_embeddings_llm()
patch_query_embedding()
patch_global_search()
def patch_openai_embeddings_llm():
# Reference: https://github.com/microsoft/graphrag/issues/370#issuecomment-2211821370
from graphrag.llm.openai.openai_embeddings_llm import OpenAIEmbeddingsLLM
import ollama
async def _execute_llm(self, input, **kwargs):
embedding_list = []
for inp in input:
embedding = ollama.embeddings(model=self.configuration.model, prompt=inp)
embedding_list.append(embedding["embedding"])
return embedding_list
OpenAIEmbeddingsLLM._execute_llm = _execute_llm
def patch_query_embedding():
# Reference: https://github.com/microsoft/graphrag/issues/345#issuecomment-2230317752
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
import ollama
from tenacity import (
AsyncRetrying,
RetryError,
Retrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
def _embed_with_retry(self, text, **kwargs):
try:
retryer = Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
for attempt in retryer:
with attempt:
embedding = (ollama.embeddings(model=self.model, prompt=text)["embedding"] or [])
return (embedding, len(text))
except RetryError as e:
self._reporter.error(
message="Error at embed_with_retry()",
details={self.__class__.__name__: str(e)},
)
return ([], 0)
else:
# TODO: why not just throw in this case?
return ([], 0)
async def _aembed_with_retry(self, text, **kwargs):
try:
retryer = AsyncRetrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
async for attempt in retryer:
with attempt:
embedding = (ollama.embeddings(model=self.model, prompt=text)["embedding"] or [])
return (embedding, len(text))
except RetryError as e:
self._reporter.error(
message="Error at embed_with_retry()",
details={self.__class__.__name__: str(e)},
)
return ([], 0)
else:
# TODO: why not just throw in this case?
return ([], 0)
OpenAIEmbedding._embed_with_retry = _embed_with_retry
OpenAIEmbedding._aembed_with_retry = _aembed_with_retry
def patch_global_search():
# Reference: https://github.com/microsoft/graphrag/issues/575#issuecomment-2252045859
from graphrag.query.structured_search.global_search.search import GlobalSearch
import logging
import time
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import SearchResult
log = logging.getLogger(__name__)
async def _map_response_single_batch(self, context_data, query, **llm_kwargs):
"""Generate answer for a single chunk of community reports."""
start_time = time.time()
search_prompt = ""
try:
search_prompt = self.map_system_prompt.format(context_data=context_data)
search_messages = [ {"role": "user", "content": search_prompt + "\n\n### USER QUESTION ### \n\n" + query} ]
async with self.semaphore:
search_response = await self.llm.agenerate(
messages=search_messages, streaming=False, **llm_kwargs
)
log.info("Map response: %s", search_response)
try:
# parse search response json
processed_response = self.parse_search_response(search_response)
except ValueError:
# Clean up and retry parse
try:
# parse search response json
processed_response = self.parse_search_response(search_response)
except ValueError:
log.warning(
"Warning: Error parsing search response json - skipping this batch"
)
processed_response = []
return SearchResult(
response=processed_response,
context_data=context_data,
context_text=context_data,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
)
except Exception:
log.exception("Exception in _map_response_single_batch")
return SearchResult(
response=[{"answer": "", "score": 0}],
context_data=context_data,
context_text=context_data,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
)
GlobalSearch._map_response_single_batch = _map_response_single_batch
# An example config that works without errors when combined with the monkey-patch script in this Gist.
# See https://chishengliu.com/posts/graphrag-local-ollama/ for details.
encoding_model: cl100k_base
skip_workflows: []
llm:
api_key: NONE
type: openai_chat # or azure_openai_chat
model: llama3.1:8b
model_supports_json: true # recommended if this is available for your model.
max_tokens: 8191
# request_timeout: 180.0
api_base: http://localhost:11434/v1
# api_version: 2024-02-15-preview
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
# tokens_per_minute: 150_000 # set a leaky bucket throttle
# requests_per_minute: 10_000 # set a leaky bucket throttle
# max_retries: 10
# max_retry_wait: 10.0
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
# concurrent_requests: 25 # the number of parallel inflight requests that may be made
# temperature: 0 # temperature for sampling
# top_p: 1 # top-p sampling
# n: 1 # Number of completions to generate
parallelization:
stagger: 0.3
# num_threads: 50 # the number of threads to use for parallel processing
async_mode: threaded # or asyncio
embeddings:
## parallelization: override the global parallelization settings for embeddings
async_mode: threaded # or asyncio
# target: required # or all
llm:
api_key: NONE
type: openai_embedding # or azure_openai_embedding
model: nomic-embed-text
api_base: http://localhost:11434/api
# api_version: 2024-02-15-preview
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
# tokens_per_minute: 150_000 # set a leaky bucket throttle
# requests_per_minute: 10_000 # set a leaky bucket throttle
# max_retries: 10
# max_retry_wait: 10.0
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
# concurrent_requests: 25 # the number of parallel inflight requests that may be made
# batch_size: 16 # the number of documents to send in a single request
batch_max_tokens: 8191 # the maximum number of tokens to send in a single request
chunks:
size: 300
overlap: 100
group_by_columns: [id] # by default, we don't allow chunks to cross documents
input:
type: file # or blob
file_type: text # or csv
base_dir: "input"
file_encoding: utf-8
file_pattern: ".*\\.txt$"
cache:
type: file # or blob
base_dir: "cache"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
storage:
type: file # or blob
base_dir: "output/${timestamp}/artifacts"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
reporting:
type: file # or console, blob
base_dir: "output/${timestamp}/reports"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
entity_extraction:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/entity_extraction.txt"
entity_types: [organization,person,geo,event]
max_gleanings: 1
summarize_descriptions:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/summarize_descriptions.txt"
max_length: 500
claim_extraction:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
# enabled: true
prompt: "prompts/claim_extraction.txt"
description: "Any claims or facts that could be relevant to information discovery."
max_gleanings: 1
community_reports:
## llm: override the global llm settings for this task
## parallelization: override the global parallelization settings for this task
## async_mode: override the global async_mode settings for this task
prompt: "prompts/community_report.txt"
max_length: 2000
max_input_length: 8000
cluster_graph:
max_cluster_size: 10
embed_graph:
enabled: false # if true, will generate node2vec embeddings for nodes
# num_walks: 10
# walk_length: 40
# window_size: 2
# iterations: 3
# random_seed: 597832
umap:
enabled: false # if true, will generate UMAP embeddings for nodes
snapshots:
graphml: false
raw_entities: false
top_level_nodes: false
local_search:
# text_unit_prop: 0.5
# community_prop: 0.1
# conversation_history_max_turns: 5
# top_k_mapped_entities: 10
# top_k_relationships: 10
# llm_temperature: 0 # temperature for sampling
# llm_top_p: 1 # top-p sampling
# llm_n: 1 # Number of completions to generate
# max_tokens: 12000
global_search:
# llm_temperature: 0 # temperature for sampling
# llm_top_p: 1 # top-p sampling
# llm_n: 1 # Number of completions to generate
# max_tokens: 12000
# data_max_tokens: 12000
# map_max_tokens: 1000
# reduce_max_tokens: 2000
# concurrency: 32
@superbarne
Copy link

the function patch_all doesn't seem to be available in graphrag_monkey_patch.py

@MortalHappiness
Copy link
Author

the function patch_all doesn't seem to be available in graphrag_monkey_patch.py

Good catch! I forgot that I renamed the function to patch_graphrag. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment