Last active
March 21, 2025 15:01
-
-
Save paprikati/87a3e66374eaaa83b7ff4e1270339aab to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package copilot | |
import ( | |
"context" | |
"encoding/json" | |
"reflect" | |
"sync" | |
"github.com/incident-io/core/server/lib/errors" | |
"github.com/incident-io/core/server/lib/log" | |
"github.com/incident-io/core/server/lib/safe" | |
"github.com/incident-io/core/server/pkg/rbac" | |
goopenai "github.com/sashabaranov/go-openai" | |
"gorm.io/gorm" | |
) | |
// ResultSpeculator allows speculatively executing the next iteration of prompt based on | |
// speculative results provided by tool calls. | |
// | |
// The aim is to: | |
// - Create a speculator at the start of a prompt iteration that calls tools | |
// - Register all the tool calls for this iteration. | |
// - If those tools provide speculative results, and we get a result from all tools, we | |
// can proceed to optimistically execute the next iteration of the prompt. | |
// - If the tools return results that match the speculative results, we can use whatever | |
// our optimistic execution produced as the next prompt iteration, often allowing us to | |
// return to the calling prompt immediately where our optimistic evaluation completed | |
// before the tool call finished. | |
// | |
// This tends to save ~3s of latency per tool call. | |
type ResultSpeculator struct { | |
ctx context.Context | |
cancel context.CancelFunc | |
completionRequest goopenai.ChatCompletionRequest | |
toolCalls []goopenai.ToolCall | |
toolCallResults []goopenai.ChatCompletionMessage | |
toolCallsRemaining int | |
runner ResultSpeculatorFunc | |
running bool | |
response goopenai.ChatCompletionResponse | |
done chan error | |
sync.Mutex | |
} | |
// ResultSpeculatorFunc is called with a completionRequest incorporating the speculative tool | |
// call responses. What it returns is used as the next prompt iteration result, and | |
// accessed via the WaitIfRunning method. | |
type ResultSpeculatorFunc func( | |
ctx context.Context, completionRequest goopenai.ChatCompletionRequest, | |
) ( | |
goopenai.ChatCompletionResponse, error, | |
) | |
func NewResultSpeculator(ctx context.Context, completionRequest goopenai.ChatCompletionRequest, toolCalls []goopenai.ToolCall, runner ResultSpeculatorFunc) *ResultSpeculator { | |
ctx, cancel := context.WithCancel(ctx) | |
return &ResultSpeculator{ | |
ctx: ctx, | |
cancel: cancel, | |
completionRequest: completionRequest, | |
toolCalls: toolCalls, | |
toolCallResults: make([]goopenai.ChatCompletionMessage, len(toolCalls)), | |
toolCallsRemaining: len(toolCalls), | |
runner: runner, | |
done: make(chan error, 1), | |
} | |
} | |
// Receive is called with the index of the tool call and the result of that tool call, if | |
// the tool has provided a speculative result. | |
func (s *ResultSpeculator) Receive(idx int, result goopenai.ChatCompletionMessage) { | |
if idx < 0 || idx >= len(s.toolCalls) { | |
return | |
} | |
s.Lock() | |
defer s.Unlock() | |
if s.toolCallsRemaining <= 0 || s.running { | |
return | |
} | |
s.toolCallResults[idx] = result | |
s.toolCallsRemaining-- | |
if s.toolCallsRemaining > 0 { | |
return | |
} | |
completionRequest := s.completionRequest | |
completionRequest.Messages = append(completionRequest.Messages, s.toolCallResults...) | |
s.running = true | |
safe.Go(func() { | |
defer close(s.done) | |
var err error | |
s.response, err = s.runner(s.ctx, completionRequest) | |
if err != nil { | |
s.done <- err | |
} | |
}) | |
} | |
// WaitIfRunning waits for the speculative result to be collected, or returns immediately | |
// if the speculative result has already been collected, or one was never started. | |
func (s *ResultSpeculator) WaitIfRunning(ctx context.Context) (goopenai.ChatCompletionResponse, bool, error) { | |
// Whatever happens, we should implicitly cancel our context if this is called, as the | |
// speculative result is only collected once and we never want to forget to cancel. | |
defer s.cancel() | |
if !s.running { | |
return s.response, false, nil | |
} | |
return s.wait(ctx) | |
} | |
func (s *ResultSpeculator) wait(ctx context.Context) (goopenai.ChatCompletionResponse, bool, error) { | |
select { | |
case <-ctx.Done(): | |
return goopenai.ChatCompletionResponse{}, false, ctx.Err() | |
case err := <-s.done: | |
if err != nil { | |
return goopenai.ChatCompletionResponse{}, false, err | |
} | |
return s.response, true, nil | |
} | |
} | |
// Cancel the speculative execution, terminating the context and ending any on-going | |
// operations. | |
func (s *ResultSpeculator) Cancel() { | |
s.cancel() | |
} | |
// ToolCallSpeculator allows speculatively executing a tool call, so that if we *do* decide to | |
// call that tool, we've already started. | |
// | |
// The aim is to: | |
// - Create a speculator at the start of a prompt iteration that calls any tools we think | |
// might be relevant, based on some simple checks (s.g. keyword matching). | |
// - Run those tools, blocking any system actions (s.g. updating an incident) using a | |
// write barrier. | |
// - If the prompt then decides to call this tool with matching parameters to our guess | |
// we can unblock the write barrier and then return the tool result as normal. | |
// | |
// This should save ~3s of latency for interactions that need to call tools, provided our | |
// guesses are reasonably accurate. | |
type ToolCallSpeculator struct { | |
Tool ToolDefinition | |
Arguments string | |
AISpanID *string | |
wb *ToolWriteBarrier | |
once sync.Once | |
result *ToolResult[any] | |
err error | |
done chan struct{} | |
} | |
func NewToolCallSpeculator(tool ToolDefinition, args string, parent *ToolWriteBarrier) (*ToolCallSpeculator, func()) { | |
wb, cancelWb := NewToolWriteBarrier(parent) | |
s := &ToolCallSpeculator{ | |
done: make(chan struct{}), | |
wb: wb, | |
Tool: tool, | |
Arguments: args, | |
} | |
cleanup := func() { | |
cancelWb() | |
s.once.Do(func() { | |
close(s.done) | |
}) | |
} | |
return s, cleanup | |
} | |
func (s *ToolCallSpeculator) ToolName() string { | |
return s.Tool.Function.Def.Name | |
} | |
func (s *ToolCallSpeculator) Run(ctx context.Context, db *gorm.DB, identity *rbac.Identity, receiveSpeculativeResult func(any)) { | |
// If we get a panic, we want to return it like a normal error so we can pass it back up the channel | |
defer func() { | |
var err error | |
if errors.RecoverPanic(recover(), &err) { | |
s.once.Do(func() { | |
s.err = err | |
close(s.done) | |
}) | |
} | |
}() | |
result, err := s.Tool.Function.Run(ctx, db, identity, s.Arguments, s.wb, receiveSpeculativeResult) | |
// If context is already cancelled, treat that as the result | |
if ctx.Err() != nil { | |
err = ctx.Err() | |
result = nil | |
} | |
s.once.Do(func() { | |
s.result = result | |
s.err = err | |
close(s.done) | |
}) | |
} | |
func (s *ToolCallSpeculator) Wait(ctx context.Context) (*ToolResult[any], error) { | |
select { | |
case <-ctx.Done(): | |
return nil, ctx.Err() | |
case <-s.done: | |
return s.result, s.err | |
} | |
} | |
// Matches checks if a given tool call (that's been returned by running a prompt in OpenAI) | |
// matches our speculator's tool and arguments. This will tell us if we can use the results | |
// of this speculator instead of calling the actual tool. | |
func (s *ToolCallSpeculator) Matches(ctx context.Context, toolCall goopenai.ToolCall) bool { | |
sameTool := s.ToolName() == toolCall.Function.Name | |
if !sameTool { | |
return false | |
} | |
var speculatorArgs, toolCallArgs map[string]any | |
if err := json.Unmarshal([]byte(s.Arguments), &speculatorArgs); err != nil { | |
log.Warn(ctx, errors.Wrap(ctx, err, "unmarshaling speculator arguments")) | |
// Safer to assume it didn't match! | |
return false | |
} | |
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &toolCallArgs); err != nil { | |
log.Warn(ctx, errors.Wrap(ctx, err, "unmarshaling tool call arguments")) | |
// Safer to assume it didn't match! | |
return false | |
} | |
// Remove any keys with the magic prefix `debug_` from the speculator arguments | |
stripDebugKeys := func(m map[string]any) map[string]any { | |
for k := range m { | |
if len(k) > 6 && k[:6] == "debug_" { | |
delete(m, k) | |
} | |
} | |
return m | |
} | |
return reflect.DeepEqual( | |
stripDebugKeys(speculatorArgs), | |
stripDebugKeys(toolCallArgs), | |
) | |
} | |
// ToolWriteBarrier is passed into tool calls and should be Wait()'d on before the tool | |
// makes any side-effective writes or actions (s.g. updating database records, sending | |
// messages, taking incident actions). | |
// | |
// It is used to ensure tools don't commit to taking an action before we know if we want | |
// to call them, helping us safely speculatively execute tools. | |
type ToolWriteBarrier struct { | |
once sync.Once | |
done chan struct{} | |
err error | |
parent *ToolWriteBarrier | |
} | |
// NewToolWriteBarrier creates a new barrier that initially blocks writes, until | |
// given permission to proceed. | |
func NewToolWriteBarrier(parent *ToolWriteBarrier) (*ToolWriteBarrier, func()) { | |
b := &ToolWriteBarrier{ | |
done: make(chan struct{}), | |
parent: parent, | |
} | |
cleanup := func() { | |
// Use sync.Once to protect the close | |
b.once.Do(func() { | |
close(b.done) | |
}) | |
} | |
return b, cleanup | |
} | |
// Wait blocks until we are permitted to make a write (s.g. have a side-effect, such as | |
// sending messages). | |
// | |
// Returns a nil error if permission to write is granted, otherwise returns an error which | |
// should be passed back up the stack. | |
func (t *ToolWriteBarrier) Wait(ctx context.Context) error { | |
if t == nil { | |
return nil | |
} | |
if t.parent != nil { | |
err := t.parent.Wait(ctx) | |
if err != nil { | |
return errors.Wrap(ctx, err, "parent write barrier") | |
} | |
} | |
select { | |
case <-ctx.Done(): | |
return ctx.Err() | |
case <-t.done: | |
return t.err | |
} | |
} | |
// Flush permits us to continue running the tool, allowing writes to proceed. | |
func (t *ToolWriteBarrier) Flush() { | |
t.once.Do(func() { | |
t.err = nil | |
close(t.done) | |
}) | |
} | |
// Deny causes all waiting writers to fail with the given error | |
func (t *ToolWriteBarrier) Deny(err error) { | |
t.once.Do(func() { | |
t.err = err | |
close(t.done) | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment