Skip to content

Instantly share code, notes, and snippets.

@paprikati
Last active March 21, 2025 15:01
Show Gist options
  • Save paprikati/87a3e66374eaaa83b7ff4e1270339aab to your computer and use it in GitHub Desktop.
Save paprikati/87a3e66374eaaa83b7ff4e1270339aab to your computer and use it in GitHub Desktop.
package copilot
import (
"context"
"encoding/json"
"reflect"
"sync"
"github.com/incident-io/core/server/lib/errors"
"github.com/incident-io/core/server/lib/log"
"github.com/incident-io/core/server/lib/safe"
"github.com/incident-io/core/server/pkg/rbac"
goopenai "github.com/sashabaranov/go-openai"
"gorm.io/gorm"
)
// ResultSpeculator allows speculatively executing the next iteration of prompt based on
// speculative results provided by tool calls.
//
// The aim is to:
// - Create a speculator at the start of a prompt iteration that calls tools
// - Register all the tool calls for this iteration.
// - If those tools provide speculative results, and we get a result from all tools, we
// can proceed to optimistically execute the next iteration of the prompt.
// - If the tools return results that match the speculative results, we can use whatever
// our optimistic execution produced as the next prompt iteration, often allowing us to
// return to the calling prompt immediately where our optimistic evaluation completed
// before the tool call finished.
//
// This tends to save ~3s of latency per tool call.
type ResultSpeculator struct {
ctx context.Context
cancel context.CancelFunc
completionRequest goopenai.ChatCompletionRequest
toolCalls []goopenai.ToolCall
toolCallResults []goopenai.ChatCompletionMessage
toolCallsRemaining int
runner ResultSpeculatorFunc
running bool
response goopenai.ChatCompletionResponse
done chan error
sync.Mutex
}
// ResultSpeculatorFunc is called with a completionRequest incorporating the speculative tool
// call responses. What it returns is used as the next prompt iteration result, and
// accessed via the WaitIfRunning method.
type ResultSpeculatorFunc func(
ctx context.Context, completionRequest goopenai.ChatCompletionRequest,
) (
goopenai.ChatCompletionResponse, error,
)
func NewResultSpeculator(ctx context.Context, completionRequest goopenai.ChatCompletionRequest, toolCalls []goopenai.ToolCall, runner ResultSpeculatorFunc) *ResultSpeculator {
ctx, cancel := context.WithCancel(ctx)
return &ResultSpeculator{
ctx: ctx,
cancel: cancel,
completionRequest: completionRequest,
toolCalls: toolCalls,
toolCallResults: make([]goopenai.ChatCompletionMessage, len(toolCalls)),
toolCallsRemaining: len(toolCalls),
runner: runner,
done: make(chan error, 1),
}
}
// Receive is called with the index of the tool call and the result of that tool call, if
// the tool has provided a speculative result.
func (s *ResultSpeculator) Receive(idx int, result goopenai.ChatCompletionMessage) {
if idx < 0 || idx >= len(s.toolCalls) {
return
}
s.Lock()
defer s.Unlock()
if s.toolCallsRemaining <= 0 || s.running {
return
}
s.toolCallResults[idx] = result
s.toolCallsRemaining--
if s.toolCallsRemaining > 0 {
return
}
completionRequest := s.completionRequest
completionRequest.Messages = append(completionRequest.Messages, s.toolCallResults...)
s.running = true
safe.Go(func() {
defer close(s.done)
var err error
s.response, err = s.runner(s.ctx, completionRequest)
if err != nil {
s.done <- err
}
})
}
// WaitIfRunning waits for the speculative result to be collected, or returns immediately
// if the speculative result has already been collected, or one was never started.
func (s *ResultSpeculator) WaitIfRunning(ctx context.Context) (goopenai.ChatCompletionResponse, bool, error) {
// Whatever happens, we should implicitly cancel our context if this is called, as the
// speculative result is only collected once and we never want to forget to cancel.
defer s.cancel()
if !s.running {
return s.response, false, nil
}
return s.wait(ctx)
}
func (s *ResultSpeculator) wait(ctx context.Context) (goopenai.ChatCompletionResponse, bool, error) {
select {
case <-ctx.Done():
return goopenai.ChatCompletionResponse{}, false, ctx.Err()
case err := <-s.done:
if err != nil {
return goopenai.ChatCompletionResponse{}, false, err
}
return s.response, true, nil
}
}
// Cancel the speculative execution, terminating the context and ending any on-going
// operations.
func (s *ResultSpeculator) Cancel() {
s.cancel()
}
// ToolCallSpeculator allows speculatively executing a tool call, so that if we *do* decide to
// call that tool, we've already started.
//
// The aim is to:
// - Create a speculator at the start of a prompt iteration that calls any tools we think
// might be relevant, based on some simple checks (s.g. keyword matching).
// - Run those tools, blocking any system actions (s.g. updating an incident) using a
// write barrier.
// - If the prompt then decides to call this tool with matching parameters to our guess
// we can unblock the write barrier and then return the tool result as normal.
//
// This should save ~3s of latency for interactions that need to call tools, provided our
// guesses are reasonably accurate.
type ToolCallSpeculator struct {
Tool ToolDefinition
Arguments string
AISpanID *string
wb *ToolWriteBarrier
once sync.Once
result *ToolResult[any]
err error
done chan struct{}
}
func NewToolCallSpeculator(tool ToolDefinition, args string, parent *ToolWriteBarrier) (*ToolCallSpeculator, func()) {
wb, cancelWb := NewToolWriteBarrier(parent)
s := &ToolCallSpeculator{
done: make(chan struct{}),
wb: wb,
Tool: tool,
Arguments: args,
}
cleanup := func() {
cancelWb()
s.once.Do(func() {
close(s.done)
})
}
return s, cleanup
}
func (s *ToolCallSpeculator) ToolName() string {
return s.Tool.Function.Def.Name
}
func (s *ToolCallSpeculator) Run(ctx context.Context, db *gorm.DB, identity *rbac.Identity, receiveSpeculativeResult func(any)) {
// If we get a panic, we want to return it like a normal error so we can pass it back up the channel
defer func() {
var err error
if errors.RecoverPanic(recover(), &err) {
s.once.Do(func() {
s.err = err
close(s.done)
})
}
}()
result, err := s.Tool.Function.Run(ctx, db, identity, s.Arguments, s.wb, receiveSpeculativeResult)
// If context is already cancelled, treat that as the result
if ctx.Err() != nil {
err = ctx.Err()
result = nil
}
s.once.Do(func() {
s.result = result
s.err = err
close(s.done)
})
}
func (s *ToolCallSpeculator) Wait(ctx context.Context) (*ToolResult[any], error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-s.done:
return s.result, s.err
}
}
// Matches checks if a given tool call (that's been returned by running a prompt in OpenAI)
// matches our speculator's tool and arguments. This will tell us if we can use the results
// of this speculator instead of calling the actual tool.
func (s *ToolCallSpeculator) Matches(ctx context.Context, toolCall goopenai.ToolCall) bool {
sameTool := s.ToolName() == toolCall.Function.Name
if !sameTool {
return false
}
var speculatorArgs, toolCallArgs map[string]any
if err := json.Unmarshal([]byte(s.Arguments), &speculatorArgs); err != nil {
log.Warn(ctx, errors.Wrap(ctx, err, "unmarshaling speculator arguments"))
// Safer to assume it didn't match!
return false
}
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &toolCallArgs); err != nil {
log.Warn(ctx, errors.Wrap(ctx, err, "unmarshaling tool call arguments"))
// Safer to assume it didn't match!
return false
}
// Remove any keys with the magic prefix `debug_` from the speculator arguments
stripDebugKeys := func(m map[string]any) map[string]any {
for k := range m {
if len(k) > 6 && k[:6] == "debug_" {
delete(m, k)
}
}
return m
}
return reflect.DeepEqual(
stripDebugKeys(speculatorArgs),
stripDebugKeys(toolCallArgs),
)
}
// ToolWriteBarrier is passed into tool calls and should be Wait()'d on before the tool
// makes any side-effective writes or actions (s.g. updating database records, sending
// messages, taking incident actions).
//
// It is used to ensure tools don't commit to taking an action before we know if we want
// to call them, helping us safely speculatively execute tools.
type ToolWriteBarrier struct {
once sync.Once
done chan struct{}
err error
parent *ToolWriteBarrier
}
// NewToolWriteBarrier creates a new barrier that initially blocks writes, until
// given permission to proceed.
func NewToolWriteBarrier(parent *ToolWriteBarrier) (*ToolWriteBarrier, func()) {
b := &ToolWriteBarrier{
done: make(chan struct{}),
parent: parent,
}
cleanup := func() {
// Use sync.Once to protect the close
b.once.Do(func() {
close(b.done)
})
}
return b, cleanup
}
// Wait blocks until we are permitted to make a write (s.g. have a side-effect, such as
// sending messages).
//
// Returns a nil error if permission to write is granted, otherwise returns an error which
// should be passed back up the stack.
func (t *ToolWriteBarrier) Wait(ctx context.Context) error {
if t == nil {
return nil
}
if t.parent != nil {
err := t.parent.Wait(ctx)
if err != nil {
return errors.Wrap(ctx, err, "parent write barrier")
}
}
select {
case <-ctx.Done():
return ctx.Err()
case <-t.done:
return t.err
}
}
// Flush permits us to continue running the tool, allowing writes to proceed.
func (t *ToolWriteBarrier) Flush() {
t.once.Do(func() {
t.err = nil
close(t.done)
})
}
// Deny causes all waiting writers to fail with the given error
func (t *ToolWriteBarrier) Deny(err error) {
t.once.Do(func() {
t.err = err
close(t.done)
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment