Skip to content

Instantly share code, notes, and snippets.

@huw
Created October 27, 2023 09:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save huw/bacd23f8904a3c541fb93bcdcaf1b75d to your computer and use it in GitHub Desktop.
Save huw/bacd23f8904a3c541fb93bcdcaf1b75d to your computer and use it in GitHub Desktop.
Cloudflare Workers AI Local Polyfill

This is a simple polyfill for Cloudflare Workers AI in wrangler dev --local mode. It doesn't actually run the models locally—it just reaches out to Cloudflare via their APIs. I think this is a reasonable compromise for local development while we wait.

I've only implemented and tested this for embeddings models, other models shouldn't be too hard but are left as an exercise for the reader. If you do implement them let me know and I'll update the gist :)

To polyfill it, I use wrangler dev --define IS_LOCAL:true in local mode, and then have something like:

import { Ai as RemoteAi } from "@cloudflare/ai";

const Ai = IS_LOCAL ? new LocalAi(env.AI_GATEWAY_URL, env.AI_API_KEY) : new RemoteAi(env.AI);

If you're reading this in the future, I built this in October 2023 and it will likely become obsolete pretty soon. Brendan & the Miniflare team are pretty fast at adding stuff to local mode!

/**
* @file A polyfill for `@cloudflare/ai` that connects to live services in local environments via the Cloudflare API instead of direct bindings.
*
* This only implements embeddings for now, but could be expanded to other models just by extending the types (and adding streaming support).
*/
type ModelName = "@cf/baai/bge-small-en-v1.5" | "@cf/baai/bge-base-en-v1.5" | "@cf/baai/bge-large-en-v1.5";
interface AiTextEmbeddingsInput {
text: string | string[];
}
interface AiTextEmbeddingsOutput {
shape: number[];
data: number[][];
}
type ConstructorParametersForModel = AiTextEmbeddingsInput;
type OutputForModel = AiTextEmbeddingsOutput;
interface AiOptions {
sessionOptions?: {
extraHeaders?: { [x: string]: unknown };
};
}
/**
* A common interface for this and `@cloudflare/ai`.
*
* Note that in practice you'll have to cast your `Ai` object to this interface if you only implement a subset of the models, because `Ai#run` doesn't use method overloading.
*/
export interface ImplementedAi {
run: <M extends ModelName>(model: M, inputs: ConstructorParametersForModel) => Promise<OutputForModel>;
}
interface CloudflareAPIErrorResponse {
success: false;
errors: { message: string }[];
}
interface CloudflareAPIResultResponse<T> {
success: true;
result: T;
}
type CloudflareAPIResponse<T> = CloudflareAPIErrorResponse | CloudflareAPIResultResponse<T>;
export class LocalAi implements ImplementedAi {
/**
* @param baseUrl Your Workers AI API URL, either in the format `https://api.cloudflare.com/client/v4/accounts/ACCOUNT_TAG/ai/run` or `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/workers-ai`.
* @param apiKey Your Workers AI [API Key](https://dash.cloudflare.com/profile/api-tokens).
*/
public constructor(
private readonly baseUrl: string,
private readonly apiKey: string,
private readonly options?: AiOptions,
) {}
public async run<M extends ModelName>(model: M, inputs: ConstructorParametersForModel): Promise<OutputForModel> {
const response = await fetch(`${this.baseUrl}/${model}`, {
method: "POST",
body: JSON.stringify(inputs),
headers: {
...this.options?.sessionOptions?.extraHeaders,
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
},
});
const value = (await response.json()) as CloudflareAPIResponse<OutputForModel>;
if (!value.success) {
throw new Error(value.errors.map((error) => error.message).join("\n"));
}
return value.result;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment