Skip to content

Instantly share code, notes, and snippets.

@cnlancehu
Created May 2, 2024 16:39
Show Gist options
  • Save cnlancehu/8efa708f371360f9086a5dd62b01f7cb to your computer and use it in GitHub Desktop.
Save cnlancehu/8efa708f371360f9086a5dd62b01f7cb to your computer and use it in GitHub Desktop.
extern crate intel_mkl_src;
use std::io::Write;
use candle_core::quantized::gguf_file;
use candle_core::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
mod token_output_stream;
use token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama::ModelWeights as Phi3;
use tokenizers::Tokenizer;
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
const TEMPERATURE: f64 = 1.0;
const REPEAT_PENALTY: f64 = 1.0;
const REPEAT_LAST_N: usize = 0;
const TOP_P: f64 = 0.9;
const SPLIT_PROMPT: bool = true;
fn main() -> anyhow::Result<()> {
let model_path: &str = "D:\\code\\models\\Phi-3-mini-4k-instruct\\Phi-3-mini-4k-instruct-q4.gguf";
let mut file = std::fs::File::open(&model_path)?;
let device = Device::Cpu;
let mut model: Phi3 = {
let model = gguf_file::Content::read(&mut file).map_err(|e: candle_core::Error| e.with_path(model_path))?;
Phi3::from_gguf(model, &mut file, &device)?
};
println!("model built");
let tokenizer = Tokenizer::from_file("D:\\code\\models\\Phi-3-mini-4k-instruct\\tokenizer.json").map_err(anyhow::Error::msg)?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt_str = DEFAULT_PROMPT;
print!("{}", &prompt_str);
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
let tokens = tokens.get_ids();
let to_sample = tokens.len().saturating_sub(1);
let mut all_tokens = vec![];
let mut logits_processor = LogitsProcessor::new(0, Some(TEMPERATURE), Some(TOP_P));
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if SPLIT_PROMPT {
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = *tos
.tokenizer()
.get_vocab(true)
.get("<|endoftext|>")
.unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if REPEAT_PENALTY == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(REPEAT_LAST_N);
candle_transformers::utils::apply_repeat_penalty(
&logits,
REPEAT_PENALTY as f32,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle_core::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
tokens.len(),
tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment