Skip to content

Instantly share code, notes, and snippets.

@JosephCatrambone
Last active March 8, 2024 01:32
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save JosephCatrambone/39fb1d7902ffdb73530119b0039855af to your computer and use it in GitHub Desktop.
Save JosephCatrambone/39fb1d7902ffdb73530119b0039855af to your computer and use it in GitHub Desktop.
Embedding GPT-2 in Godot via Rust
mod ml_thread;
use gdnative::prelude::{godot_print, methods, Method, NativeClass, Node as GDNode, InitHandle, godot_init};
use ml_thread::start_language_model_thread;
use std::sync::mpsc::{channel, Receiver, RecvError, Sender, SendError};
const MAX_INPUT_LENGTH: usize = 512;
const BATCH_SIZE: usize = 1;
// Contains our processing job and work IDs.
#[derive(NativeClass)]
#[inherit(GDNode)]
pub struct ChatBot {
message_tx: Sender<String>,
response_rx: Receiver<String>
}
// Only one impl block can have [methods].
#[methods]
impl ChatBot {
/// The "constructor" of the class.
pub fn new(_base: &GDNode) -> Self {
let (tx, rx) = start_language_model_thread();
ChatBot {
message_tx: tx,
response_rx: rx,
}
}
pub fn make_reply(&self, text: &str, maxent: bool) -> String {
// If we use maxent, then we just pick the most likely word.
// Otherwise select probabalistically.
// No beam search yet.
self.message_tx.send(text.to_string()).expect("Child runner crashed.");
if let Some(msg) = self.response_rx.recv() {
return msg;
}
return "".into();
}
#[method]
fn _ready(&self, #[base] base: &GDNode) {
// The `godot_print!` macro works like `println!` but prints to the Godot-editor output tab as well.
godot_print!("Hello world from node {}!", base.to_string());
}
#[method]
fn process_user_query(&self, #[base] base: &GDNode, user_str: String) -> String {
// We could use GodotString, but there are different performance characteristics. Let's try this one!
//godot_print!("Got a call to the process_user_query endpoint with {}", &user_str);
self.make_reply(&user_str, false)
}
}
// Function that registers all exposed classes to Godot
fn init(handle: InitHandle) {
handle.add_class::<ChatBot>()
}
// Macro that creates the entry-points of the dynamic library.
godot_init!(init);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_works() {
let cb = ChatBot::load();
let result = cb.make_reply("I give you one yike:", true);
println!("{result}");
//assert_eq!(result, 0);
}
}
use ndarray::s;
use std::{f32, path::{Path, PathBuf}, str::FromStr};
use std::error::Error;
use std::io::Cursor;
use std::sync::mpsc::{channel, Receiver, RecvError, Sender};
use std::thread;
use tokenizers::tokenizer::{Result, Tokenizer};
use tract_onnx::prelude::*;
use tract_onnx::tract_hir::infer::InferenceOp;
use rand::{Rng, thread_rng};
use crate::{BATCH_SIZE, MAX_INPUT_LENGTH};
fn run_language_model(input_channel: Receiver<String>, result_channel: Sender<String>) {
let tokenizer: Tokenizer = Tokenizer::from_str(include_str!("../model/tokenizer_gpt2.json")).expect("Failed to load packed tokenizer. Library may be corrupt.");
// A little info on GPT-2:
// input1 - type: int64[input1_dynamic_axes_1,input1_dynamic_axes_2,input1_dynamic_axes_3]
// output1 - type: float32[input1_dynamic_axes_1,input1_dynamic_axes_2,input1_dynamic_axes_3,50257]
// output dims are [1, 50257].
let mut model_path = PathBuf::from_str("model").unwrap();
model_path.push(Path::new("model.onnx"));
let mut model_buf = Cursor::new(include_bytes!("../model/gpt-neo-2.onnx"));
//model: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
//model: SimplePlan<InferenceFact, Box<dyn InferenceOp>, Graph<InferenceFact, Box<dyn InferenceOp>>>,
let model = tract_onnx::onnx()
.model_for_read(&mut model_buf).expect("Unable to read from model built into binary. This indicates corruption.")
//.model_for_path(model_path).unwrap()
.with_input_fact(0, i64::fact(&[BATCH_SIZE, MAX_INPUT_LENGTH]).into()).expect("Defining input fact size on preloaded language model.")
.with_input_fact(1, i64::fact(&[BATCH_SIZE, MAX_INPUT_LENGTH]).into()).expect("Defining input mask size on preloaded language model.")
//.with_output_fact(0, f32::fact(&[axis_1_shape, axis_2_shape, axis_3_shape, axis_4_shape]).into()).unwrap()
//.into_optimized().expect("Converting packaged model to optimized build failed.")
.into_runnable().expect("Converting optimized model into runnable model failed.");
loop {
match input_channel.recv() {
Ok(msg) => {
let tokenizer_output = tokenizer.encode(text, true).expect("Unable to encode input string.");
let token_ids = tokenizer_output.get_ids();
let mask: Tensor = tract_ndarray::Array2::from_shape_fn((1, MAX_INPUT_LENGTH), |idx|{ if idx.0 < MAX_INPUT_LENGTH && idx.1 < MAX_INPUT_LENGTH { 1i64 } else { 0i64 } }).into();
let token_tensor: Tensor = tract_ndarray::Array2::from_shape_fn((1, MAX_INPUT_LENGTH),|idx| { if idx.0 < token_ids.len() { token_ids[idx.0] as i64 } else { 0 as i64 } }).into();
//let token_tensor: Tensor = tract_ndarray::Array2::from_shape_vec((1, token_ids.len()),token_ids.iter().map(|&x| x as i64).collect()).unwrap().into();
let outputs = model.run(tvec!(token_tensor, mask)).expect("Failed to run model on token tensor.");
let logits = outputs[0].to_array_view::<f32>().expect("Unable to convert tensor output to f32 array.");
let word_id = if maxent {
logits.iter().zip(0..).max_by(|a, b| a.0.partial_cmp(b.0).unwrap()).unwrap().1
} else {
let mut rng = thread_rng();
let mut energy = rng.gen::<f32>();
let mut selected_token = 0;
for (idx, token_energy_hill) in logits.iter().enumerate() {
if *token_energy_hill > energy {
selected_token = idx;
energy = 0.0;
} else {
energy -= *token_energy_hill;
}
}
selected_token as u32
};
let word = tokenizer.id_to_token(word_id).unwrap_or(" ".into());
result_channel.send(word).expect("Failed to send result.");
}
Err(_) => {
return;
}
}
}
}
pub fn start_language_model_thread() -> (Sender<String>, Receiver<String>) {
let (user_input_tx, user_input_rx) = channel();
let (ai_completion_tx, ai_completion_rx) = channel();
thread::spawn(move || run_language_model(user_input_rx, ai_completion_tx));
(user_input_tx, ai_completion_rx)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment