Skip to content

Instantly share code, notes, and snippets.

@meyer1994
Last active June 4, 2024 19:42
Show Gist options
  • Save meyer1994/35348d4e2df8831a220c2f1cf8952c4e to your computer and use it in GitHub Desktop.
Save meyer1994/35348d4e2df8831a220c2f1cf8952c4e to your computer and use it in GitHub Desktop.
Simple SQL memory implementatio for llama-index
import datetime as dt
from typing import Any, List, Optional
import sqlalchemy as sa
import sqlalchemy.orm as orm
from llama_index.core.llms import ChatMessage
from llama_index.core.storage.chat_store import BaseChatStore
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.agent.openai import OpenAIAgent
class Base(orm.DeclarativeBase):
pass
class Message(Base):
__tablename__ = "messages"
id: orm.Mapped[int] = orm.mapped_column(sa.Integer, primary_key=True)
key: orm.Mapped[str] = orm.mapped_column(sa.String)
role: orm.Mapped[str] = orm.mapped_column(sa.String)
content: orm.Mapped[str] = orm.mapped_column(sa.String)
additional_kwargs: orm.Mapped[dict[str, Any]] = orm.mapped_column(sa.JSON)
# needed for sorting
created_at: orm.Mapped[dt.datetime] = orm.mapped_column(
sa.DateTime(timezone=True),
default=lambda: dt.datetime.now(dt.UTC),
)
def as_chat_message(self) -> ChatMessage:
return ChatMessage(
role=self.role, # type: ignore
content=self.content,
additional_kwargs=self.additional_kwargs,
)
class SQLChatStore(BaseChatStore):
session: orm.Session
class Config:
arbitrary_types_allowed = True
def set_messages(self, key: str, messages: List[ChatMessage]) -> None:
for i in messages:
self.add_message(key, i)
def _get_messages(self, key: str) -> List[Message]:
query = (
sa.select(Message)
.where(Message.key == key)
.order_by(Message.created_at.asc())
)
cursor = self.session.execute(query)
return [i for i in cursor.scalars()]
def get_messages(self, key: str) -> List[ChatMessage]:
messages = self._get_messages(key)
return [i.as_chat_message() for i in messages]
def add_message(self, key: str, message: ChatMessage) -> None:
msg = Message(
key=key,
role=message.role,
content=message.content,
additional_kwargs=message.additional_kwargs,
)
self.session.add(msg)
def delete_messages(self, key: str) -> Optional[List[ChatMessage]]:
messages = self.get_messages(key)
query = sa.delete(Message).where(Message.key == key)
self.session.execute(query)
return messages
def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
messages = self._get_messages(key)
to_delete = messages[idx]
query = sa.delete(Message).where(Message.key == to_delete.id)
self.session.execute(query)
return to_delete.as_chat_message()
def delete_last_message(self, key: str) -> Optional[ChatMessage]:
query = (
sa.select(Message)
.where(Message.key == key)
.order_by(Message.created_at.desc())
)
cursor = self.session.execute(query)
message = cursor.scalar_one_or_none()
if message is None:
return None
self.session.delete(message)
return message.as_chat_message()
def get_keys(self) -> List[str]:
query = sa.select(Message.key).distinct()
cursor = self.session.execute(query)
return [i for i in cursor.scalars()]
print("creating db")
engine = sa.create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
print("running agent")
with orm.Session(engine) as session:
store = SQLChatStore(session=session)
buffer = ChatMemoryBuffer.from_defaults(chat_store=store)
agent = OpenAIAgent.from_tools(memory=buffer)
res = agent.chat("Hello, world!")
print(res)
res = agent.chat("What?")
print(res)
print("checking memory")
query = sa.select(Message).order_by(Message.created_at.asc())
cursor = session.execute(query)
for i in cursor.scalars():
print(i.content)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment