Last active
April 25, 2024 00:03
-
-
Save shawnlewis/ea95a99eade3ed5260008ade78f672f8 to your computer and use it in GitHub Desktop.
LLama index weave integration starter code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any | |
import datetime | |
import weave | |
import copy | |
from llama_index.legacy.callbacks.base_handler import BaseCallbackHandler | |
from llama_index.core.indices.base import BaseIndex | |
import os.path | |
from llama_index.core import ( | |
VectorStoreIndex, | |
SimpleDirectoryReader, | |
StorageContext, | |
load_index_from_storage, | |
) | |
from typing import Any, Dict, List, Optional | |
from llama_index.legacy.callbacks.schema import CBEventType | |
from weave import run_context | |
from weave.weave_client import generate_id, Call | |
from weave import graph_client_context | |
from weave.trace.serialize import to_json | |
from weave.trace_server.trace_server_interface import ( | |
StartedCallSchemaForInsert, | |
CallStartReq, | |
CallEndReq, | |
EndedCallSchemaForInsert, | |
) | |
class WeaveCallbackHandler(BaseCallbackHandler): | |
"""Base callback handler that can be used to track event starts and ends.""" | |
def __init__( | |
self, | |
event_starts_to_ignore: Optional[List[CBEventType]] = None, | |
event_ends_to_ignore: Optional[List[CBEventType]] = None, | |
) -> None: | |
event_starts_to_ignore = ( | |
event_starts_to_ignore if event_starts_to_ignore else [] | |
) | |
event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else [] | |
self._ops = {} | |
self._event_tokens = {} | |
super().__init__( | |
event_starts_to_ignore=event_starts_to_ignore, | |
event_ends_to_ignore=event_ends_to_ignore, | |
) | |
def on_event_start( | |
self, | |
event_type: CBEventType, | |
payload: Optional[Dict[str, Any]] = None, | |
event_id: str = "", | |
parent_id: str = "", | |
**kwargs: Any, | |
) -> str: | |
"""Run when an event starts and return id of event.""" | |
gc = graph_client_context.require_graph_client() | |
event_type_name = event_type.name.lower() | |
# Weird stuff we have to do to make an anonymous op. | |
op = self._ops.get(event_type_name) | |
if not op: | |
@weave.op() | |
def resolve_fn(): | |
return "fake_op" | |
resolve_fn.name = event_type_name | |
op = resolve_fn | |
self._ops[event_type_name] = op | |
op_def_ref = gc._save_op(op) | |
op_str = op_def_ref.uri() | |
# if self._weave_trace_id is None: | |
# raise ValueError("Trace not started") | |
# TODO: the below might be nicer if we used the client API instead of | |
# server API. Otherwise, we should refactor the API a little bit to make | |
# this easier. Doing stack manipulation here feels bad. | |
# parent weave call | |
current_run = run_context.get_current_run() | |
if current_run and current_run.id: | |
parent_id = current_run.id | |
trace_id = current_run.trace_id | |
else: | |
parent_id = None | |
trace_id = generate_id() | |
# have to manually append to stack :( | |
new_stack = copy.copy(run_context._run_stack.get()) | |
call = Call( | |
project_id=gc._project_id(), | |
id=event_id, | |
op_name=op_str, | |
trace_id=trace_id, | |
parent_id=parent_id, | |
inputs=to_json(payload, gc._project_id(), gc.server), | |
) | |
new_stack.append(call) | |
self._event_tokens[event_id] = run_context._run_stack.set(new_stack) | |
# print("EV START", event_type, payload, event_id, parent_id, kwargs) | |
payload = payload or {} | |
payload = {k.name: v for k, v in payload.items()} | |
print("LOG EV", event_id, parent_id, self._weave_trace_id) | |
gc.server.call_start( | |
CallStartReq( | |
start=StartedCallSchemaForInsert( | |
project_id=gc._project_id(), | |
id=event_id, | |
op_name=op_str, | |
trace_id=trace_id, | |
parent_id=parent_id, | |
started_at=datetime.datetime.now(), | |
attributes={}, | |
inputs=to_json(payload, gc._project_id(), gc.server), | |
) | |
) | |
) | |
return event_id | |
def on_event_end( | |
self, | |
event_type: CBEventType, | |
payload: Optional[Dict[str, Any]] = None, | |
event_id: str = "", | |
**kwargs: Any, | |
) -> None: | |
"""Run when an event ends.""" | |
gc = graph_client_context.require_graph_client() | |
payload = payload or {} | |
payload = {k.name: v for k, v in payload.items()} | |
print("LOG END", event_id, self._weave_trace_id) | |
gc.server.call_end( | |
CallEndReq( | |
end=EndedCallSchemaForInsert( | |
project_id=gc._project_id(), | |
id=event_id, # type: ignore | |
ended_at=datetime.datetime.now(), | |
output=to_json(payload, gc._project_id(), gc.server), | |
summary={}, | |
) | |
) | |
) | |
token = self._event_tokens.pop(event_id) | |
run_context._run_stack.reset(token) | |
# print("EV END", event_type, payload, event_id, kwargs) | |
# def start_trace(self, trace_id: Optional[str] = None) -> None: | |
# """Run when an overall trace is launched.""" | |
# print("TR START", trace_id) | |
# self._weave_trace_id = generate_id() | |
# def end_trace( | |
# self, | |
# trace_id: Optional[str] = None, | |
# trace_map: Optional[Dict[str, List[str]]] = None, | |
# ) -> None: | |
# """Run when an overall trace is exited.""" | |
# print("TR END", trace_id, trace_map) | |
# self._weave_trace_id = None | |
import llama_index.core | |
llama_index.core.global_handler = WeaveCallbackHandler() | |
# Manually save or load a vector store | |
PERSIST_DIR = "./storage" | |
if not os.path.exists(PERSIST_DIR): | |
# load the documents and create the index | |
documents = SimpleDirectoryReader("data").load_data() | |
index = VectorStoreIndex.from_documents(documents) | |
# store it for later | |
index.storage_context.persist(persist_dir=PERSIST_DIR) | |
else: | |
# load the existing index | |
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR) | |
index = load_index_from_storage(storage_context) | |
weave.init("shawn/weave-llama-index3") | |
# do a llama-index query without a Weave Model | |
query_engine = index.as_query_engine() | |
response = query_engine.query("What did the author do growing up?") | |
print(response) | |
# Or using a Weave Model so we can version Index, parameters, and code | |
class LLamaIndexEngineModel(weave.Model): | |
index: BaseIndex | |
# Other parameters here | |
@weave.op() | |
def query(self, query: str) -> Any: | |
engine = self.index.as_query_engine() # can pass other model params in here | |
return engine.query(query) | |
model = LLamaIndexEngineModel(index=index) | |
response = model.query("What did the author do growing up?") | |
print("RESPONSE", response) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment