Skip to content

Instantly share code, notes, and snippets.

@josiahbryan
Last active April 15, 2024 04:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save josiahbryan/54dd184a9614882f1ee9ea413110d95e to your computer and use it in GitHub Desktop.
Save josiahbryan/54dd184a9614882f1ee9ea413110d95e to your computer and use it in GitHub Desktop.
/**
* Coax an LLM into generated grounded JSON based on key/value-level perplexity checks with a custom
* `yup` validation test for perplexity, model fallbacks, and a customizable failure injection callback
* to inject grounding instructions on request for exact paths (e.g. for specific keys, if perplexity too high).
*
* Contributing to the community as an example of how to coax an LLM into generating JSON that conforms
* to a schema with key-level perplexity checks and fallback models for retries.
*/
/*
# Example run
Example of a successful run of this script showing that it retried when perplexity was too high for the "formalName"
prop and finally succeeded. See the "executeMain" section at the bottom for the schema and prompt given.
```log
{
"formalName": "Emily",
"nickname": "Em",
"ageGuess": 25
}
(PID 74676) 2024-04-15T04:17:36.875Z coax-llm.js [DEBUG] coax-llm.js:162: Cannot use generated JSON because it failed schema validation: Error: High perplexity detected in "Emily" at "formalName", indicating potential inaccuracies or lack of grounding in facts. Please refine to be more grounded in factual content and closely aligned with the context provided. Target perplexity under: 1.125. {
path: 'formalName',
type: 'perplexity',
errors: [
'Error: High perplexity detected in "Emily" at "formalName", indicating potential inaccuracies or lack of grounding in facts. Please refine to be more grounded in factual content and closely aligned with the context provided. Target perplexity under: 1.125.'
]
}
(PID 74676) 2024-04-15T04:17:36.875Z coax-llm.js [WARN] coax-llm.js:233: Retrying with fallback model: gpt-3.5-turbo-0125
{
"formalName": "Josiah Bryan",
"nickname": "Joey",
"ageGuess": 28
}(PID 74676) 2024-04-15T04:17:37.519Z coax-llm.js [INFO] coax-llm.js:330: Generated content: {
"formalName": "Josiah Bryan",
"nickname": "Joey",
"ageGuess": 28
}
(PID 74676) 2024-04-15T04:17:37.519Z coax-llm.js [INFO] coax-llm.js:334: Generated object: { formalName: 'Josiah Bryan', nickname: 'Joey', ageGuess: 28 }
(PID 74676) 2024-04-15T04:17:37.519Z coax-llm.js [INFO] coax-llm.js:338: Generated object with metadata: {
formalName: {
key: 'formalName',
value: 'Josiah Bryan',
keyProb: 0.999996,
valueProb: 0.999957,
keyPerplexity: 1.000001,
valuePerplexity: 1.000014,
finished: true
},
nickname: {
key: 'nickname',
value: 'Joey',
keyProb: 0.999996,
valueProb: 0.872926,
keyPerplexity: 1.000004,
valuePerplexity: 1.070314,
finished: true
},
ageGuess: {
key: 'ageGuess',
value: 28,
keyProb: 0.999994,
valueProb: 0.594872,
keyPerplexity: 1.000003,
valuePerplexity: 1.681035,
finished: true
}
}
```
*/
/* eslint-disable no-unused-vars */
import { ModelIds } from 'shared/utils/ChatBotDefEnums';
import Logger from 'shared/utils/Logger';
import { yup, convertSchema } from 'shared/requests/utils/AuthoringHelpers';
import { executeMain } from '../../../../../utils/script-cli';
import { llmPredictFactory } from '../../../../chatbot/utils/llmPredictFactory';
import {
convertSimplifiedJsonToObject,
onStreamChunkDebugFactory,
} from '../../../../chatbot/utils/withStreamLogProbsToJson';
import {
safeEndLangfuse,
traceFactory,
} from '../../../../chatbot/utils/langfuse';
// Expects that yup.validate() will receive a 2nd context arg with { simplified, logger } where
// 'simplified' is from the metadata of the LLM predict response (e.g. from withStreamLogProbsToJson)
function perplexityTest(options) {
return this.test({
name: 'perplexity',
message: ({ path, value }) =>
// `The perplexity of "${value}" (located at "${path}") was too high and it should be re-evaluated. The perplexity for this path should be under ${options.max} to be considered safe.`, // Using a function for dynamic message
// Error message refined with the help of gpt-4
`Error: High perplexity detected in "${value}" at "${path}", indicating potential inaccuracies or lack of grounding in facts. Please refine to be more grounded in factual content and closely aligned with the context provided. Target perplexity under: ${options.max}.`,
test(value) {
const { path } = this; // Access the current path
const { context } = this.options; // Access the context to get simplified json for further testing
const { simplified, logger } = context;
// logger.debug('Testing path:', path); // Optionally log the path for debugging
const metadata = simplified?.find((x) => x[0] === path);
if (!metadata) {
logger.warn(
'yup.perplexity(): No metadata found in context for path:',
path,
);
return false;
}
const { valuePerplexity } = metadata[1];
// The actual test condition
const passed = valuePerplexity == null || valuePerplexity <= options.max;
// logger.debug(`Perplexity test for ${path} (${valuePerplexity}):`, passed);
return passed;
},
});
}
// Attach our custom perplexity test to all the types we want to support
[yup.string, yup.number, yup.object, yup.mixed, yup.array, yup.boolean].forEach(
(type) => {
yup.addMethod(type, 'perplexity', perplexityTest);
},
);
/**
* Coax an LLM into generating grounded JSON based on key/value-level perplexity checks
*
* @param {Object} options - The options for generating content.
* @param {string} options.prompt - The prompt for generating content.
* @param {string[]} [options.fallbackModels] - The fallback models to use if the generation fails.
* @param {yup.ObjectSchema} [options.schema] - The schema to validate the generated JSON against.
* @param {Logger} [options.logger] - The logger to use for logging.
* @param {boolean} [options.includeSchemaInPrompt] - Whether to include the schema in the prompt.
* @param {Function} [options.failureInjectCallback] - The callback function to inject additional instructions on failure.
* @param {Object} [options.langfuseTrace] - The langfuse trace object.
* @param {string} [options.templateId] - The template ID.
* @param {string} [options.spanName] - The span name.
* @param {number} [options.maxRootPerplexity] - The maximum perplexity allowed for the root content.
* @param {boolean} [options.enableStreamJsonParse] - Whether to enable stream JSON parsing.
* @param {string} [options.cacheMode] - The cache mode to use.
* @returns {Promise<Object>} - The generated content and metadata.
*/
async function coaxLlm({
prompt: promptInput,
fallbackModels = [
ModelIds.Gpt3_5Turbo,
ModelIds.Gpt4,
ModelIds.Claude_3_Opus, // no logprobs!
],
schema = yup.object().shape({}),
logger = Logger,
// If true, converts the schema prop to a JSON Schema and injects it at the END of the prompt
includeSchemaInPrompt = true,
// Async method that can be used to inject additional instructions on failure
failureInjectCallback = async ({
retryCount, // depth into retries
modelName, // model name the failure was generated with
message, // error message
type, // type of failure (e.g. perplexity, typeError, etc.)
errors, // array of errors if from yup
data, // data generated that caused a failure
content, // raw content from LLM if any
jsonLogProbs, // logprobs for the JSON generation (see simplified or result props)
}) =>
// should return an array of strings to inject into the prompt AFTER the failure message (failure message already automatically injected at start of prompt before original prompt, and these lines will go right after the failure message before the original prompt is added back in)
[],
langfuseTrace,
templateId = 'coax-llm',
spanName = templateId,
maxRootPerplexity = 3,
enableStreamJsonParse = true,
cacheMode = 'auto',
}) {
const prompt = [promptInput];
const span = langfuseTrace?.span({
name: spanName,
input: {
prompt: promptInput,
schema: convertSchema(schema),
includeSchemaInPrompt,
},
});
if (includeSchemaInPrompt) {
prompt.push(
'',
`### Schema\n\n`,
'```json\n',
JSON.stringify(convertSchema(schema), null, 2),
'```\n\n',
`Be sure to follow the schema above when generating JSON and output valid JSON that conforms to the schema:`,
);
}
const debugStream = onStreamChunkDebugFactory(logger);
const generate = async (args) => {
const { modelName, promptInject = [], retryCount = 0 } = args;
const predict = llmPredictFactory({
modelName,
});
const { content, metadata } = await predict(
[...promptInject, ...prompt].join('\n'),
{
cacheMode,
// enableStreamJsonParse,
onStreamChunk: debugStream,
returnMetadata: true,
langfuseTrace: span,
heliconeCustomProps: {
templateId,
},
},
);
let failure;
let object;
let objectWithMetadata;
if (schema) {
const { json: jsonLogProbs } = metadata || {};
const { simplified = [] } = jsonLogProbs || {};
// Use our blunt-force JSON parser output that converted the token+logprobs into
// a simplified array of entries like ["path.to.key", { valuePerplexity: 1.2 }] into
// a set of two JSON objects - one with the original JSON and one with the metadata
// for each key in the JSON
[object, objectWithMetadata] = convertSimplifiedJsonToObject(simplified, {
dualVersion: true,
});
await schema
.validate(object, {
// Pass the simplified array of JSON entries to the context for the yup schema to use
// in the custom test for perplexity (above)
context: {
simplified,
logger,
},
})
.catch((ex) => {
const { path, type, errors } = ex;
// If our custom 'perplexity' test above, type will equal "perplexity" and errors will contain our custom
// error message from the test function
logger.debug(
`Cannot use generated JSON because it failed schema validation: ${ex.message}`,
{
path,
type,
errors,
},
);
failure = {
message: `Error in generated JSON: ${ex.message}`,
type,
path,
errors,
data: object,
content,
jsonLogProbs,
};
});
}
if (!failure && !content) {
failure = {
message: 'Error: No content generated',
type: 'noContent',
path: '$root',
data: null,
};
}
if (!failure && metadata?.perplexity > maxRootPerplexity) {
failure = {
message: `Error: High perplexity detected in root content, indicating potential inaccuracies or lack of grounding in facts. Please refine to be more grounded in factual content and closely aligned with the context provided. Target perplexity under: ${maxRootPerplexity}.`,
type: 'perplexity',
path: '$root',
data: content,
};
}
if (failure) {
if (retryCount < fallbackModels.length) {
const { type, path, message, errors, ...failureRest } = failure;
const inject = [
`## IMPORTANT`,
`Failure detected in previous attempt: ${message}`,
`Failure type: ${type}`,
`Failure path: ${path}`,
`Pay attention to the failure message and the generation again.`,
'',
...(await failureInjectCallback({
modelName,
retryCount,
...failure,
})),
'', // Add a blank line before the original prompt
];
span?.event({
name: 'coax-failure',
metadata: {
type,
path,
errors,
simplifiedJsonLogProbs: failureRest.jsonLogProbs?.simplified,
},
input: { type, path, errors },
output: { retryCount, inject, nextModel: fallbackModels[retryCount] },
level: 'ERROR',
statusMessage: failure.message,
});
logger.warn(
`Retrying with fallback model: ${fallbackModels[retryCount]}`,
);
return generate({
modelName: fallbackModels[retryCount],
// promptInject: `### Retry with ${modelName} failed. Trying with ${fallbackModels[retryCount]}`,
promptInject: inject,
retryCount: retryCount + 1,
});
}
return { failure };
}
return { content, object, objectWithMetadata };
};
return safeEndLangfuse(span, {
output: await generate({ modelName: fallbackModels[0] }),
});
}
executeMain(async ({ authorization, logger }) => {
// const prompt = `Write me a poem about the ocean, with a focus on the beauty of the waves and the serenity of the sea. Include vivid imagery and descriptive language to bring the scene to life. Be sure to capture the essence of the ocean in your words. Rate the poem on a scale of 1 to 10, with 10 being the highest rating. Provide a brief explanation for your rating.`;
// const schema = yup.object().shape({
// poem: yup
// .string()
// .required()
// .description('Poem about the ocean')
// .perplexity({ max: 1.3 }),
// rating: yup
// .number()
// .min(0)
// .max(10)
// .description('Rating of the poem')
// .perplexity({ max: 1.3 }),
// explanation: yup
// .string()
// .required()
// .description('Explanation for the rating')
// .perplexity({ max: 1.3 }),
// });
const prompt = `What is my name? Generate a nickname based on my name and totally guess my age based on my name`;
const langfuseTrace = traceFactory({
logger,
name: 'coax-llm-test',
input: { prompt },
});
const schema = yup.object().shape({
formalName: yup
.string()
.required()
.description('Formal name')
.perplexity({ max: 1.125 }),
nickname: yup
.string()
.required()
.description('Generated nickname')
.perplexity({ max: 1.5 }),
ageGuess: yup
.number()
.required()
.description('Generated age guess')
.perplexity({ max: 99 }),
});
const { content, object, objectWithMetadata, failure } = await coaxLlm({
prompt,
schema,
logger,
langfuseTrace,
cacheMode: 'save',
failureInjectCallback: async ({ type, path }) => {
if (type === 'perplexity' && ['nickname', 'formalName'].includes(path)) {
return [`My name is: "${authorization.user.name}"`];
}
return [];
},
});
safeEndLangfuse(langfuseTrace, {
output: {
content,
object,
objectWithMetadata,
failure,
},
});
if (failure) {
logger.warn('Failed to generate valid JSON:', failure);
}
if (content) {
logger.info('Generated content:', content);
}
if (object) {
logger.info('Generated object:', object);
}
if (objectWithMetadata) {
logger.info('Generated object with metadata:', objectWithMetadata);
}
return { content, object, objectWithMetadata, failure };
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment