Skip to content

Instantly share code, notes, and snippets.

@Struki84
Last active July 13, 2023 18:24
Show Gist options
  • Save Struki84/6cbd1d9796ae4eaeb919a336143c6a81 to your computer and use it in GitHub Desktop.
Save Struki84/6cbd1d9796ae4eaeb919a336143c6a81 to your computer and use it in GitHub Desktop.
// An example of my custom DB adapter, a gorm based wrapper for a postgres database,
// implements the DBAdapter interface.
package main
import (
"database/sql/driver"
"encoding/json"
"errors"
"github.com/tmc/langchaingo/schema"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
var ErrDBConnection = errors.New("can't connect to database")
var ErrDBMigration = errors.New("can't migrate database")
var ErrMissingSessionID = errors.New("session id can not be empty")
type ChatHistory struct {
ID int `gorm:"primary_key"`
SessionID string `gorm:"type:varchar(256)"`
BufferString string `gorm:"type:text"`
ChatHistory *Messages `json:"chat_history" gorm:"type:jsonb;column:chat_history"`
}
type Message struct {
Type string `json:"type"`
Text string `json:"text"`
}
type Messages []Message
// Value implements the driver.Valuer interface, this method allows us to
// customize how we store the Message type in the database.
func (m Messages) Value() (driver.Value, error) {
return json.Marshal(m)
}
// Scan implements the sql.Scanner interface, this method allows us to
// define how we convert the Message data from the database into our Message type.
func (m *Messages) Scan(src interface{}) error {
if bytes, ok := src.([]byte); ok {
return json.Unmarshal(bytes, m)
}
return errors.New("could not scan type into Message")
}
type PostgreAdapter struct {
gorm *gorm.DB
history *ChatHistory
sessionID string
}
func NewPostgreAdapter() (*PostgreAdapter, error) {
adapter := &PostgreAdapter{
history: &ChatHistory{},
}
dsn := ""
gorm, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return nil, ErrDBConnection
}
adapter.gorm = gorm
err = adapter.gorm.AutoMigrate(ChatHistory{})
if err != nil {
return nil, ErrDBMigration
}
return adapter, nil
}
func (adapter *PostgreAdapter) SetSessionId(id string) {
adapter.sessionID = id
}
func (adapter *PostgreAdapter) GetSessionId() string {
return adapter.sessionID
}
func (adapter *PostgreAdapter) SaveDBContext(id string, msgs []schema.ChatMessage, bufferString string) error {
if adapter.sessionID == "" {
return ErrMissingSessionID
}
newMsgs := Messages{}
for _, msg := range msgs {
newMsgs = append(newMsgs, Message{
Type: string(msg.GetType()),
Text: msg.GetText(),
})
}
adapter.history.SessionID = adapter.sessionID
adapter.history.ChatHistory = &newMsgs
adapter.history.BufferString = bufferString
err := adapter.gorm.Save(&adapter.history).Error
if err != nil {
return err
}
return nil
}
func (adapter *PostgreAdapter) LoadDBMemory(id string) ([]schema.ChatMessage, error) {
// You implement your custom retrival logic here
return nil, nil
}
func (adapter *PostgreAdapter) ClearDBContext(id string) error {
// You implement your custom delete logic here
return nil
}
func main() {
postgreAdapter, err := NewPostgreAdapter()
persistentMemoryBuffer := memory.NewPersistentBuffer(postgreAdapter)
persistentMemoryBuffer.DB.SetSessionId("USID-001")
llm, err := openai.New()
if err != nil {
log.Print(err)
}
serpapi, err := serpapi.New()
if err != nil {
log.Print(err)
}
executor, err := agents.Initialize(
llm,
[]tools.Tool{serpapi},
agents.ZeroShotReactDescription,
agents.WithMemory(persistentMemoryBuffer),
agents.WithMaxIterations(3),
)
if err != nil {
log.Print(err)
}
input := "Who is the current CEO of Twitter?"
answer, err := chains.Run(context.Background(), executor, input)
if err != nil {
log.Print(err)
return
}
log.Print(answer)
}
package memory
import (
"github.com/tmc/langchaingo/memory"
"github.com/tmc/langchaingo/schema"
)
// Persistent Buffer defines DBAdapter interface for storing memory data in a DB,
// the interface allows developers to use any DB types as long as they implement the DBAdapter interface.
type DBAdapter interface {
LoadDBMemory(id string) ([]schema.ChatMessage, error)
SaveDBContext(id string, msgs []schema.ChatMessage, bufferString string) error
ClearDBContext(id string) error
SetSessionId(id string)
GetSessionId() string
}
// Persistent Buffer uses DBAdapter interface to save, clear and retrive data from a DB,
// while implementing Memory interface.
type PersistentBuffer struct {
ChatHistory *memory.ChatMessageHistory
DB DBAdapter
ReturnMessages bool
InputKey string
OutputKey string
HumanPrefix string
AIPrefix string
MemoryKey string
}
var _ schema.Memory = PersistentBuffer{}
func NewPersistentBuffer(dbAdapter DBAdapter) *PersistentBuffer {
buffer := PersistentBuffer{
ChatHistory: memory.NewChatMessageHistory(),
DB: dbAdapter,
ReturnMessages: false,
InputKey: "",
OutputKey: "",
HumanPrefix: "Human",
AIPrefix: "AI",
MemoryKey: "history",
}
return &buffer
}
func (buffer PersistentBuffer) MemoryVariables() []string {
return []string{buffer.MemoryKey}
}
func (buffer PersistentBuffer) LoadMemoryVariables(inputs map[string]any) (map[string]any, error) {
sessionID := buffer.DB.GetSessionId()
msgs, err := buffer.DB.LoadDBMemory(sessionID)
if err != nil {
return nil, err
}
buffer.ChatHistory = memory.NewChatMessageHistory(
memory.WithPreviousMessages(msgs),
)
if buffer.ReturnMessages {
return map[string]any{
buffer.MemoryKey: buffer.ChatHistory.Messages(),
}, nil
}
bufferString, err := schema.GetBufferString(buffer.ChatHistory.Messages(), buffer.HumanPrefix, buffer.AIPrefix)
if err != nil {
return nil, err
}
return map[string]any{
buffer.MemoryKey: bufferString,
}, nil
}
func (buffer PersistentBuffer) SaveContext(inputs map[string]any, outputs map[string]any) error {
sessionID := buffer.DB.GetSessionId()
userInputValue, err := getInputValue(inputs, buffer.InputKey)
if err != nil {
return err
}
buffer.ChatHistory.AddUserMessage(userInputValue)
aiOutPutValue, err := getInputValue(outputs, buffer.OutputKey)
if err != nil {
return err
}
buffer.ChatHistory.AddAIMessage(aiOutPutValue)
bufferString, err := schema.GetBufferString(buffer.ChatHistory.Messages(), buffer.HumanPrefix, buffer.AIPrefix)
if err != nil {
return err
}
msgs := buffer.ChatHistory.Messages()
err = buffer.DB.SaveDBContext(sessionID, msgs, bufferString)
if err != nil {
return err
}
return nil
}
func (buffer PersistentBuffer) Clear() error {
sessionID := buffer.DB.GetSessionId()
err := buffer.DB.ClearDBContext(sessionID)
if err != nil {
return err
}
buffer.ChatHistory.Clear()
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment