Skip to content

Instantly share code, notes, and snippets.

@eddking
Last active July 17, 2024 18:41
Show Gist options
  • Save eddking/330ca6f304b2f7293cf4031af123a328 to your computer and use it in GitHub Desktop.
Save eddking/330ca6f304b2f7293cf4031af123a328 to your computer and use it in GitHub Desktop.
Streaming Function Calling with OpenAI assistants API
import { Assistant } from "openai/resources/beta/assistants/assistants";
import { Message } from "openai/resources/beta/threads/messages/messages";
import { RequiredActionFunctionToolCall } from "openai/resources/beta/threads/runs/runs";
import { Thread } from "openai/resources/beta/threads/threads";
import { searchInput } from "src/handlers/search";
import { adminUpdateAiThreadExtra } from "src/operations/ai_thread";
import { openai } from "../lib/openai";
import {
getMessageContent,
parsePlain,
recordAssistantMessage,
updateAssistantMessage,
} from "./message";
import { InternalThread } from "./threads";
import {
handleLinkToolCall,
handleSearchToolCall,
linkInput,
ToolCallContext,
} from "./tools";
import { IdMapper } from "./idMapper";
import { InternalAssistant } from "./assistant";
import { AiRunEvent } from "models";
export const runThread = async (
thread: Thread,
assistant: Assistant,
internalAssistant: InternalAssistant,
internalThread: InternalThread,
toolCallContext: ToolCallContext,
recordEvent: (event: AiRunEvent) => void
) => {
const idMapper = new IdMapper(internalThread.extra.id_map);
// Something thats gonna resolve once the message is created
const inProgressMessages: Record<
string,
Promise<{ id: string } | null | undefined>
> = {};
const updateMessage = async (message: Message) => {
try {
const result = await inProgressMessages[message.id];
if (!result) {
console.log("No result for message, skipping: ", message.id);
return;
}
await updateAssistantMessage(message, idMapper, result.id);
} catch (e) {
console.error("Error updating message: ", e);
}
};
const currentTime = () => new Date().toISOString();
const executeToolCall = async (toolCall: RequiredActionFunctionToolCall) => {
const name = toolCall.function.name;
const args = JSON.parse(toolCall.function.arguments);
switch (name) {
case "search":
const searchArgs = searchInput.parse(args);
const results = await handleSearchToolCall(
searchArgs,
toolCallContext,
idMapper
);
return JSON.stringify(results);
case "link_record":
const linkArgs = linkInput.parse(args);
return await handleLinkToolCall(linkArgs, toolCallContext, idMapper);
default:
throw new Error("Unknown tool call: " + name);
}
};
try {
let stream = openai.beta.threads.runs.createAndStream(thread.id, {
assistant_id: assistant.id,
});
let done = false;
while (!done) {
stream.on("messageCreated", async (message) => {
const currentRunId = stream.currentRun()!.id;
const createPromise = recordAssistantMessage(
message,
idMapper,
internalThread.id,
internalAssistant.id,
currentRunId,
toolCallContext.orgId
);
inProgressMessages[message.id] = createPromise;
const currentContent = getMessageContent(message);
const structured = await parsePlain(currentContent, idMapper);
const internalMessage = await createPromise;
recordEvent({
type: "messageCreated",
content: currentContent,
structured: structured,
internal_id: internalMessage?.id!,
external_id: message.id,
role: message.role,
created_at: new Date(message.created_at).toISOString(),
});
});
stream.on("messageDelta", async (_messageDelta, snapshot) => {
const currentContent = getMessageContent(snapshot);
const structured = await parsePlain(currentContent, idMapper);
const createResult = await inProgressMessages[snapshot.id];
recordEvent({
type: "messageDelta",
content: currentContent,
structured: structured,
external_id: snapshot.id,
internal_id: createResult?.id!,
created_at: new Date(snapshot.created_at).toISOString(),
role: snapshot.role,
});
});
stream.on("messageDone", async (message) => {
const finalContent = getMessageContent(message);
const structured = await parsePlain(finalContent, idMapper);
const createResult = await inProgressMessages[message.id];
updateMessage(message); // Update the internal message with the final content
recordEvent({
type: "messageDone",
content: finalContent,
structured: structured,
external_id: message.id,
internal_id: createResult?.id!,
created_at: new Date(message.created_at).toISOString(),
role: message.role,
});
});
stream.on("event", (event) => {
console.log("Event: ", event.event);
let internalEvent: object = { event: event.event };
switch (event.event) {
// Dont log these
case "thread.run.step.delta":
case "thread.message.created":
case "thread.message.in_progress":
case "thread.message.delta":
case "thread.message.completed":
case "thread.message.incomplete":
case "thread.run.queued":
case "thread.run.in_progress":
case "thread.run.cancelling":
case "thread.run.step.created":
case "thread.run.step.in_progress":
case "thread.run.step.completed":
case "thread.run.step.failed":
case "thread.run.step.cancelled":
case "thread.run.step.expired":
case "thread.created":
return;
case "thread.run.requires_action": // Logged separately below
return;
// Log these for posterity
case "thread.run.created":
recordEvent({
type: "runCreated",
created_at: new Date(event.data.created_at).toISOString(),
});
break;
case "thread.run.completed":
recordEvent({ type: "runCompleted", created_at: currentTime() });
break;
case "thread.run.failed":
recordEvent({ type: "runFailed", created_at: currentTime() });
break;
case "thread.run.cancelled":
recordEvent({ type: "runCancelled", created_at: currentTime() });
break;
case "thread.run.expired":
recordEvent({ type: "runExpired", created_at: currentTime() });
break;
case "error": // Logged after stream done
recordEvent({
type: "runError",
error: event.data,
created_at: currentTime(),
});
break;
default:
const _exhaustiveCheck: never = event;
}
});
// I was planning to start executing tool calls as they stream in, storing the result promises in a map
// Until the requires_action event comes in, then I would submit the results of all the tool calls
// But it seems like this isnt easily supported at the moment. the data is there internally somewhere
// but 'toolCallDone' event seems to give empty '' arguments for the tool call
await stream.done();
const currentRun = stream.currentRun();
if (!currentRun) {
recordEvent({
type: "runError",
error: "No Current Run",
created_at: currentTime(),
});
return;
}
const lastError = currentRun.last_error;
if (lastError) {
recordEvent({
type: "runError",
error: lastError,
created_at: currentTime(),
});
return;
}
if (!currentRun.required_action) {
// This happens at the end of every successful run
recordEvent({ type: "noActionRequired", created_at: currentTime() });
return;
}
// The assumption at the moment is that the only required action is to submit tool outputs
// This may change in the future
if (!currentRun.required_action.submit_tool_outputs) {
recordEvent({
type: "runError",
error: "No tool outputs required",
created_at: currentTime(),
});
return;
}
const toolCalls =
currentRun.required_action.submit_tool_outputs.tool_calls;
recordEvent({
type: "requiredAction",
toolCalls,
created_at: currentTime(),
});
const allToolResults = await Promise.all(
toolCalls.map(async (toolCall) => {
const toolCallId = toolCall.id;
try {
const result = await executeToolCall(toolCall);
console.log("Tool call result: ", toolCallId, result);
return { tool_call_id: toolCallId, output: result };
} catch (e: any) {
console.error("Error executing tool call: ", e);
return {
tool_call_id: toolCallId,
output: "Error: " + (e.message || e),
};
}
})
);
recordEvent({
type: "toolCallResults",
results: allToolResults,
created_at: currentTime(),
});
// Set up the next stream and loop back up to the top
stream = openai.beta.threads.runs.submitToolOutputsStream(
thread.id,
stream.currentRun()!.id,
{
stream: true,
tool_outputs: allToolResults,
}
);
}
} finally {
await adminUpdateAiThreadExtra(
{ id_map: idMapper.getState() },
internalThread.id
);
}
};
import {
AssistantTool,
FunctionTool,
} from "openai/resources/beta/assistants/assistants";
import { zodToJsonSchema } from "zod-to-json-schema";
import { searchInput, performSearch } from "src/handlers/search";
import { z } from "zod";
import { Language_Enum } from "src/generated/gql/graphql";
import { IdMapper } from "./idMapper";
import { isMissing } from "@util";
export interface ToolCallContext {
orgId: string;
userId: string;
orgLanguage: Language_Enum;
}
// Dont allow the AI to set these params, they are more for UI specific things
const searchToolInput = searchInput.omit({
includeFacets: true,
typeahead: true,
filter: true, // Filters are complicated, maybe we can test later
tags: true,
});
export const searchSchema = zodToJsonSchema(searchToolInput);
export const searchTool: FunctionTool = {
type: "function",
function: {
name: "search",
description:
"Search everything within the organization. If a search result has a target_id, " +
" and this differs from it's own formatted_id, it means it is embedded in another page. ",
parameters: searchSchema,
},
};
export const handleSearchToolCall = async (
input: z.infer<typeof searchToolInput>,
context: ToolCallContext,
idMapper: IdMapper
) => {
const { orgId, userId, orgLanguage } = context;
const searchContext = {
orgId,
userId,
orgLanguage,
admin: false,
ai: true,
};
const results = await performSearch(input, searchContext);
return results.results.map((x) => processSearchResult(x, idMapper));
};
const processSearchResult = (
result: any,
idMapper: IdMapper
): { document: unknown; target_id: string | null | undefined } => {
const document = result.document;
// Map ids into shortened ids
if (document["formatted_id"]) {
document["formatted_id"] = idMapper.get(document["formatted_id"]);
}
if (document["target_id"]) {
document["target_id"] = idMapper.get(document["target_id"]);
}
if (document["org_id"]) {
document["org_id"] = idMapper.get(document["org_id"]);
}
if (document["created_by"]) {
document["created_by"] = idMapper.get(document["created_by"]);
}
if (document["updated_by"]) {
document["updated_by"] = idMapper.get(document["updated_by"]);
}
if (document["id"]) {
document["id"] = idMapper.get(document["id"]);
}
if (document["created_at"]) {
delete document["created_at"];
}
if (document["updated_at"]) {
delete document["updated_at"];
}
if ("deleted_at" in document) {
// search results wont return deleted records anyway, no point paying for the tokens
delete document["deleted_at"];
}
// Sometimes the target_id is not in the document, but it is in the result
// if it has been mapped from a different column during indexing. e.g. it comes from an association
return {
...document,
target_id:
document["target_id"] ||
(result.target_id ? idMapper.get(result.target_id) : undefined),
};
};
export const linkInput = z.object({
formatted_id: z
.string()
.describe("The formatted_id or target_id of the record"),
});
export const linkSchema = zodToJsonSchema(linkInput);
export const linkTool: FunctionTool = {
type: "function",
function: {
name: "link_record",
description:
"Returns a url that you can use to display a record to the user, you must provide a formatted_id" +
"The returned url can be turned into a link with markdown syntax e.g. [testing](https://example.com).",
parameters: linkSchema,
},
};
export const RECORD_LINK_PREFIX = "https://nascent.com/link/";
export const handleLinkToolCall = async (
input: z.infer<typeof linkInput>,
context: ToolCallContext,
idMapper: IdMapper
) => {
const actual = idMapper.getInverse(input.formatted_id);
if (isMissing(actual) || !actual.includes(":")) {
throw new Error(
"Invalid id. it must be from a formatted_id or target_id field"
);
}
return RECORD_LINK_PREFIX + input.formatted_id;
};
export const tools: AssistantTool[] = [searchTool, linkTool];
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment