Skip to content

Instantly share code, notes, and snippets.

@zhoukekestar
Created March 28, 2024 02:34
Show Gist options
  • Save zhoukekestar/daec061180101892c3b0c239df2e051a to your computer and use it in GitHub Desktop.
Save zhoukekestar/daec061180101892c3b0c239df2e051a to your computer and use it in GitHub Desktop.
@google/generative-ai for china
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Possible roles.
* @public
*/
import fetch from 'node-fetch';
import { HttpsProxyAgent } from 'https-proxy-agent';
const POSSIBLE_ROLES = ["user", "model", "function"];
/**
* Harm categories that would cause prompts or candidates to be blocked.
* @public
*/
var HarmCategory;
(function (HarmCategory) {
HarmCategory["HARM_CATEGORY_UNSPECIFIED"] = "HARM_CATEGORY_UNSPECIFIED";
HarmCategory["HARM_CATEGORY_HATE_SPEECH"] = "HARM_CATEGORY_HATE_SPEECH";
HarmCategory["HARM_CATEGORY_SEXUALLY_EXPLICIT"] = "HARM_CATEGORY_SEXUALLY_EXPLICIT";
HarmCategory["HARM_CATEGORY_HARASSMENT"] = "HARM_CATEGORY_HARASSMENT";
HarmCategory["HARM_CATEGORY_DANGEROUS_CONTENT"] = "HARM_CATEGORY_DANGEROUS_CONTENT";
})(HarmCategory || (HarmCategory = {}));
/**
* Threshold above which a prompt or candidate will be blocked.
* @public
*/
var HarmBlockThreshold;
(function (HarmBlockThreshold) {
// Threshold is unspecified.
HarmBlockThreshold["HARM_BLOCK_THRESHOLD_UNSPECIFIED"] = "HARM_BLOCK_THRESHOLD_UNSPECIFIED";
// Content with NEGLIGIBLE will be allowed.
HarmBlockThreshold["BLOCK_LOW_AND_ABOVE"] = "BLOCK_LOW_AND_ABOVE";
// Content with NEGLIGIBLE and LOW will be allowed.
HarmBlockThreshold["BLOCK_MEDIUM_AND_ABOVE"] = "BLOCK_MEDIUM_AND_ABOVE";
// Content with NEGLIGIBLE, LOW, and MEDIUM will be allowed.
HarmBlockThreshold["BLOCK_ONLY_HIGH"] = "BLOCK_ONLY_HIGH";
// All content will be allowed.
HarmBlockThreshold["BLOCK_NONE"] = "BLOCK_NONE";
})(HarmBlockThreshold || (HarmBlockThreshold = {}));
/**
* Probability that a prompt or candidate matches a harm category.
* @public
*/
var HarmProbability;
(function (HarmProbability) {
// Probability is unspecified.
HarmProbability["HARM_PROBABILITY_UNSPECIFIED"] = "HARM_PROBABILITY_UNSPECIFIED";
// Content has a negligible chance of being unsafe.
HarmProbability["NEGLIGIBLE"] = "NEGLIGIBLE";
// Content has a low chance of being unsafe.
HarmProbability["LOW"] = "LOW";
// Content has a medium chance of being unsafe.
HarmProbability["MEDIUM"] = "MEDIUM";
// Content has a high chance of being unsafe.
HarmProbability["HIGH"] = "HIGH";
})(HarmProbability || (HarmProbability = {}));
/**
* Reason that a prompt was blocked.
* @public
*/
var BlockReason;
(function (BlockReason) {
// A blocked reason was not specified.
BlockReason["BLOCKED_REASON_UNSPECIFIED"] = "BLOCKED_REASON_UNSPECIFIED";
// Content was blocked by safety settings.
BlockReason["SAFETY"] = "SAFETY";
// Content was blocked, but the reason is uncategorized.
BlockReason["OTHER"] = "OTHER";
})(BlockReason || (BlockReason = {}));
/**
* Reason that a candidate finished.
* @public
*/
var FinishReason;
(function (FinishReason) {
// Default value. This value is unused.
FinishReason["FINISH_REASON_UNSPECIFIED"] = "FINISH_REASON_UNSPECIFIED";
// Natural stop point of the model or provided stop sequence.
FinishReason["STOP"] = "STOP";
// The maximum number of tokens as specified in the request was reached.
FinishReason["MAX_TOKENS"] = "MAX_TOKENS";
// The candidate content was flagged for safety reasons.
FinishReason["SAFETY"] = "SAFETY";
// The candidate content was flagged for recitation reasons.
FinishReason["RECITATION"] = "RECITATION";
// Unknown reason.
FinishReason["OTHER"] = "OTHER";
})(FinishReason || (FinishReason = {}));
/**
* Task type for embedding content.
* @public
*/
var TaskType;
(function (TaskType) {
TaskType["TASK_TYPE_UNSPECIFIED"] = "TASK_TYPE_UNSPECIFIED";
TaskType["RETRIEVAL_QUERY"] = "RETRIEVAL_QUERY";
TaskType["RETRIEVAL_DOCUMENT"] = "RETRIEVAL_DOCUMENT";
TaskType["SEMANTIC_SIMILARITY"] = "SEMANTIC_SIMILARITY";
TaskType["CLASSIFICATION"] = "CLASSIFICATION";
TaskType["CLUSTERING"] = "CLUSTERING";
})(TaskType || (TaskType = {}));
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Contains the list of OpenAPI data types
* as defined by https://swagger.io/docs/specification/data-models/data-types/
* @public
*/
var FunctionDeclarationSchemaType;
(function (FunctionDeclarationSchemaType) {
/** String type. */
FunctionDeclarationSchemaType["STRING"] = "STRING";
/** Number type. */
FunctionDeclarationSchemaType["NUMBER"] = "NUMBER";
/** Integer type. */
FunctionDeclarationSchemaType["INTEGER"] = "INTEGER";
/** Boolean type. */
FunctionDeclarationSchemaType["BOOLEAN"] = "BOOLEAN";
/** Array type. */
FunctionDeclarationSchemaType["ARRAY"] = "ARRAY";
/** Object type. */
FunctionDeclarationSchemaType["OBJECT"] = "OBJECT";
})(FunctionDeclarationSchemaType || (FunctionDeclarationSchemaType = {}));
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
class GoogleGenerativeAIError extends Error {
constructor(message) {
super(`[GoogleGenerativeAI Error]: ${message}`);
}
}
class GoogleGenerativeAIResponseError extends GoogleGenerativeAIError {
constructor(message, response) {
super(message);
this.response = response;
}
}
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
const BASE_URL = "https://generativelanguage.googleapis.com";
const DEFAULT_API_VERSION = "v1";
/**
* We can't `require` package.json if this runs on web. We will use rollup to
* swap in the version number here at build time.
*/
const PACKAGE_VERSION = "0.3.1";
const PACKAGE_LOG_HEADER = "genai-js";
var Task;
(function (Task) {
Task["GENERATE_CONTENT"] = "generateContent";
Task["STREAM_GENERATE_CONTENT"] = "streamGenerateContent";
Task["COUNT_TOKENS"] = "countTokens";
Task["EMBED_CONTENT"] = "embedContent";
Task["BATCH_EMBED_CONTENTS"] = "batchEmbedContents";
})(Task || (Task = {}));
class RequestUrl {
constructor(model, task, apiKey, stream, requestOptions) {
this.model = model;
this.task = task;
this.apiKey = apiKey;
this.stream = stream;
this.requestOptions = requestOptions;
}
toString() {
var _a;
const apiVersion = ((_a = this.requestOptions) === null || _a === void 0 ? void 0 : _a.apiVersion) || DEFAULT_API_VERSION;
let url = `${BASE_URL}/${apiVersion}/${this.model}:${this.task}`;
if (this.stream) {
url += "?alt=sse";
}
return url;
}
}
/**
* Simple, but may become more complex if we add more versions to log.
*/
function getClientHeaders() {
return `${PACKAGE_LOG_HEADER}/${PACKAGE_VERSION}`;
}
async function makeRequest(url, body, requestOptions) {
let response;
try {
let agent;
if (process.env.http_proxy) {
agent = new HttpsProxyAgent(process.env.http_proxy);
}
response = await fetch(url.toString(), Object.assign(Object.assign({}, buildFetchOptions(requestOptions)), { method: "POST", headers: {
"Content-Type": "application/json",
"x-goog-api-client": getClientHeaders(),
"x-goog-api-key": url.apiKey,
}, body, agent }));
if (!response.ok) {
let message = "";
try {
const json = await response.json();
message = json.error.message;
if (json.error.details) {
message += ` ${JSON.stringify(json.error.details)}`;
}
}
catch (e) {
// ignored
}
throw new Error(`[${response.status} ${response.statusText}] ${message}`);
}
}
catch (e) {
const err = new GoogleGenerativeAIError(`Error fetching from ${url.toString()}: ${e.message}`);
err.stack = e.stack;
throw err;
}
return response;
}
/**
* Generates the request options to be passed to the fetch API.
* @param requestOptions - The user-defined request options.
* @returns The generated request options.
*/
function buildFetchOptions(requestOptions) {
const fetchOptions = {};
if ((requestOptions === null || requestOptions === void 0 ? void 0 : requestOptions.timeout) >= 0) {
const abortController = new AbortController();
const signal = abortController.signal;
setTimeout(() => abortController.abort(), requestOptions.timeout);
fetchOptions.signal = signal;
}
return fetchOptions;
}
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Adds convenience helper methods to a response object, including stream
* chunks (as long as each chunk is a complete GenerateContentResponse JSON).
*/
function addHelpers(response) {
response.text = () => {
if (response.candidates && response.candidates.length > 0) {
if (response.candidates.length > 1) {
console.warn(`This response had ${response.candidates.length} ` +
`candidates. Returning text from the first candidate only. ` +
`Access response.candidates directly to use the other candidates.`);
}
if (hadBadFinishReason(response.candidates[0])) {
throw new GoogleGenerativeAIResponseError(`${formatBlockErrorMessage(response)}`, response);
}
return getText(response);
}
else if (response.promptFeedback) {
throw new GoogleGenerativeAIResponseError(`Text not available. ${formatBlockErrorMessage(response)}`, response);
}
return "";
};
response.functionCall = () => {
if (response.candidates && response.candidates.length > 0) {
if (response.candidates.length > 1) {
console.warn(`This response had ${response.candidates.length} ` +
`candidates. Returning function call from the first candidate only. ` +
`Access response.candidates directly to use the other candidates.`);
}
if (hadBadFinishReason(response.candidates[0])) {
throw new GoogleGenerativeAIResponseError(`${formatBlockErrorMessage(response)}`, response);
}
return getFunctionCall(response);
}
else if (response.promptFeedback) {
throw new GoogleGenerativeAIResponseError(`Function call not available. ${formatBlockErrorMessage(response)}`, response);
}
return undefined;
};
return response;
}
/**
* Returns text of first candidate.
*/
function getText(response) {
var _a, _b, _c, _d;
if ((_d = (_c = (_b = (_a = response.candidates) === null || _a === void 0 ? void 0 : _a[0].content) === null || _b === void 0 ? void 0 : _b.parts) === null || _c === void 0 ? void 0 : _c[0]) === null || _d === void 0 ? void 0 : _d.text) {
return response.candidates[0].content.parts
.map(({ text }) => text)
.join("");
}
else {
return "";
}
}
/**
* Returns functionCall of first candidate.
*/
function getFunctionCall(response) {
var _a, _b, _c, _d;
return (_d = (_c = (_b = (_a = response.candidates) === null || _a === void 0 ? void 0 : _a[0].content) === null || _b === void 0 ? void 0 : _b.parts) === null || _c === void 0 ? void 0 : _c[0]) === null || _d === void 0 ? void 0 : _d.functionCall;
}
const badFinishReasons = [FinishReason.RECITATION, FinishReason.SAFETY];
function hadBadFinishReason(candidate) {
return (!!candidate.finishReason &&
badFinishReasons.includes(candidate.finishReason));
}
function formatBlockErrorMessage(response) {
var _a, _b, _c;
let message = "";
if ((!response.candidates || response.candidates.length === 0) &&
response.promptFeedback) {
message += "Response was blocked";
if ((_a = response.promptFeedback) === null || _a === void 0 ? void 0 : _a.blockReason) {
message += ` due to ${response.promptFeedback.blockReason}`;
}
if ((_b = response.promptFeedback) === null || _b === void 0 ? void 0 : _b.blockReasonMessage) {
message += `: ${response.promptFeedback.blockReasonMessage}`;
}
}
else if ((_c = response.candidates) === null || _c === void 0 ? void 0 : _c[0]) {
const firstCandidate = response.candidates[0];
if (hadBadFinishReason(firstCandidate)) {
message += `Candidate was blocked due to ${firstCandidate.finishReason}`;
if (firstCandidate.finishMessage) {
message += `: ${firstCandidate.finishMessage}`;
}
}
}
return message;
}
/******************************************************************************
Copyright (c) Microsoft Corporation.
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.
***************************************************************************** */
/* global Reflect, Promise, SuppressedError, Symbol */
function __await(v) {
return this instanceof __await ? (this.v = v, this) : new __await(v);
}
function __asyncGenerator(thisArg, _arguments, generator) {
if (!Symbol.asyncIterator) throw new TypeError("Symbol.asyncIterator is not defined.");
var g = generator.apply(thisArg, _arguments || []), i, q = [];
return i = {}, verb("next"), verb("throw"), verb("return"), i[Symbol.asyncIterator] = function () { return this; }, i;
function verb(n) { if (g[n]) i[n] = function (v) { return new Promise(function (a, b) { q.push([n, v, a, b]) > 1 || resume(n, v); }); }; }
function resume(n, v) { try { step(g[n](v)); } catch (e) { settle(q[0][3], e); } }
function step(r) { r.value instanceof __await ? Promise.resolve(r.value.v).then(fulfill, reject) : settle(q[0][2], r); }
function fulfill(value) { resume("next", value); }
function reject(value) { resume("throw", value); }
function settle(f, v) { if (f(v), q.shift(), q.length) resume(q[0][0], q[0][1]); }
}
typeof SuppressedError === "function" ? SuppressedError : function (error, suppressed, message) {
var e = new Error(message);
return e.name = "SuppressedError", e.error = error, e.suppressed = suppressed, e;
};
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/;
/**
* Process a response.body stream from the backend and return an
* iterator that provides one complete GenerateContentResponse at a time
* and a promise that resolves with a single aggregated
* GenerateContentResponse.
*
* @param response - Response from a fetch call
*/
function processStream(response) {
const inputStream = response.body.pipeThrough(new TextDecoderStream("utf8", { fatal: true }));
const responseStream = getResponseStream(inputStream);
const [stream1, stream2] = responseStream.tee();
return {
stream: generateResponseSequence(stream1),
response: getResponsePromise(stream2),
};
}
async function getResponsePromise(stream) {
const allResponses = [];
const reader = stream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) {
return addHelpers(aggregateResponses(allResponses));
}
allResponses.push(value);
}
}
function generateResponseSequence(stream) {
return __asyncGenerator(this, arguments, function* generateResponseSequence_1() {
const reader = stream.getReader();
while (true) {
const { value, done } = yield __await(reader.read());
if (done) {
break;
}
yield yield __await(addHelpers(value));
}
});
}
/**
* Reads a raw stream from the fetch response and join incomplete
* chunks, returning a new stream that provides a single complete
* GenerateContentResponse in each iteration.
*/
function getResponseStream(inputStream) {
const reader = inputStream.getReader();
const stream = new ReadableStream({
start(controller) {
let currentText = "";
return pump();
function pump() {
return reader.read().then(({ value, done }) => {
if (done) {
if (currentText.trim()) {
controller.error(new GoogleGenerativeAIError("Failed to parse stream"));
return;
}
controller.close();
return;
}
currentText += value;
let match = currentText.match(responseLineRE);
let parsedResponse;
while (match) {
try {
parsedResponse = JSON.parse(match[1]);
}
catch (e) {
controller.error(new GoogleGenerativeAIError(`Error parsing JSON response: "${match[1]}"`));
return;
}
controller.enqueue(parsedResponse);
currentText = currentText.substring(match[0].length);
match = currentText.match(responseLineRE);
}
return pump();
});
}
},
});
return stream;
}
/**
* Aggregates an array of `GenerateContentResponse`s into a single
* GenerateContentResponse.
*/
function aggregateResponses(responses) {
const lastResponse = responses[responses.length - 1];
const aggregatedResponse = {
promptFeedback: lastResponse === null || lastResponse === void 0 ? void 0 : lastResponse.promptFeedback,
};
for (const response of responses) {
if (response.candidates) {
for (const candidate of response.candidates) {
const i = candidate.index;
if (!aggregatedResponse.candidates) {
aggregatedResponse.candidates = [];
}
if (!aggregatedResponse.candidates[i]) {
aggregatedResponse.candidates[i] = {
index: candidate.index,
};
}
// Keep overwriting, the last one will be final
aggregatedResponse.candidates[i].citationMetadata =
candidate.citationMetadata;
aggregatedResponse.candidates[i].finishReason = candidate.finishReason;
aggregatedResponse.candidates[i].finishMessage =
candidate.finishMessage;
aggregatedResponse.candidates[i].safetyRatings =
candidate.safetyRatings;
/**
* Candidates should always have content and parts, but this handles
* possible malformed responses.
*/
if (candidate.content && candidate.content.parts) {
if (!aggregatedResponse.candidates[i].content) {
aggregatedResponse.candidates[i].content = {
role: candidate.content.role || "user",
parts: [],
};
}
const newPart = {};
for (const part of candidate.content.parts) {
if (part.text) {
newPart.text = part.text;
}
if (part.functionCall) {
newPart.functionCall = part.functionCall;
}
if (Object.keys(newPart).length === 0) {
newPart.text = "";
}
aggregatedResponse.candidates[i].content.parts.push(newPart);
}
}
}
}
}
return aggregatedResponse;
}
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
async function generateContentStream(apiKey, model, params, requestOptions) {
const url = new RequestUrl(model, Task.STREAM_GENERATE_CONTENT, apiKey,
/* stream */ true, requestOptions);
const response = await makeRequest(url, JSON.stringify(params), requestOptions);
return processStream(response);
}
async function generateContent(apiKey, model, params, requestOptions) {
const url = new RequestUrl(model, Task.GENERATE_CONTENT, apiKey,
/* stream */ false, requestOptions);
const response = await makeRequest(url, JSON.stringify(params), requestOptions);
const responseJson = await response.json();
const enhancedResponse = addHelpers(responseJson);
return {
response: enhancedResponse,
};
}
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
function formatNewContent(request) {
let newParts = [];
if (typeof request === "string") {
newParts = [{ text: request }];
}
else {
for (const partOrString of request) {
if (typeof partOrString === "string") {
newParts.push({ text: partOrString });
}
else {
newParts.push(partOrString);
}
}
}
return assignRoleToPartsAndValidateSendMessageRequest(newParts);
}
/**
* When multiple Part types (i.e. FunctionResponsePart and TextPart) are
* passed in a single Part array, we may need to assign different roles to each
* part. Currently only FunctionResponsePart requires a role other than 'user'.
* @private
* @param parts Array of parts to pass to the model
* @returns Array of content items
*/
function assignRoleToPartsAndValidateSendMessageRequest(parts) {
const userContent = { role: "user", parts: [] };
const functionContent = { role: "function", parts: [] };
let hasUserContent = false;
let hasFunctionContent = false;
for (const part of parts) {
if ("functionResponse" in part) {
functionContent.parts.push(part);
hasFunctionContent = true;
}
else {
userContent.parts.push(part);
hasUserContent = true;
}
}
if (hasUserContent && hasFunctionContent) {
throw new GoogleGenerativeAIError("Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.");
}
if (!hasUserContent && !hasFunctionContent) {
throw new GoogleGenerativeAIError("No content is provided for sending chat message.");
}
if (hasUserContent) {
return userContent;
}
return functionContent;
}
function formatGenerateContentInput(params) {
if (params.contents) {
return params;
}
else {
const content = formatNewContent(params);
return { contents: [content] };
}
}
function formatEmbedContentInput(params) {
if (typeof params === "string" || Array.isArray(params)) {
const content = formatNewContent(params);
return { content };
}
return params;
}
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// https://ai.google.dev/api/rest/v1beta/Content#part
const VALID_PART_FIELDS = [
"text",
"inlineData",
"functionCall",
"functionResponse",
];
const VALID_PARTS_PER_ROLE = {
user: ["text", "inlineData"],
function: ["functionResponse"],
model: ["text", "functionCall"],
};
const VALID_PREVIOUS_CONTENT_ROLES = {
user: ["model"],
function: ["model"],
model: ["user", "function"],
};
function validateChatHistory(history) {
let prevContent;
for (const currContent of history) {
const { role, parts } = currContent;
if (!prevContent && role !== "user") {
throw new GoogleGenerativeAIError(`First content should be with role 'user', got ${role}`);
}
if (!POSSIBLE_ROLES.includes(role)) {
throw new GoogleGenerativeAIError(`Each item should include role field. Got ${role} but valid roles are: ${JSON.stringify(POSSIBLE_ROLES)}`);
}
if (!Array.isArray(parts)) {
throw new GoogleGenerativeAIError("Content should have 'parts' property with an array of Parts");
}
if (parts.length === 0) {
throw new GoogleGenerativeAIError("Each Content should have at least one part");
}
const countFields = {
text: 0,
inlineData: 0,
functionCall: 0,
functionResponse: 0,
};
for (const part of parts) {
for (const key of VALID_PART_FIELDS) {
if (key in part) {
countFields[key] += 1;
}
}
}
const validParts = VALID_PARTS_PER_ROLE[role];
for (const key of VALID_PART_FIELDS) {
if (!validParts.includes(key) && countFields[key] > 0) {
throw new GoogleGenerativeAIError(`Content with role '${role}' can't contain '${key}' part`);
}
}
if (prevContent) {
const validPreviousContentRoles = VALID_PREVIOUS_CONTENT_ROLES[role];
if (!validPreviousContentRoles.includes(prevContent.role)) {
throw new GoogleGenerativeAIError(`Content with role '${role}' can't follow '${prevContent.role}'. Valid previous roles: ${JSON.stringify(VALID_PREVIOUS_CONTENT_ROLES)}`);
}
}
prevContent = currContent;
}
}
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Do not log a message for this error.
*/
const SILENT_ERROR = "SILENT_ERROR";
/**
* ChatSession class that enables sending chat messages and stores
* history of sent and received messages so far.
*
* @public
*/
class ChatSession {
constructor(apiKey, model, params, requestOptions) {
this.model = model;
this.params = params;
this.requestOptions = requestOptions;
this._history = [];
this._sendPromise = Promise.resolve();
this._apiKey = apiKey;
if (params === null || params === void 0 ? void 0 : params.history) {
validateChatHistory(params.history);
this._history = params.history;
}
}
/**
* Gets the chat history so far. Blocked prompts are not added to history.
* Blocked candidates are not added to history, nor are the prompts that
* generated them.
*/
async getHistory() {
await this._sendPromise;
return this._history;
}
/**
* Sends a chat message and receives a non-streaming
* {@link GenerateContentResult}
*/
async sendMessage(request) {
var _a, _b, _c;
await this._sendPromise;
const newContent = formatNewContent(request);
const generateContentRequest = {
safetySettings: (_a = this.params) === null || _a === void 0 ? void 0 : _a.safetySettings,
generationConfig: (_b = this.params) === null || _b === void 0 ? void 0 : _b.generationConfig,
tools: (_c = this.params) === null || _c === void 0 ? void 0 : _c.tools,
contents: [...this._history, newContent],
};
let finalResult;
// Add onto the chain.
this._sendPromise = this._sendPromise
.then(() => generateContent(this._apiKey, this.model, generateContentRequest, this.requestOptions))
.then((result) => {
var _a;
if (result.response.candidates &&
result.response.candidates.length > 0) {
this._history.push(newContent);
const responseContent = Object.assign({ parts: [],
// Response seems to come back without a role set.
role: "model" }, (_a = result.response.candidates) === null || _a === void 0 ? void 0 : _a[0].content);
this._history.push(responseContent);
}
else {
const blockErrorMessage = formatBlockErrorMessage(result.response);
if (blockErrorMessage) {
console.warn(`sendMessage() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.`);
}
}
finalResult = result;
});
await this._sendPromise;
return finalResult;
}
/**
* Sends a chat message and receives the response as a
* {@link GenerateContentStreamResult} containing an iterable stream
* and a response promise.
*/
async sendMessageStream(request) {
var _a, _b, _c;
await this._sendPromise;
const newContent = formatNewContent(request);
const generateContentRequest = {
safetySettings: (_a = this.params) === null || _a === void 0 ? void 0 : _a.safetySettings,
generationConfig: (_b = this.params) === null || _b === void 0 ? void 0 : _b.generationConfig,
tools: (_c = this.params) === null || _c === void 0 ? void 0 : _c.tools,
contents: [...this._history, newContent],
};
const streamPromise = generateContentStream(this._apiKey, this.model, generateContentRequest, this.requestOptions);
// Add onto the chain.
this._sendPromise = this._sendPromise
.then(() => streamPromise)
// This must be handled to avoid unhandled rejection, but jump
// to the final catch block with a label to not log this error.
.catch((_ignored) => {
throw new Error(SILENT_ERROR);
})
.then((streamResult) => streamResult.response)
.then((response) => {
if (response.candidates && response.candidates.length > 0) {
this._history.push(newContent);
const responseContent = Object.assign({}, response.candidates[0].content);
// Response seems to come back without a role set.
if (!responseContent.role) {
responseContent.role = "model";
}
this._history.push(responseContent);
}
else {
const blockErrorMessage = formatBlockErrorMessage(response);
if (blockErrorMessage) {
console.warn(`sendMessageStream() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.`);
}
}
})
.catch((e) => {
// Errors in streamPromise are already catchable by the user as
// streamPromise is returned.
// Avoid duplicating the error message in logs.
if (e.message !== SILENT_ERROR) {
// Users do not have access to _sendPromise to catch errors
// downstream from streamPromise, so they should not throw.
console.error(e);
}
});
return streamPromise;
}
}
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
async function countTokens(apiKey, model, params, requestOptions) {
const url = new RequestUrl(model, Task.COUNT_TOKENS, apiKey, false, {});
const response = await makeRequest(url, JSON.stringify(Object.assign(Object.assign({}, params), { model })), requestOptions);
return response.json();
}
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
async function embedContent(apiKey, model, params, requestOptions) {
const url = new RequestUrl(model, Task.EMBED_CONTENT, apiKey, false, {});
const response = await makeRequest(url, JSON.stringify(params), requestOptions);
return response.json();
}
async function batchEmbedContents(apiKey, model, params, requestOptions) {
const url = new RequestUrl(model, Task.BATCH_EMBED_CONTENTS, apiKey, false, {});
const requestsWithModel = params.requests.map((request) => {
return Object.assign(Object.assign({}, request), { model });
});
const response = await makeRequest(url, JSON.stringify({ requests: requestsWithModel }), requestOptions);
return response.json();
}
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Class for generative model APIs.
* @public
*/
class GenerativeModel {
constructor(apiKey, modelParams, requestOptions) {
this.apiKey = apiKey;
if (modelParams.model.includes("/")) {
// Models may be named "models/model-name" or "tunedModels/model-name"
this.model = modelParams.model;
}
else {
// If path is not included, assume it's a non-tuned model.
this.model = `models/${modelParams.model}`;
}
this.generationConfig = modelParams.generationConfig || {};
this.safetySettings = modelParams.safetySettings || [];
this.tools = modelParams.tools;
this.requestOptions = requestOptions || {};
}
/**
* Makes a single non-streaming call to the model
* and returns an object containing a single {@link GenerateContentResponse}.
*/
async generateContent(request) {
const formattedParams = formatGenerateContentInput(request);
return generateContent(this.apiKey, this.model, Object.assign({ generationConfig: this.generationConfig, safetySettings: this.safetySettings, tools: this.tools }, formattedParams), this.requestOptions);
}
/**
* Makes a single streaming call to the model
* and returns an object containing an iterable stream that iterates
* over all chunks in the streaming response as well as
* a promise that returns the final aggregated response.
*/
async generateContentStream(request) {
const formattedParams = formatGenerateContentInput(request);
return generateContentStream(this.apiKey, this.model, Object.assign({ generationConfig: this.generationConfig, safetySettings: this.safetySettings, tools: this.tools }, formattedParams), this.requestOptions);
}
/**
* Gets a new {@link ChatSession} instance which can be used for
* multi-turn chats.
*/
startChat(startChatParams) {
return new ChatSession(this.apiKey, this.model, Object.assign({ tools: this.tools }, startChatParams), this.requestOptions);
}
/**
* Counts the tokens in the provided request.
*/
async countTokens(request) {
const formattedParams = formatGenerateContentInput(request);
return countTokens(this.apiKey, this.model, formattedParams);
}
/**
* Embeds the provided content.
*/
async embedContent(request) {
const formattedParams = formatEmbedContentInput(request);
return embedContent(this.apiKey, this.model, formattedParams);
}
/**
* Embeds an array of {@link EmbedContentRequest}s.
*/
async batchEmbedContents(batchEmbedContentRequest) {
return batchEmbedContents(this.apiKey, this.model, batchEmbedContentRequest, this.requestOptions);
}
}
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Top-level class for this SDK
* @public
*/
class GoogleGenerativeAI {
constructor(apiKey) {
this.apiKey = apiKey;
}
/**
* Gets a {@link GenerativeModel} instance for the provided model name.
*/
getGenerativeModel(modelParams, requestOptions) {
if (!modelParams.model) {
throw new GoogleGenerativeAIError(`Must provide a model name. ` +
`Example: genai.getGenerativeModel({ model: 'my-model-name' })`);
}
return new GenerativeModel(this.apiKey, modelParams, requestOptions);
}
}
export { BlockReason, ChatSession, FinishReason, FunctionDeclarationSchemaType, GenerativeModel, GoogleGenerativeAI, HarmBlockThreshold, HarmCategory, HarmProbability, POSSIBLE_ROLES, TaskType };
//# sourceMappingURL=index.mjs.map
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment