Skip to content

Instantly share code, notes, and snippets.

@emonidi
Created May 24, 2023 16:17
Show Gist options
  • Save emonidi/b257014de26497db506850df9998b2bb to your computer and use it in GitHub Desktop.
Save emonidi/b257014de26497db506850df9998b2bb to your computer and use it in GitHub Desktop.
Langchain.js Naive ToT implementation
import { models } from '../models'
import { LLMChain, SerializedLLMChain } from "langchain/chains";
import { PromptTemplate } from "langchain/prompts";
import {AgentExecutor, BaseSingleActionAgent, StoppingMethod } from "langchain/agents";
import { CallbackManagerForChainRun, Callbacks } from "langchain/callbacks";
import { BaseMultiActionAgent } from "langchain/dist/agents/agent";
import { BaseMemory } from "langchain/memory";
import { AgentAction, AgentFinish, ChainValues } from "langchain/schema";
import { Tool } from "langchain/tools";
import { OutputParser } from "./output-parser";
import { FormatInstructionsOptions } from "langchain/schema/output_parser";
(async () => {
const model = models["oa"];
const thoughtPrompt = new PromptTemplate({
template: "Given the current instruction: '{input}', generate {sizeLimit} different answers:",
inputVariables: ["input", "sizeLimit"]
});
class ThoughtOutputParser extends OutputParser {
promptTemplate: PromptTemplate
constructor(fields: { promptTemplate: PromptTemplate }) {
super(fields.promptTemplate.template);
this.promptTemplate = fields.promptTemplate;
}
//@ts-ignore
async parse(text: string, callbacks?: Callbacks | undefined): Promise<any> {
console.log(text.split("\n").filter(item=>item !== '').map(item=>item.replace(/\d.\s/ig, "")))
return text.split("\n").filter(item=>item !== '').map(item=>item.replace(/\d.\s/ig, ""));
}
//@ts-ignore
getFormatInstructions(options?: FormatInstructionsOptions | undefined): string {
throw new Error("Method not implemented.");
}
}
class EvaluatorOutputParser extends OutputParser {
}
const thoughtGenerator = new LLMChain({
llm: model,
prompt: thoughtPrompt,
outputParser: new ThoughtOutputParser({ promptTemplate: thoughtPrompt })
});
const thoughtEvaluator = new LLMChain({
llm: model,
prompt: new PromptTemplate({
template: "Given the following instruction as a context: '{context}' and the current state of reasoning: '{state_text}', critically evaluate its relevance and accuracy as a float between 0 and 1, and NOTHING ELSE:",
inputVariables: ["state_text","context"]
}),
outputParser: new OutputParser("")
})
interface ToT {
thought: string,
evaluation: number,
children?: ToT[] | undefined;
}
interface ToTInput {
sizeLimit: number // k,
stepLimit: number //T,
threshold: number //Vth
}
//@ts-ignore;
class ToTExecutor implements AgentExecutor {
agent: BaseSingleActionAgent | BaseMultiActionAgent;
tools: Tool[];
totInput: ToTInput
returnIntermediateSteps: boolean;
maxIterations?: number | undefined;
earlyStoppingMethod: StoppingMethod;
memory?: BaseMemory | undefined;
evaluator: LLMChain<AgentAction|AgentFinish>
generator: LLMChain<AgentAction|AgentFinish>
constructor(fields: {
totInput: ToTInput,
evaluator: LLMChain<AgentAction|AgentFinish>,
generator: LLMChain<AgentAction|AgentFinish>
}) {
this.totInput = fields.totInput;
this.evaluator = fields.evaluator;
this.generator = fields.generator;
this.run = this.run.bind(this);
}
get inputKeys(): string[] {
throw new Error("Method not implemented.");
}
get outputKeys(): string[] {
throw new Error("Method not implemented.");
}
//@ts-ignore
_call(inputs: ChainValues, runManager?: CallbackManagerForChainRun | undefined): Promise<ChainValues> {
throw new Error("Method not implemented.");
}
//@ts-ignore
_chainType(): "agent_executor" {
throw new Error("Method not implemented.");
}
//@ts-ignore
serialize(): SerializedLLMChain {
throw new Error("Method not implemented.");
}
async run(input: any, _callbacks?: Callbacks | undefined): Promise<string> {
let output: ToT[]= [];
let context = input;
let totInput = this.totInput;
let { generator, evaluator } = this;
async function dfs(_stepLimit: number, _currentStep: number,input:any) {
if(_currentStep > totInput.stepLimit){
const thought = await generator.call({
input,
sizeLimit: 1
});
const evaluated = await evaluator.call({ state_text: thought.text[0], context });
const evaluation = parseFloat(evaluated.text.log)
output.push({
thought:thought.text[0],
evaluation
})
return;
}
const thoughts =
await generator.call({
input,
sizeLimit: totInput.sizeLimit
});
for (let i = 0; i < thoughts.text.length; i++) {
const evaluated = await evaluator.call({ state_text: thoughts.text[i], context });
const evaluation = parseFloat(evaluated.text.log.trim())
console.log(`[step]:${_currentStep}`)
console.log(`[thought]:${thoughts.text[i]}`);
console.log(`[score]:${evaluation}`);
if(evaluation > totInput.threshold){
await dfs(totInput.sizeLimit,_currentStep+1,thoughts.text[i])
}
}
};
await dfs(1, 0,input);
return Promise.resolve(JSON.stringify(output));
}
//@ts-ignore
call(values: ChainValues, callbacks?: Callbacks | undefined): Promise<ChainValues> {
throw new Error("Method not implemented.");
}
//@ts-ignore
apply(inputs: ChainValues[], callbacks?: Callbacks[] | undefined): Promise<ChainValues> {
throw new Error("Method not implemented.");
}
verbose: boolean;
callbacks?: Callbacks | undefined;
}
const tot = new ToTExecutor({
totInput: {
sizeLimit: 3,
stepLimit: 2,
threshold: 0.5
},
evaluator: thoughtEvaluator,
generator: thoughtGenerator
})
const res = await tot.run("What steps I need to take to become a good AI engineer?");
console.log(
JSON.parse(res).sort((a:ToT,b:ToT)=>{
return a.evaluation - b.evaluation;
})
);
})()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment