Skip to content

Instantly share code, notes, and snippets.

@shawnlewis
Last active April 25, 2024 00:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shawnlewis/ea95a99eade3ed5260008ade78f672f8 to your computer and use it in GitHub Desktop.
Save shawnlewis/ea95a99eade3ed5260008ade78f672f8 to your computer and use it in GitHub Desktop.
LLama index weave integration starter code
from typing import Any
import datetime
import weave
import copy
from llama_index.legacy.callbacks.base_handler import BaseCallbackHandler
from llama_index.core.indices.base import BaseIndex
import os.path
from llama_index.core import (
VectorStoreIndex,
SimpleDirectoryReader,
StorageContext,
load_index_from_storage,
)
from typing import Any, Dict, List, Optional
from llama_index.legacy.callbacks.schema import CBEventType
from weave import run_context
from weave.weave_client import generate_id, Call
from weave import graph_client_context
from weave.trace.serialize import to_json
from weave.trace_server.trace_server_interface import (
StartedCallSchemaForInsert,
CallStartReq,
CallEndReq,
EndedCallSchemaForInsert,
)
class WeaveCallbackHandler(BaseCallbackHandler):
"""Base callback handler that can be used to track event starts and ends."""
def __init__(
self,
event_starts_to_ignore: Optional[List[CBEventType]] = None,
event_ends_to_ignore: Optional[List[CBEventType]] = None,
) -> None:
event_starts_to_ignore = (
event_starts_to_ignore if event_starts_to_ignore else []
)
event_ends_to_ignore = event_ends_to_ignore if event_ends_to_ignore else []
self._ops = {}
self._event_tokens = {}
super().__init__(
event_starts_to_ignore=event_starts_to_ignore,
event_ends_to_ignore=event_ends_to_ignore,
)
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""
gc = graph_client_context.require_graph_client()
event_type_name = event_type.name.lower()
# Weird stuff we have to do to make an anonymous op.
op = self._ops.get(event_type_name)
if not op:
@weave.op()
def resolve_fn():
return "fake_op"
resolve_fn.name = event_type_name
op = resolve_fn
self._ops[event_type_name] = op
op_def_ref = gc._save_op(op)
op_str = op_def_ref.uri()
# if self._weave_trace_id is None:
# raise ValueError("Trace not started")
# TODO: the below might be nicer if we used the client API instead of
# server API. Otherwise, we should refactor the API a little bit to make
# this easier. Doing stack manipulation here feels bad.
# parent weave call
current_run = run_context.get_current_run()
if current_run and current_run.id:
parent_id = current_run.id
trace_id = current_run.trace_id
else:
parent_id = None
trace_id = generate_id()
# have to manually append to stack :(
new_stack = copy.copy(run_context._run_stack.get())
call = Call(
project_id=gc._project_id(),
id=event_id,
op_name=op_str,
trace_id=trace_id,
parent_id=parent_id,
inputs=to_json(payload, gc._project_id(), gc.server),
)
new_stack.append(call)
self._event_tokens[event_id] = run_context._run_stack.set(new_stack)
# print("EV START", event_type, payload, event_id, parent_id, kwargs)
payload = payload or {}
payload = {k.name: v for k, v in payload.items()}
print("LOG EV", event_id, parent_id, self._weave_trace_id)
gc.server.call_start(
CallStartReq(
start=StartedCallSchemaForInsert(
project_id=gc._project_id(),
id=event_id,
op_name=op_str,
trace_id=trace_id,
parent_id=parent_id,
started_at=datetime.datetime.now(),
attributes={},
inputs=to_json(payload, gc._project_id(), gc.server),
)
)
)
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Run when an event ends."""
gc = graph_client_context.require_graph_client()
payload = payload or {}
payload = {k.name: v for k, v in payload.items()}
print("LOG END", event_id, self._weave_trace_id)
gc.server.call_end(
CallEndReq(
end=EndedCallSchemaForInsert(
project_id=gc._project_id(),
id=event_id, # type: ignore
ended_at=datetime.datetime.now(),
output=to_json(payload, gc._project_id(), gc.server),
summary={},
)
)
)
token = self._event_tokens.pop(event_id)
run_context._run_stack.reset(token)
# print("EV END", event_type, payload, event_id, kwargs)
# def start_trace(self, trace_id: Optional[str] = None) -> None:
# """Run when an overall trace is launched."""
# print("TR START", trace_id)
# self._weave_trace_id = generate_id()
# def end_trace(
# self,
# trace_id: Optional[str] = None,
# trace_map: Optional[Dict[str, List[str]]] = None,
# ) -> None:
# """Run when an overall trace is exited."""
# print("TR END", trace_id, trace_map)
# self._weave_trace_id = None
import llama_index.core
llama_index.core.global_handler = WeaveCallbackHandler()
# Manually save or load a vector store
PERSIST_DIR = "./storage"
if not os.path.exists(PERSIST_DIR):
# load the documents and create the index
documents = SimpleDirectoryReader("data").load_data()
index = VectorStoreIndex.from_documents(documents)
# store it for later
index.storage_context.persist(persist_dir=PERSIST_DIR)
else:
# load the existing index
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
weave.init("shawn/weave-llama-index3")
# do a llama-index query without a Weave Model
query_engine = index.as_query_engine()
response = query_engine.query("What did the author do growing up?")
print(response)
# Or using a Weave Model so we can version Index, parameters, and code
class LLamaIndexEngineModel(weave.Model):
index: BaseIndex
# Other parameters here
@weave.op()
def query(self, query: str) -> Any:
engine = self.index.as_query_engine() # can pass other model params in here
return engine.query(query)
model = LLamaIndexEngineModel(index=index)
response = model.query("What did the author do growing up?")
print("RESPONSE", response)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment