Last active
August 7, 2023 16:35
-
-
Save HackyDev/814c6d1c96f259a13dbf5b2dabf98e8f to your computer and use it in GitHub Desktop.
Use Llama2 in a `nodejs` project
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* INFO | |
This class allows you to use Llama2 in a `nodejs` project. | |
The class spawns a process which can be setup using instructions from here https://replicate.com/blog/run-llama-locally | |
The communication between llama2 and node is done using `stdin/stdout` interface. | |
*/ | |
/* USAGE EXAMPLE | |
import LlamaSpawner from './LlamaSpawner' | |
const options = { | |
executablePath: '/home/username/llama.cpp/main', | |
modelPath: '/home/username/llama.cpp/models/llama-2-7b-chat.ggmlv3.q4_0.bin' | |
} | |
const llamaSpawner = new LlamaSpawner(options) | |
llamaSpawner.start() | |
.then(() => llamaSpawner.prompt('How are you?')) | |
.then(answer => console.log(answer)) | |
.then(() => llamaSpawner.prompt('Are you sure?')) | |
.then(answer => console.log(answer)) | |
*/ | |
import { spawn, ChildProcessWithoutNullStreams } from 'child_process' | |
enum State { | |
Initial = 'initial', | |
Starting = 'starting', | |
ReadyForInput = 'readyForInput', | |
Inferencing = 'inferencing', | |
Stopping = 'stopping' | |
} | |
enum ErrorCode { | |
MultipleStart = 'multipleStart', | |
AlreadyStarted = 'alreadyStarted', | |
DataInInvalidState = 'dataInInvalidState', | |
ProcessError = 'processError', | |
ClosedWhileStarting = 'closedWhileStarting', | |
ClosedWhileInferencing = 'closedWhileInferencing', | |
Closed = 'closed', | |
UnknownStartError = 'unknownStartError', | |
NoProcessToPrompt = 'noProcessToPrompt', | |
PromptInInvalidState = 'promptInInvalidState', | |
UnknownPromptError = 'unknownPromptError', | |
UnknownStopError = 'UnknownStopError' | |
} | |
interface Constructor { | |
executablePath: string; | |
modelPath: string; | |
} | |
interface CreateArgsOptions { | |
modelPath: string; | |
} | |
class LlamaSpawner { | |
executablePath = '' | |
modelPath= '' | |
private process: ChildProcessWithoutNullStreams | null = null | |
private currentAnswer = '' | |
private promptResolve: ((value: string) => void) | null = null | |
private promptReject: ((reason: Error) => void) | null = null | |
private readyForInputString = '\n> ' | |
private state: State = State.Initial | |
constructor(options: Constructor) { | |
this.executablePath = options.executablePath | |
this.modelPath = options.modelPath | |
} | |
public async start(): Promise<void> { | |
if (this.state !== State.Initial) { | |
const error = this.handleError(ErrorCode.MultipleStart) | |
return Promise.reject(error) | |
} | |
return new Promise<void>((resolve, reject) => { | |
if (this.process) { | |
const error = this.handleError(ErrorCode.AlreadyStarted) | |
reject(error) | |
return | |
} | |
try { | |
this.setState(State.Starting) | |
const args = this.createArgs({ modelPath: this.modelPath }) | |
const process = this.process = spawn(this.executablePath, args) | |
process.stdout.on('data', (data) => { | |
const output = data.toString() | |
if (this.state === State.Inferencing) { | |
if (output === this.readyForInputString) { | |
this.setState(State.ReadyForInput) | |
} else { | |
this.addToAnswer(output) | |
} | |
} else if (this.state === State.Starting) { | |
if (output === this.readyForInputString) { | |
this.setState(State.ReadyForInput) | |
resolve() | |
} | |
} else { | |
const error = this.handleError(ErrorCode.DataInInvalidState) | |
console.error(error) | |
} | |
}) | |
process.on('error', (err) => { | |
const error = this.handleError(ErrorCode.ProcessError, err.message) | |
if (this.state === State.Starting) { | |
reject(error) | |
} else { | |
console.error(err) | |
} | |
}) | |
process.on('close', () => { | |
let error: Error | |
if (this.state === State.Starting) { | |
error = this.handleError(ErrorCode.ClosedWhileStarting) | |
} else if (this.state === State.Inferencing) { | |
error = this.handleError(ErrorCode.ClosedWhileInferencing) | |
} else { | |
error = this.handleError(ErrorCode.Closed) | |
} | |
reject(error) | |
}) | |
} catch (e) { | |
const message = this.getErrorMessage(e) | |
const error = this.handleError(ErrorCode.UnknownStartError, message) | |
reject(error) | |
} | |
}) | |
} | |
public async prompt(question: string): Promise<string> { | |
return new Promise<string>((resolve, reject) => { | |
try { | |
if (this.state === State.ReadyForInput) { | |
if (this.process) { | |
this.promptResolve = resolve | |
this.promptReject = reject | |
this.process.stdin.write(`${question}\r\n`) | |
this.setState(State.Inferencing) | |
} else { | |
const error = this.handleError(ErrorCode.NoProcessToPrompt) | |
reject(error) | |
} | |
} else { | |
const error = this.handleError(ErrorCode.PromptInInvalidState) | |
reject(error) | |
} | |
} catch (e) { | |
const message = this.getErrorMessage(e) | |
const error = this.handleError(ErrorCode.UnknownPromptError, message) | |
reject(error) | |
} | |
}) | |
} | |
public async stop(): Promise<void> { | |
return new Promise((resolve, reject) => { | |
try { | |
const process = this.process | |
if (process) { | |
process.stdin.end() | |
setTimeout(() => { | |
this.setState(State.Initial) | |
resolve() | |
}, 1000) | |
} else { | |
resolve() | |
} | |
} catch (e) { | |
const message = this.getErrorMessage(e) | |
const error = this.handleError(ErrorCode.UnknownStopError, message) | |
reject(error) | |
} | |
}) | |
} | |
private assertNever(value: never): never { | |
throw new Error(`Unknown error code: ${value}`) | |
} | |
private handleError (code: ErrorCode, message = '', data?: Record<string, any>): Error { | |
switch (code) { | |
case ErrorCode.MultipleStart: | |
message = '"start" was called before "stop"' | |
break | |
case ErrorCode.AlreadyStarted: | |
message = '"start" was called multiple times' | |
break | |
case ErrorCode.DataInInvalidState: | |
message = 'needs investigation' | |
data = { | |
expectedStates: [State.ReadyForInput, State.Inferencing], | |
currentState: this.state | |
} | |
break | |
case ErrorCode.ProcessError: | |
this.setState(State.Initial) | |
break | |
case ErrorCode.ClosedWhileStarting: | |
message = 'process closed while starting' | |
this.setState(State.Initial) | |
break | |
case ErrorCode.ClosedWhileInferencing: | |
message = 'process closed while inferencing' | |
this.setState(State.Initial) | |
break | |
case ErrorCode.Closed: | |
message = 'process was closed' | |
this.setState(State.Initial) | |
break | |
case ErrorCode.UnknownStartError: | |
this.setState(State.Initial) | |
break | |
case ErrorCode.NoProcessToPrompt: | |
message = 'must investigate what is going on' | |
this.setState(State.Initial) | |
break | |
case ErrorCode.PromptInInvalidState: | |
message = 'process was possibly was stopped or the inference was going on' | |
break | |
case ErrorCode.UnknownPromptError: | |
this.setState(State.Initial) | |
break | |
case ErrorCode.UnknownStopError: | |
this.setState(State.Initial) | |
break | |
default: | |
return this.assertNever(code) | |
} | |
const error = this.createError(code, message, data) | |
this.promptReject && this.promptReject(error) | |
return error | |
} | |
private getErrorMessage (e: unknown) { | |
let message = '' | |
if (e instanceof Error) message = e.message | |
return message | |
} | |
private addToAnswer (value: string) { | |
this.currentAnswer += value | |
} | |
private setState(state: State): void { | |
this.state = state | |
switch (state) { | |
case State.Initial: { | |
if (this.process) { | |
this.process.stdin.end() | |
this.process.kill() | |
} | |
this.process = null | |
this.currentAnswer = '' | |
this.promptResolve = null | |
this.promptReject = null | |
break | |
} | |
case State.ReadyForInput: { | |
if (this.currentAnswer && this.promptResolve) { | |
this.promptResolve && this.promptResolve(this.currentAnswer) | |
} | |
this.currentAnswer = '' | |
break | |
} | |
} | |
} | |
private createArgs(options: CreateArgsOptions): string[] { | |
return [ | |
'-m', options.modelPath, | |
'--ctx_size', '2048', | |
'-n', '-1', '-ins', | |
'-b', '256', | |
'--top_k', '10000', | |
'--temp', '0.2', | |
'--repeat_penalty', '1.1', | |
'-t', '8' | |
] | |
} | |
private createError (code: string, message: string, data: Record<string, any> = {}) { | |
const debug = { | |
currentState: this.state, | |
...data | |
} | |
const debugStr = JSON.stringify(debug) | |
return new Error(`${code}:${message}:${debugStr}`) | |
} | |
} | |
export default LlamaSpawner |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment