Skip to content

Instantly share code, notes, and snippets.

@simonw

simonw/lm.rs.md Secret

Created October 11, 2024 19:30
Show Gist options
  • Save simonw/19ce7d66bcd9a9efc46e25354a2f5b3c to your computer and use it in GitHub Desktop.
Save simonw/19ce7d66bcd9a9efc46e25354a2f5b3c to your computer and use it in GitHub Desktop.

2024-10-11T19:25:19 conversation: 01j9ygp426wmx6prmt02e0km0h

Model: claude-3-5-sonnet-20240620

Prompt:

files-to-prompt src Cargo.toml -c
<documents>
<document index="1">
<source>src/functional.rs</source>
<document_content>

use crate::quantization::{QuantizedTensor, MutableQuantizedTensor};

use std::{convert::TryInto};
use rayon::prelude::*;
use wide::{f32x8, i32x8};

// Some helper functions 

pub fn slice_to_u32(slice: &[u8]) -> u32 {
    assert!(slice.len() == 4, "Slice must be exactly 4 bytes long");
    u32::from_ne_bytes(slice.try_into().expect("Slice with incorrect length"))
}

pub fn slice_to_f32(slice: &[u8]) -> f32 {
    assert!(slice.len() == 4, "Slice must be exactly 4 bytes long");
    f32::from_ne_bytes(slice.try_into().expect("Slice with incorrect length"))
}

pub fn u8_to_f32_slice(data: &[u8]) -> &[f32] {
    let (prefix, f32data, suffix) = unsafe { data.align_to::<f32>() };
    assert!(prefix.is_empty(), "Data was not aligned correctly");
    assert!(suffix.is_empty(), "Data was not aligned correctly");
    f32data
}

pub fn u8_to_i8_slice(data: &[u8]) -> &[i8] {
    let (prefix, i8data, suffix) = unsafe { data.align_to::<i8>() };
    assert!(prefix.is_empty(), "Data was not aligned correctly");
    assert!(suffix.is_empty(), "Data was not aligned correctly");
    i8data
}

pub fn random_u32(mut state: u64) -> u32 {
    state ^= state >> 12;
    state ^= state << 25;
    state ^= state >> 27;

    ((state * 0x2545F4914F6CDD1Du64) >> 32) as u32
}

pub fn random_f32(state: u64) -> f32 { 
    (random_u32(state) >> 8) as f32 / 16777216.0f32
}

// Functions used in NNs

pub fn rmsnorm(o: &mut [f32], x: &[f32], weight: &[f32], size: usize, eps: f32, add_unit_offset: bool) {
    let n_simd = size/8;

    let mut ss_sim = f32x8::ZERO;

    for j in 0..n_simd {
        let x_vec = f32x8::from(&x[j*8..j*8+8]); 
        ss_sim += x_vec * x_vec;
    } 

    let mut ss = ss_sim.reduce_add();

    ss /= size as f32;
    ss += eps;
    ss = 1.0 / ss.sqrt();

    for j in 0..n_simd {
        let x_vec = f32x8::from(&x[j*8..j*8+8]);
        let w_vec = f32x8::from(&weight[j*8..j*8+8]);
        
        let r = if add_unit_offset {
            ((1.0 + w_vec) * (ss * x_vec)).to_array()
        } else {
            (w_vec * (ss * x_vec)).to_array()
        };

        for k in 0..8 {
            o[(j*8) + k] = r[k];
        } 
    } 
}

pub fn layernorm(o: &mut [f32], x: &[f32], weight: &[f32], bias: &[f32], size: usize, eps: f32) {
    let n_simd = size / 8;

    let mut mean_sim = f32x8::ZERO;
    let mut var_sim = f32x8::ZERO;

    for j in 0..n_simd {
        let x_vec = f32x8::from(&x[j * 8..j * 8 + 8]);
        mean_sim += x_vec;
    }

    let mean = mean_sim.reduce_add() / size as f32;

    for j in 0..n_simd {
        let x_vec = f32x8::from(&x[j * 8..j * 8 + 8]);
        let diff = x_vec - f32x8::splat(mean);
        var_sim += diff * diff;
    }

    let variance = var_sim.reduce_add() / size as f32 + eps;
    let inv_std = 1.0 / variance.sqrt();

    for j in 0..n_simd {
        let x_vec = f32x8::from(&x[j * 8..j * 8 + 8]);
        let w_vec = f32x8::from(&weight[j * 8..j * 8 + 8]);
        let b_vec = f32x8::from(&bias[j * 8..j * 8 + 8]);

        let normalized = (x_vec - f32x8::splat(mean)) * f32x8::splat(inv_std);
        let r = (normalized * w_vec + b_vec).to_array(); 

        for k in 0..8 {
            o[(j * 8) + k] = r[k];
        }
    }
}

pub fn softmax(x: &mut [f32]){
    let mut sum: f32 = 0.0;
    let mut max_val: f32 = x[0];

    for i in x.iter() {
        if *i > max_val {
            max_val = *i;
        }
    }

    for i in x.iter_mut() {
        *i = (*i - max_val).exp();
        sum += *i;
    }
    
    for i in x.iter_mut() {
        *i /= sum;
    } 
}

pub fn matmul(xout: &mut [f32], x: &[f32], w: &[f32]) {
    let n = x.len();
    let n_simd = n / 8;
    
    let rest = n_simd * 8;

    xout.par_iter_mut().enumerate().for_each(|(i, val)| {
        let mut sum = f32x8::ZERO;
        let mut final_sum: f32 = 0.0;
        let w_slice = &w[i * n..i * n + n];

        for j in 0..n_simd {
            let x_vec = f32x8::from(&x[j*8..j*8+8]);
            let w_vec = f32x8::from(&w_slice[j*8..j*8+8]);
            sum += w_vec * x_vec;
        }

        final_sum += sum.reduce_add();
        
        for r in rest..n {
            final_sum += w_slice[r] * x[r];
        }

        *val = final_sum;
    });
}

pub fn matmul_q8(xout: &mut [f32], x: &MutableQuantizedTensor, w: &QuantizedTensor, n: usize, gs: usize) {
    let n_simd = gs / 8;
    let rest = n_simd * 8;
    
    xout.par_iter_mut().enumerate().for_each(|(i, xout_elem)| {
        let ni: usize = i * n;

        *xout_elem = (0..=(n - gs)).step_by(gs).map(|j| {
            let mut ival = i32x8::ZERO;
            let mut sum: f32 = 0.0;

            for k in 0..n_simd {
                let x_vec = i32x8::from(&x.q[j+k*8..j+k*8+8]);
                let w_vec = i32x8::from(&w.q[ni+j+k*8..ni+j+k*8+8]);

                ival += x_vec * w_vec;
            }

            sum += ival.reduce_add() as f32;
            
            for r in rest..gs {
                sum += (x.q[j+r] as i32 * w.q[ni+j+r] as i32) as f32;
            }
            
            sum *= w.s[(ni + j) / gs] * x.s[j / gs];

            sum
        }).sum();
    });
}

pub fn matmul_q4(xout: &mut [f32], x: &MutableQuantizedTensor, w: &QuantizedTensor, n: usize, gs: usize) {
    let group_size = gs / 2;
    let n_simd = group_size / 8;
    let rest = n_simd * 8;

    let mask_a = i32x8::new([0x0F; 8]);
    let mask_b = i32x8::new([0xF0; 8]);
    
    xout.par_iter_mut().enumerate().for_each(|(i, xout_elem)| {
        let ni: usize = i * n / 2;

        *xout_elem = (0..=(n/2 - group_size)).step_by(group_size).map(|j| {
            let mut ival = i32x8::ZERO;
            let mut sum: f32 = 0.0;

            for k in 0..n_simd {
                let x_vec = i32x8::from(&x.q[j+k*8..j+k*8+8]);
                let w_vec = i32x8::from(&w.q[ni+j+k*8..ni+j+k*8+8]);

                let x_a = (x_vec & mask_a) - 8;
                let w_a = (w_vec & mask_a) - 8;
                
                let x_b = (mask_a & ((x_vec & mask_b) >> 4)) - 8;
                let w_b = (mask_a & ((w_vec & mask_b) >> 4)) - 8;

                ival += x_a * w_a;
                ival += x_b * w_b;
            }

            sum += ival.reduce_add() as f32;

            for r in rest..group_size {
                let x_vec = x.q[j+r] as i32;
                let w_vec = w.q[ni+j+r] as i32;

                let x_a = (x_vec & 0x0F) - 8;
                let w_a = (w_vec & 0x0F) - 8;
                
                let x_b = (0x0F & ((x_vec & 0xF0) >> 4)) - 8;
                let w_b = (0x0F & ((w_vec & 0xF0) >> 4)) - 8;

                sum += (x_a * w_a) as f32;
                sum += (x_b * w_b) as f32;
            }
            
            sum *= w.s[(ni + j) / group_size] * x.s[j / group_size];

            sum
        }).sum();
    });
}

pub fn matmul_conv(xout: &mut [f32], x: &[f32], w: &[f32], n: usize, patches_per_row: u32) {
    let n_simd = n / 8;
    let rest = n_simd * 8;
    let x_len = x.len();
    
    xout.par_iter_mut().enumerate().for_each(|(i, xout_elem)| {
        let ni: usize = (i / (patches_per_row*patches_per_row) as usize) * n;
        let x_idx = i*n % x_len;

        let mut val = f32x8::ZERO;
        let mut sum: f32 = 0.0;

        for k in 0..n_simd {
            let x_vec = f32x8::from(&x[x_idx+k*8..x_idx+k*8+8]);
            let w_vec = f32x8::from(&w[ni+k*8..ni+k*8+8]);

            val += x_vec * w_vec;
        }

        sum += val.reduce_add();
        
        for r in rest..n {
            sum += x[x_idx+r] * w[ni+r];
        }

        *xout_elem = sum;
    });
}

pub fn concat<T: Clone>(arr0: &[T], arr1: &[T]) -> Vec<T> {
    let mut concat_arr: Vec<T> = Vec::new();

    concat_arr.extend_from_slice(arr0);
    concat_arr.extend_from_slice(arr1);

    concat_arr
}
s_content>
</document>
<document index="2">
<source>src/lib.rs</source>
<document_content>
pub mod tokenizer;
pub mod functional;
pub mod transformer;
pub mod sampler;
pub mod quantization;
#[cfg(any(feature = "multimodal", feature="backend-multimodal"))]
pub mod vision;
#[cfg(any(feature = "multimodal", feature="backend-multimodal"))]
pub mod processor;
</document_content>
</document>
<document index="3">
<source>src/processor.rs</source>
<document_content>
use crate::quantization::{QuantizedTensor, QuantType, MutableQuantizedTensor, quantize, quantize_q4};
use crate::transformer::{init_param, init_param_quant};
use crate::functional::{matmul, matmul_q8, matmul_q4, concat};

use std::mem::MaybeUninit;
use std::alloc::dealloc;
use std::alloc::Layout;

use image::imageops::resize;
use image::{ImageBuffer, Rgb};
use wide::f32x8;
use rayon::prelude::*;

fn transpose_img(image: &[u8], width: u32, height: u32) -> Vec<u8> {
    let mut transposed_image = Vec::new();

    for x in 0..width {
        for y in 0..height {
            transposed_image.push(image[((x*3) + ((y*3) * width)) as usize]);
            transposed_image.push(image[((x*3) + 1 + ((y*3) * width)) as usize]);
            transposed_image.push(image[((x*3) + 2 + ((y*3) * width)) as usize]);
        }
    }

    return transposed_image;
}

fn normalize_img(image: &[u8], mean: [f32; 3], std: [f32; 3]) -> Vec<f32> {
    let mut normalized_image = Vec::new();

    for x in (0..image.len()).step_by(3) {
        let r = image[x];
        let g = image[x+1];
        let b = image[x+2];

        let norm_r = (((r as f32)/255.0) as f32 - mean[0])/std[0];
        let norm_g = ((g as f32)/255.0 - mean[1])/std[1];
        let norm_b = ((b as f32)/255.0 - mean[2])/std[2];

        normalized_image.push(norm_r);
        normalized_image.push(norm_g);
        normalized_image.push(norm_b);
    }

    return normalized_image;
}

fn pad_img(image: &[u8], width: u32, height: u32, pad_left: u32, pad_top: u32, pad_right: u32, pad_bottom: u32, value: [u8; 3]) -> Vec<u8> {
    let mut padded_image = Vec::new();
    
    for _p in 0..pad_top {
        for _x in 0..width {
            padded_image.push(value[0]);
            padded_image.push(value[1]);
            padded_image.push(value[2]);
        }
    }

    for y in 0..height {
        for _p in 0..pad_left {
            padded_image.push(value[0]);
            padded_image.push(value[1]);
            padded_image.push(value[2]);
        }

        for x in 0..width {
            padded_image.push(image[((x*3) + ((y*3) * width)) as usize]);
            padded_image.push(image[((x*3) + 1 + ((y*3) * width)) as usize]);
            padded_image.push(image[((x*3) + 2 + ((y*3) * width)) as usize]);
        }
        
        for _p in 0..pad_right {
            padded_image.push(value[0]);
            padded_image.push(value[1]);
            padded_image.push(value[2]);
        }
    }
    
    for _p in 0..pad_bottom {
        for _x in 0..width {
            padded_image.push(value[0]);
            padded_image.push(value[1]);
            padded_image.push(value[2]);
        }
    }

    padded_image
}

fn crop_img(img: &[u8], size: (u32, u32), num_crops: u32) -> Vec<u8> {
    let mut cropped_img: Vec<u8> = Vec::new();
    let crop_size = 336;
    let crops_per_side_x = size.0/crop_size;

    println!("{:?}", size);
    println!("{}", num_crops);

    for c in 0..num_crops {
        let grid_y = c/crops_per_side_x;
        let grid_x = c%crops_per_side_x;

        for y in 0..crop_size {
            for x in (0..crop_size*3).step_by(3) {
                cropped_img.push(img[((grid_x*3)*crop_size + x + (y*3)*size.0 + (grid_y*3)*crop_size*size.0) as usize]);
                cropped_img.push(img[((grid_x*3)*crop_size + x + 1 + (y*3)*size.0 + (grid_y*3)*crop_size*size.0) as usize]);
                cropped_img.push(img[((grid_x*3)*crop_size + x + 2 + (y*3)*size.0 + (grid_y*3)*crop_size*size.0) as usize]);
            }
        }
    }

    cropped_img
}

fn view_as_patches(img: &[f32], size: u32, patch_size: u32, num_crops: u32) -> Vec<f32> {
    let mut out: Vec<f32> = Vec::new();
    let patches_per_row = size/patch_size;

    for c in 0..num_crops {
        for y in 0..patches_per_row {
            for x in 0..patches_per_row {
                let mut r: Vec<f32> = Vec::new();
                let mut g: Vec<f32> = Vec::new();
                let mut b: Vec<f32> = Vec::new();
                for py in 0..patch_size {
                    for px in (0..patch_size*3).step_by(3) {
                        r.push(img[(c*3*size*size) as usize + ((x*3)*patch_size + px) as usize + ((y*3)*size*patch_size + (py*3)*size) as usize]);
                        g.push(img[(c*3*size*size) as usize + ((x*3)*patch_size + px + 1) as usize + ((y*3)*size*patch_size + (py*3)*size) as usize]);
                        b.push(img[(c*3*size*size) as usize + ((x*3)*patch_size + px + 2) as usize + ((y*3)*size*patch_size + (py*3)*size) as usize]);
                    }
                }
                out.extend(r);
                out.extend(g);
                out.extend(b);
            }
        }
    }

    out
}

#[repr(C, packed)]
#[derive(Debug, Copy, Clone)]
struct ProcessorArgs {
    hidden_dim: u32,
    text_dim: u32,
    q_type: QuantType,
    group_size: u32,
}

struct ProcessorWeights<'a> {
    glb_gn: &'a [f32],
    sub_gn: &'a [f32],
    
    img_projection0: MaybeUninit<&'a [f32]>,
    img_projection0_bias: &'a [f32],
    img_projection1: MaybeUninit<&'a [f32]>,
    img_projection1_bias: &'a [f32],
    
    img_projection0_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,
    img_projection1_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,
}

pub struct PHI3VProcessor<'a> {
    weights: ProcessorWeights<'a>,
    args: ProcessorArgs
}

impl<'a> PHI3VProcessor<'a> {
    pub fn new(data: &'a [u8]) -> PHI3VProcessor<'a> { 
        let (head, body, _) = unsafe { data[..13].align_to::<ProcessorArgs>() };

        assert!(head.is_empty(), "Data was not aligned");
        
        let cfg = &body[0];

        let mut offset: usize = 128;

        let hidden_dim = cfg.hidden_dim;
        let text_dim = cfg.text_dim;
        let quantized = cfg.q_type != QuantType::None;

        let glb_gn = init_param(data, &mut offset, 1, hidden_dim);
        let sub_gn = init_param(data, &mut offset, 1, hidden_dim);

        if !quantized {
            let img_projection0 = init_param(data, &mut offset, 1, hidden_dim * text_dim);
            let img_projection1 = init_param(data, &mut offset, 1, text_dim * text_dim);
            let img_projection0_bias = init_param(data, &mut offset, 1, text_dim);
            let img_projection1_bias = init_param(data, &mut offset, 1, text_dim);

            let weights = ProcessorWeights {
                glb_gn,
                sub_gn,
                img_projection0: MaybeUninit::new(img_projection0),
                img_projection1: MaybeUninit::new(img_projection1),
                img_projection0_quant: MaybeUninit::uninit(),
                img_projection1_quant: MaybeUninit::uninit(),
                img_projection0_bias,
                img_projection1_bias
            };

            return PHI3VProcessor {
                weights,
                args: *cfg
            }
        }

        println!("Loading processor weights...");

        let img_projection0_quant = init_param_quant(data, &mut offset, 1, text_dim * hidden_dim, cfg.group_size, cfg.q_type);
        let img_projection1_quant = init_param_quant(data, &mut offset, 1, text_dim * text_dim, cfg.group_size, cfg.q_type);
        let img_projection0_bias = init_param(data, &mut offset, 1, text_dim);
        let img_projection1_bias = init_param(data, &mut offset, 1, text_dim);
        
        let weights = ProcessorWeights {
            glb_gn,
            sub_gn,
            img_projection0: MaybeUninit::uninit(),
            img_projection1: MaybeUninit::uninit(),
            img_projection0_quant: MaybeUninit::new(img_projection0_quant),
            img_projection1_quant: MaybeUninit::new(img_projection1_quant),
            img_projection0_bias,
            img_projection1_bias
        };

        println!("Done.\n");

        return PHI3VProcessor {
            weights,
            args: cfg.clone()
        }
    }
    
    pub fn forward(&self, out_patches: &[f32], new_shape: u32, patch_side: u32, w_crop: u32, h_crop: u32) -> Vec<f32> {
        let p = self.args;
        let w = &self.weights;
        let hidden_dim = p.hidden_dim;

        let quantized = p.q_type != QuantType::None;

        let mut global_features = PHI3VProcessor::reshape_hd_patches_2x2merge(&out_patches[..new_shape as usize], 1, 1);
        PHI3VProcessor::add_image_newline(&mut global_features, w.sub_gn, patch_side as usize, patch_side as usize, hidden_dim as usize);
        
        let mut sub_image_features = PHI3VProcessor::reshape_hd_patches_2x2merge(&out_patches[new_shape as usize..], h_crop as usize, w_crop as usize);
        PHI3VProcessor::add_image_newline(&mut sub_image_features, w.sub_gn, (h_crop*patch_side) as usize, (w_crop*patch_side) as usize, hidden_dim as usize);

        let sub_len = sub_image_features.len();
        let glb_len = global_features.len();

        let mut out_embeddings = Vec::with_capacity(sub_len + glb_len + hidden_dim as usize);

        out_embeddings.extend(sub_image_features);
        out_embeddings.extend(w.glb_gn);
        out_embeddings.extend(global_features);
        
        let num_embeds = (h_crop * patch_side) * ((w_crop * patch_side + 1)) + (patch_side*(patch_side + 1)) + 1;
        
        let out_shape = num_embeds * p.text_dim;

        let mut out_features = vec![0.0; out_shape as usize];

        // Img projection mlp with GELU activation
        out_features.par_chunks_mut(p.text_dim as usize).enumerate().for_each( |(h, xb)| {
            let mut hidden_emb = vec![0.0; p.text_dim as usize];

            unsafe {
                if !quantized {
                    matmul(&mut hidden_emb, &out_embeddings[((h as u32*hidden_dim)) as usize..((h as u32*hidden_dim) + hidden_dim) as usize], &w.img_projection0.assume_init());
                } else {
                    let mut sxq = MutableQuantizedTensor { q: &mut vec![0; (hidden_dim) as usize], s: &mut vec![0.0; hidden_dim as usize]};

                    if p.q_type == QuantType::Q8_0 {
                        quantize(&mut sxq, &out_embeddings[((h as u32*hidden_dim)) as usize..((h as u32*hidden_dim) + hidden_dim) as usize], hidden_dim as usize, p.group_size);
                        
                        matmul_q8(&mut hidden_emb, &sxq, &w.img_projection0_quant.assume_init()[0], hidden_dim as usize, p.group_size as usize);
                    } else if p.q_type == QuantType::Q4_0 {
                        quantize_q4(&mut sxq, &out_embeddings[((h as u32*hidden_dim)) as usize..((h as u32*hidden_dim) + hidden_dim) as usize], hidden_dim as usize, p.group_size);
                        
                        matmul_q4(&mut hidden_emb, &sxq, &w.img_projection0_quant.assume_init()[0], hidden_dim as usize, p.group_size as usize);
                    }
                }
            }
            
            // Add bias
            let mut n_simd = p.text_dim/8;
            
            for k in 0..n_simd {
                let w1_bias_vec = f32x8::from(&w.img_projection0_bias[(k*8) as usize..(k*8+8) as usize]);

                let mut x1_vec = f32x8::from(&hidden_emb[(k*8) as usize..(k*8+8) as usize]);

                x1_vec += w1_bias_vec;
                
                let x1 = x1_vec.to_array();

                for j in 0..8 {
                    let idx = (k*8 + j) as usize;

                    hidden_emb[idx] = x1[j as usize];
                    
                    // GELU
                    hidden_emb[idx] *= 0.5 * (1.0 + ((0.7978845608028654 * (hidden_emb[idx] + 0.044715 * hidden_emb[idx] * hidden_emb[idx] * hidden_emb[idx]) as f64).tanh()) as f32);   
                }
            }
            
            unsafe {
                if !quantized {
                    matmul(xb, &hidden_emb, &w.img_projection1.assume_init());
                } else {
                    let mut sxq = MutableQuantizedTensor { q: &mut vec![0; (p.text_dim) as usize], s: &mut vec![0.0; p.text_dim as usize]};

                    if p.q_type == QuantType::Q8_0 {
                        quantize(&mut sxq, &hidden_emb, p.text_dim as usize, p.group_size);
                        
                        matmul_q8(xb, &sxq, &w.img_projection1_quant.assume_init()[0], p.text_dim as usize, p.group_size as usize);
                    } else if p.q_type == QuantType::Q4_0 {
                        quantize_q4(&mut sxq, &hidden_emb, p.text_dim as usize, p.group_size);
                        
                        matmul_q4(xb, &sxq, &w.img_projection1_quant.assume_init()[0], p.text_dim as usize, p.group_size as usize);
                    }
                }
            }

            n_simd = p.text_dim/8;
            
            for k in 0..n_simd {
                let w2_bias_vec = f32x8::from(&w.img_projection1_bias[(k*8) as usize..(k*8+8) as usize]);

                let mut x2_vec = f32x8::from(&xb[(k*8) as usize..(k*8+8) as usize]);

                x2_vec += w2_bias_vec;
                
                let x2 = x2_vec.to_array();

                for j in 0..8 {
                    xb[(k*8 + j) as usize] = x2[j as usize];
                }
            }
        });

        out_features
    }

    pub fn process(&self, pixels: &[u8], width: u32, height: u32, patch_size: u32, mut num_crops: u32) -> (Vec<f32>, u32, u32, u32) {
        let mean = [0.48145466, 0.4578275, 0.40821073];
        let std = [0.26862954, 0.26130258, 0.27577711];

        let (transposed_image, new_w, new_h) = PHI3VProcessor::hd_transform(pixels, width, height, num_crops);
        
        let resized_img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_raw(new_w, new_h, transposed_image.clone()).expect("Failed to create image");

        let global_img = resize(&resized_img, 336, 336, image::imageops::FilterType::Nearest);

        let cropped_img;

        let w_crop = new_w/336;
        let h_crop = new_h/336;

        num_crops = w_crop*h_crop;

        if num_crops > 2 {
            cropped_img = crop_img(&transposed_image, (new_w, new_h), num_crops);
        } else {
            cropped_img = transposed_image;
        }

        num_crops += 1;

        let final_img = concat::<u8>(&global_img, &cropped_img);
        let normalized_img = normalize_img(&final_img, mean, std);
        
        let patches = view_as_patches(&normalized_img, 336, patch_size, num_crops);

        (patches, w_crop, h_crop, num_crops)
    }
    
    fn reshape_hd_patches_2x2merge(image_features: &[f32], h_crop: usize, w_crop: usize) -> Vec<f32> {
        const C: usize = 1024;
        const H: usize = 24;
        const L: usize = H * H;
        
        let n = image_features.len() / (L * C);
        let num_images = n / (h_crop * w_crop);
        let out_h = h_crop * H / 2;
        let out_w = w_crop * H / 2;
        let out_c = 4 * C;
        
        let mut result = vec![0.0; num_images * out_h * out_w * out_c];
        
        for img_idx in 0..num_images {
            for hc in 0..h_crop {
                for wc in 0..w_crop {
                    let patch_idx = img_idx * h_crop * w_crop + hc * w_crop + wc;
                    
                    for i in 0..H/2 {
                        for j in 0..H/2 {
                            let mut merged_features = Vec::with_capacity(4 * C);
                            for di in 0..2 {
                                for dj in 0..2 {
                                    let old_y = i * 2 + di;
                                    let old_x = j * 2 + dj;
                                    let old_start = patch_idx * L * C + (old_y * H + old_x) * C;
                                    merged_features.extend_from_slice(&image_features[old_start..old_start + C]);
                                }
                            }
                            
                            let new_h = hc * H/2 + i;
                            let new_w = wc * H/2 + j;
                            let new_idx = ((img_idx * out_h + new_h) * out_w + new_w) * out_c;
                            result[new_idx..new_idx + out_c].copy_from_slice(&merged_features);
                        }
                    }
                }
            }
        }
        
        result
    }
    
    fn padding_336(image: &[u8], width: u32, height: u32) -> (Vec<u8>, u32){
        let tar = ((height as f32 / 336.0).ceil() * 336.0) as u32;
        let top_padding = (tar - height)/2;
        let bottom_padding = tar - height - top_padding;
        
        let b = pad_img(image, width, height, 0, top_padding, 0, bottom_padding, [255, 255, 255]);

        (b, height + top_padding + bottom_padding)
    }

    fn hd_transform(img: &[u8], width: u32, height: u32, hd_num: u32) -> (Vec<u8>, u32, u32) {
        let mut trans = false;
        let mut new_img = Vec::from(img);
        let mut new_width = width;
        let mut new_height = height;

        if width < height {
            new_img = transpose_img(img, width, height);
            trans = true;
            new_width = height;
            new_height = width;
        }

        let ratio: f32 = new_width as f32 / new_height as f32;
        let mut scale: f32 = 1.0;

        while scale*(scale/ratio).ceil() <= hd_num as f32 {
            scale += 1.0;
        }

        scale -= 1.0;

        let mut new_w = (scale * 336.0) as u32;
        let mut new_h = (new_w as f32 / ratio) as u32;

        let img_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_raw(new_width, new_height, new_img).expect("Failed to create image");
        let resized_img = resize(&img_buffer, new_w, new_h, image::imageops::FilterType::Nearest);

        (new_img, new_h) = PHI3VProcessor::padding_336(resized_img.as_raw() as &[u8], new_w, new_h);

        if trans {
            new_img = transpose_img(&new_img, new_w, new_h);
            let temp_w = new_w;
            new_w = new_h;
            new_h = temp_w;
        }

        (new_img, new_w, new_h)
    }

    fn insert_slice_at_position(vec: &mut Vec<f32>, index: usize, slice: &[f32]) {
        vec.reserve(slice.len());

        let tail = vec.split_off(index);

        vec.extend(slice);

        vec.extend(tail);
    }

    fn add_image_newline(img: &mut Vec<f32>, separator: &[f32], h: usize, w: usize, dim: usize) {
        for i in 0..h {
            PHI3VProcessor::insert_slice_at_position(img, i*w*dim + i*dim + w*dim, separator);
        }
    }
}

impl<'a> Drop for PHI3VProcessor<'a> {
    fn drop(&mut self) {
        if self.args.q_type != QuantType::None {
            unsafe {
                let weights_layout = Layout::array::<QuantizedTensor>(1).unwrap();
                dealloc(self.weights.img_projection0_quant.assume_init().as_ptr() as *mut u8, weights_layout);
                dealloc(self.weights.img_projection1_quant.assume_init().as_ptr() as *mut u8, weights_layout);
            }
        }
    }
}
</document_content>
</document>
<document index="4">
<source>src/quantization.rs</source>
<document_content>
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum QuantType {
    None = 0,
    Q8_0 = 1,
    Q4_0 = 2,
}

pub struct QuantizedTensor<'a>{
    pub q: &'a [i8],
    pub s: &'a [f32],
}
pub struct MutableQuantizedTensor<'a>{
    pub q: &'a mut [i8],
    pub s: &'a mut [f32],
}

fn unpack(value: i8) -> (i8, i8) {
    let a: i8 = (value & 0x0F) - 8;

    let b: i8 = (0x0F & ((value & 0xF0u8 as i8) >> 4)) - 8;

    (a, b)
}

pub fn dequantize(qx: &QuantizedTensor, x: &mut [f32], n: usize, gs: u32, q_type: QuantType) {
    match q_type {
        QuantType::Q8_0 => { 
            for (i, value) in x.iter_mut().enumerate().take(n) {
                *value = qx.q[i] as f32 * qx.s[(i as u32 / gs) as usize];
            }
        },
        QuantType::Q4_0 => { 
            for i in 0..(n/2) {
                let (a, b) = unpack(qx.q[i]);
                let scale = qx.s[((i*2) as u32 / gs) as usize];
                x[i*2] = a as f32 * scale;
                x[i*2+1] = b as f32 * scale;
            }
        },
        _ => (),
    }  
}

pub fn quantize(qx: &mut MutableQuantizedTensor, x: & [f32], n: usize, gs: u32) {
    let num_groups: u32 = n as u32 / gs;
    let q_max: f32 = 127.0f32;

    for group in 0..num_groups {
        let mut wmax: f32 = 0.0;
        for i in 0..gs {
            let val: f32 = x[(group * gs + i) as usize].abs();
            if val > wmax {
                wmax = val;
            }
        }

        let scale = wmax / q_max;
        
        qx.s[group as usize] = scale;

        for i in 0..gs {
            let quant_value = x[(group * gs + i) as usize] / scale;
            let quantized: i8 = quant_value.round() as i8;
            qx.q[(group * gs + i) as usize] = quantized;
        }
    }
}

pub fn quantize_q4(qx: &mut MutableQuantizedTensor, x: & [f32], n: usize, gs: u32) {
    let num_groups: u32 = n as u32 / gs;
    let q_max: f32 = -8.0f32;

    for group in 0..num_groups {
        let mut wmax: f32 = 0.0;
        for i in 0..gs {
            let val: f32 = x[(group * gs + i) as usize].abs();
            if val > wmax {
                wmax = val;
            }
        }

        let scale = wmax / q_max;
        
        qx.s[group as usize] = scale;

        for i in 0..(gs/2) {
            let quant_value_a = x[(group * gs + i*2) as usize] / scale;
            let quant_value_b = x[(group * gs + i*2 + 1) as usize] / scale;
            let quantized_a: i8 = ((quant_value_a + 8.0).round() as u8).clamp(0, 15) as i8;
            let quantized_b: i8 = ((quant_value_b + 8.0).round() as u8).clamp(0, 15) as i8;
        
            qx.q[(group * gs / 2 + i) as usize] = quantized_a | (quantized_b << 4);
        }
    }
}
</document_content>
</document>
<document index="5">
<source>src/sampler.rs</source>
<document_content>
use crate::functional::softmax;
use crate::functional::random_f32;

#[derive(Debug, Copy, Clone)]
struct ProbIndex {
    prob: f32,
    index: u32,
}

pub struct Sampler {
    vocab_size: u32,
    probindex: Vec<ProbIndex>,
    temperature: f32,
    top_p: f32,
    seed: u64,
}

impl Sampler {
    pub fn new(vocab_size: u32, temperature: f32, top_p: f32, seed: u64) -> Sampler {
        Sampler {
            vocab_size,
            probindex: vec![ProbIndex { prob: 0.0, index: 0 }; vocab_size as usize],
            temperature,
            top_p,
            seed
        }
    }
    
    fn sample_argmax(probabilities: &[f32]) -> u32 {
        let mut max_i: u32 = 0;
        let mut max_p = probabilities[0];

        for (i, p) in probabilities.iter().enumerate().skip(1) {
            if *p > max_p {
                max_i = i as u32;
                max_p = *p;
            }
        }

        max_i
    }

    fn sample_mult(probabilities: &[f32], rand: f32) -> u32 {
        let mut cdf: f32 = 0.0;
        let n = probabilities.len();

        for (i, p) in probabilities.iter().enumerate() {
            cdf += *p;
            if rand < cdf {
                return i as u32;
            }
        }

        (n - 1) as u32
    }

    fn compare(a: &ProbIndex, b: &ProbIndex) -> std::cmp::Ordering {
        if a.prob > b.prob {
            std::cmp::Ordering::Less
        } else if a.prob < b.prob {
            std::cmp::Ordering::Greater
        } else {
            std::cmp::Ordering::Equal
        }
    }

    fn sample_topp(&mut self, probabilities: &[f32], top_p: f32, rand: f32) -> u32 {
        let n = probabilities.len();
        let mut n0 = 0;

        let  cutoff: f32 = (1.0f32 - top_p) / (n - 1) as f32;

        for (i, p) in probabilities.iter().enumerate() {
            if *p >= cutoff {
                self.probindex[n0].index = i as u32;
                self.probindex[n0].prob = *p;
                n0 += 1;
            }
        }
        
        self.probindex.sort_by(Sampler::compare);

        let mut cumulative_prob: f32 = 0.0;

        let mut last_idx = n0 - 1;

        for i in 0..n0 {
            cumulative_prob += self.probindex[i].prob;
            if cumulative_prob > top_p {
                last_idx = i;
                break;
            }
        }

        let r = rand * cumulative_prob;
        let mut cdf: f32 = 0.0;

        for i in 0..last_idx+1 {
            cdf += self.probindex[i].prob;
            if r < cdf {
                return self.probindex[i].index;
            }
        }

        self.probindex[last_idx].index
    }


    pub fn sample(&mut self, logits: &mut [f32]) -> u32 {
        let next: u32;
        
        if self.temperature == 0.0f32 {
            next = Sampler::sample_argmax(logits);
        } else {
            for q in 0..self.vocab_size { logits[q as usize] /= self.temperature; }

            softmax(logits);

            let rand: f32 = random_f32(self.seed);

            if self.top_p <= 0.0 || self.top_p >= 1.0 {
                next = Sampler::sample_mult(logits, rand);
            } else {
                next = self.sample_topp(logits, self.top_p, rand);
            }
        }

        next
    }
}

</document_content>
</document>
<document index="6">
<source>src/tokenizer.rs</source>
<document_content>
use crate::functional::slice_to_u32;
use crate::functional::slice_to_f32;
use crate::transformer::ModelType;
use std::fs;

#[derive(Debug, Clone)]
struct TokenIndex {
    text: String,
    id: u32,
}

pub struct Tokenizer {
    vocab_size: u32,
    vocab: Vec<String>,
    pub bos: u32,
    pub eos: u32,
    vocab_scores: Vec<f32>,
    sorted_vocab: Vec<TokenIndex>,
    //For now i don't use this, only allow seqs of max this size, future work
    //max_token_len: u32, 
}

impl Tokenizer {
    pub fn new(path: &str) -> Tokenizer {
        let data: Vec<u8> = fs::read(path).expect("Error reading tokenizer file.");

        let vocab_size = slice_to_u32(&data[0..4]);
        //let max_token_len = slice_to_u32(&data[4..8]);
        let bos = slice_to_u32(&data[8..12]);
        let eos = slice_to_u32(&data[12..16]);
        let mut vocab: Vec<String> = vec![];
        let mut vocab_scores: Vec<f32> = vec![];
        let sorted_vocab: Vec<TokenIndex> = vec![];

        let mut offset: usize = 16;

        for _ in 0..vocab_size {
            let score = slice_to_f32(&data[offset..offset + 4]);

            vocab_scores.push(score);

            offset += 4;

            let str_len = slice_to_u32(&data[offset..offset + 4]);

            offset += 4;

            let token_str = String::from_utf8(data[offset..offset + str_len as usize].to_vec()).expect("Error reading token string");

            vocab.push(token_str);

            offset += str_len as usize;
        }

        Tokenizer {
            vocab_size,
            //max_token_len,
            vocab,
            bos,
            eos,
            vocab_scores,
            sorted_vocab,
        }
    }

    pub fn encode(&mut self, text: &str, bos: bool, eos: bool, chat_format: bool, model_type: ModelType) -> Vec<u32> {
        assert!(!text.is_empty(), "Text to encode should not be empty");

        if self.sorted_vocab.is_empty() {
            for i in 0..self.vocab_size as usize {
                self.sorted_vocab.push(
                    // Using clone, should point to the vocabs string, but cant deal with rust rn
                    TokenIndex {
                        text: self.vocab[i].clone(),
                        id: i as u32,
                    }
                )
            }
            self.sorted_vocab.sort_by(|a, b| a.text.cmp(&b.text));
        }

        let mut tokens: Vec<u32> = Vec::new();

        if bos {
            tokens.push(self.bos)
        }

        if chat_format {
            if model_type == ModelType::GEMMA {
                tokens.extend([self.bos, 106, 1645, 108]);
            } else if model_type == ModelType::LLAMA {
                tokens.extend([128006, 882, 128007, 271]);
            } else if model_type == ModelType::PHI {
                tokens.extend([self.bos, 32010, 29871, 13]);
            }
        }

        for c in text.chars() {
            let c_str = c.to_string();
            match self.sorted_vocab.binary_search_by(|token| token.text.cmp(&c_str)) {
                Ok(index) => tokens.push(self.sorted_vocab[index].id),
                Err(_) => {
                    for b in c_str.into_bytes().iter() {
                        tokens.push(*b as u32 + 3)
                    }
                },
            }
        }

        loop {
            let mut best_score: f32 = -1e10;
            let mut best_id: u32 = 0;
            let mut best_idx: i32 = -1;

            for idx in 0..tokens.len() - 1 {
                let new_t = self.vocab[tokens[idx] as usize].clone() + &self.vocab[tokens[idx + 1] as usize];
                
                if let Ok(index) = self.sorted_vocab.binary_search_by(|token| token.text.cmp(&new_t)) {
                    let temp_t = &self.sorted_vocab[index];
                    if self.vocab_scores[temp_t.id as usize] > best_score {
                        best_score = self.vocab_scores[temp_t.id as usize];
                        best_id = temp_t.id;
                        best_idx = idx as i32;
                    }
                }
            }

            if best_idx == -1 {
                break;
            }

            tokens[best_idx as usize] = best_id;
            tokens.remove((best_idx+1) as usize);
        }
        
        if chat_format {
            if model_type == ModelType::GEMMA {
                tokens.extend([107, 108, 106, 2516, 108]);
            } else if model_type == ModelType::LLAMA {
                tokens.extend([128009, 128006, 78191, 128007, 271]);
            } else if model_type == ModelType::PHI {
                tokens.extend([32007, 29871, 13, 32001, 29871, 13]);
            }
        }

        if eos {
            tokens.push(self.eos)
        }

        tokens
    }

    pub fn decode(&self, token: u32) -> String {
        let piece = self.vocab[token as usize].to_string();
        
        if piece.starts_with("<0x") && piece.ends_with('>') && piece.len() == 6 {
            if let Ok(byte_val) = u8::from_str_radix(&piece[3..5], 16) {
                return char::from(byte_val).to_string();
            }
        }
        piece
    }
}
</document_content>
</document>
<document index="7">
<source>src/transformer.rs</source>
<document_content>
use crate::functional::slice_to_u32;
use crate::functional::u8_to_f32_slice;
use crate::functional::u8_to_i8_slice;
use crate::functional::rmsnorm;
use crate::functional::matmul;
use crate::functional::matmul_q4;
use crate::functional::matmul_q8;
use crate::functional::softmax;

use crate::quantization::*;

use memmap2::Mmap;
use rayon::prelude::*;
use std::alloc::{dealloc, Layout};
use std::mem::{MaybeUninit, size_of};

pub fn init_param<'a>(data: &'a [u8], offset: &mut usize, n: u32, size_each: u32) -> &'a [f32]{
    let ptr: &[f32]  = u8_to_f32_slice(&data[*offset..(*offset + ((n * size_each) as usize * size_of::<f32>()))]);

    *offset += (n * size_each) as usize * size_of::<f32>();

    ptr
}

pub fn init_param_quant<'a>(data: &'a [u8], offset: &mut usize, n: u32, size_each: u32, gs: u32, q_type: QuantType) -> &'a [QuantizedTensor<'a>]{
    let mut res: Vec<QuantizedTensor> = Vec::with_capacity(n as usize);
    let groups = (size_each / gs) as usize;
    let mut size = size_each;
    
    if q_type == QuantType::Q4_0 {
        size /= 2;
    }

    for _ in 0..n {
        let mut qt = QuantizedTensor { q: &mut [], s: &mut [] };

        qt.q = u8_to_i8_slice(&data[*offset..(*offset + (size as usize * size_of::<i8>()))]);
        
        *offset += size as usize * size_of::<i8>() ;

        qt.s = u8_to_f32_slice(&data[*offset..(*offset + (groups * size_of::<f32>()))]);
        
        *offset += groups * size_of::<f32>();

        res.push(qt);
    }

    Box::leak(res.into_boxed_slice())
}

#[derive(Debug, Copy, Clone, PartialEq)]
pub enum ModelType {
    GEMMA,
    LLAMA,
    PHI
}

#[repr(C, packed)]
#[derive(Debug, Copy, Clone)]
pub struct TransformerArgs {
    dim: u32,
    hidden_dim: u32,
    n_layers: u32,
    n_heads: u32,
    head_size: u32,
    n_kv_heads: u32,
    pub vocab_size: u32,
    seq_len: u32,
    rms_norm_eps: f32,
    rope_theta: f32,
    q_type: QuantType,
    pub model_type: ModelType,
    group_size: u32,
    pub multimodal: bool,
}

pub struct TransformerWeights<'a> {
    token_embedding_table: &'a [f32],

    // Attention

    wq: MaybeUninit<&'a [f32]>,
    wk: MaybeUninit<&'a [f32]>,
    wv: MaybeUninit<&'a [f32]>,
    wo: MaybeUninit<&'a [f32]>,
    
    wq_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,
    wk_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,
    wv_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,
    wo_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,

    w_rms_att: &'a [f32],

    // FFN
    w1: MaybeUninit<&'a [f32]>,
    w2: MaybeUninit<&'a [f32]>,
    w3: MaybeUninit<&'a [f32]>,

    w1_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,
    w2_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,
    w3_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,

    w_rms_post_att: &'a [f32],

    w_rms_pre_ffn: MaybeUninit<&'a [f32]>,
    w_rms_post_ffn: MaybeUninit<&'a [f32]>,

    w_rms_final: &'a [f32],

    w_cls: MaybeUninit<&'a [f32]>,
    w_cls_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,

    lm_head: MaybeUninit<&'a [f32]>,
    lm_head_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>
}

pub struct TransformerState<'a>
{
    xb: Vec<f32>,
    xb2: Vec<f32>, 
    xb3: Vec<f32>, 
    hb: Vec<f32>,
    hb2: Vec<f32>,
    q: Vec<f32>,
    xq: MaybeUninit<MutableQuantizedTensor<'a>>,
    xq1: MaybeUninit<MutableQuantizedTensor<'a>>,
    hq: MaybeUninit<MutableQuantizedTensor<'a>>,
    logits: Vec<f32>, 

    // kv cache
    key_cache: Vec<f32>,
    value_cache: Vec<f32>, 
}

pub struct Transformer<'a> {
    pub args: TransformerArgs,
    weights: TransformerWeights<'a>,
    state: TransformerState<'a>,
}

impl<'a> Transformer<'a> {
    pub fn new(data: &'a Mmap) -> (Transformer<'a>, usize) {
        assert_eq!(data[0..4], [0x6c, 0x6d, 0x72, 0x73], "Model not in lm.rs format.");

        let lmrs_version = slice_to_u32(&data[4..8]);

        println!("LMRS version: {}", lmrs_version);
        
        let (head, body, _) = unsafe { data[8..55].align_to::<TransformerArgs>() };

        assert!(head.is_empty(), "Data was not aligned");
        
        let mut cfg = body[0];

        println!("Model type: {:?}\n", cfg.model_type);

        let head_size = cfg.head_size;
        
        let mut offset: usize = 256;

        let quantized = cfg.q_type != QuantType::None;
        
        if quantized { println!("Using {:?} quantization.", cfg.q_type) };
        
        // For now this will do so we don't run out of memory
        if cfg.seq_len > 8192 {
            cfg.seq_len = 8192;
        }

        let kv_dim = cfg.head_size * cfg.n_kv_heads;

        let mut rms_pre_ffn = MaybeUninit::uninit();
        let mut rms_post_ffn = MaybeUninit::uninit();
        let mut lm_head = MaybeUninit::uninit();
        let mut lm_head_quant = MaybeUninit::uninit();

        if !quantized {
            
            let emb_tab = init_param(data, &mut offset, 1, cfg.vocab_size * cfg.dim);
            let rms_att = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
            let wq = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size);
            let wk = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_kv_heads * head_size);
            let wv = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_kv_heads * head_size);
            let wo = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size);
            let rms_post_att = init_param(data, &mut offset, cfg.n_layers, cfg.dim);

            if cfg.model_type == ModelType::GEMMA {
                rms_pre_ffn = MaybeUninit::new(init_param(data, &mut offset, cfg.n_layers, cfg.dim));
            }

            let w1 = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.hidden_dim);
            let w2 = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.hidden_dim);
            let w3 = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.hidden_dim);
            
            if cfg.model_type == ModelType::GEMMA {
                rms_post_ffn = MaybeUninit::new(init_param(data, &mut offset, cfg.n_layers, cfg.dim));
            }

            let rms_final = init_param(data, &mut offset, 1, cfg.dim);

            if cfg.model_type == ModelType::PHI {
                lm_head = MaybeUninit::new(init_param(data, &mut offset, 1, cfg.dim * cfg.vocab_size));
            }
            
            let weights = TransformerWeights {
                token_embedding_table: emb_tab,
                wq: MaybeUninit::new(wq),
                wk: MaybeUninit::new(wk),
                wv: MaybeUninit::new(wv),
                wo: MaybeUninit::new(wo),
                wq_quant: MaybeUninit::uninit(),
                wk_quant: MaybeUninit::uninit(),
                wv_quant: MaybeUninit::uninit(),
                wo_quant: MaybeUninit::uninit(),
                w_rms_att: rms_att,
                w1: MaybeUninit::new(w1),
                w2: MaybeUninit::new(w2),
                w3: MaybeUninit::new(w3),
                w1_quant: MaybeUninit::uninit(),
                w2_quant: MaybeUninit::uninit(),
                w3_quant: MaybeUninit::uninit(),
                w_rms_post_att: rms_post_att,
                w_rms_pre_ffn: rms_pre_ffn,
                w_rms_post_ffn: rms_post_ffn,
                w_rms_final: rms_final,
                w_cls: MaybeUninit::new(emb_tab),
                w_cls_quant: MaybeUninit::uninit(),
                lm_head,
                lm_head_quant
            };

            let state = TransformerState {
                xb: vec![0.0; cfg.dim as usize],
                xb2: vec![0.0; cfg.dim as usize],
                xb3: vec![0.0; (cfg.head_size*cfg.n_heads) as usize],
                hb: vec![0.0; cfg.hidden_dim as usize],
                hb2: vec![0.0; cfg.hidden_dim as usize],
                q: vec![0.0; (cfg.head_size*cfg.n_heads) as usize],
                xq: MaybeUninit::uninit(),
                xq1: MaybeUninit::uninit(),
                hq: MaybeUninit::uninit(),
                key_cache: vec![0.0; (cfg.n_layers * cfg.seq_len * kv_dim) as usize],
                value_cache: vec![0.0; (cfg.n_layers * cfg.seq_len * kv_dim) as usize],
                logits: vec![0.0; cfg.vocab_size as usize],
            };
            
            return (Transformer {
                args: cfg,
                weights,
                state,
            }, offset)
        } 

        println!("Loading weights...");

        let emb_tab_quant = init_param_quant(data, &mut offset, 1, cfg.vocab_size * cfg.dim, cfg.group_size, cfg.q_type);

        let mut emb_tab: Vec<f32> = vec![0.0; (cfg.vocab_size * cfg.dim) as usize];

        dequantize(&emb_tab_quant[0], &mut emb_tab, (cfg.vocab_size * cfg.dim) as usize, cfg.group_size, cfg.q_type);

        let rms_att = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
        let wq_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size, cfg.group_size, cfg.q_type);
        let wk_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_kv_heads * head_size, cfg.group_size, cfg.q_type);
        let wv_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_kv_heads * head_size, cfg.group_size, cfg.q_type);
        let wo_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size, cfg.group_size, cfg.q_type);
        let rms_post_att = init_param(data, &mut offset, cfg.n_layers, cfg.dim);

        if cfg.model_type == ModelType::GEMMA {
            rms_pre_ffn = MaybeUninit::new(init_param(data, &mut offset, cfg.n_layers, cfg.dim));
        }

        let w1_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.hidden_dim, cfg.group_size, cfg.q_type);
        let w2_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.hidden_dim, cfg.group_size, cfg.q_type);
        let w3_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.hidden_dim, cfg.group_size, cfg.q_type);

        if cfg.model_type == ModelType::GEMMA {
            rms_post_ffn = MaybeUninit::new(init_param(data, &mut offset, cfg.n_layers, cfg.dim));
        }

        let rms_final = init_param(data, &mut offset, 1, cfg.dim); 
        
        if cfg.model_type == ModelType::PHI {
            lm_head_quant = MaybeUninit::new(init_param_quant(data, &mut offset, 1, cfg.dim * cfg.vocab_size, cfg.group_size, cfg.q_type));
        }
        
        let weights = TransformerWeights {
            token_embedding_table: Box::leak(emb_tab.into_boxed_slice()),
            wq: MaybeUninit::uninit(),
            wk: MaybeUninit::uninit(),
            wv: MaybeUninit::uninit(),
            wo: MaybeUninit::uninit(),
            wq_quant: MaybeUninit::new(wq_quant),
            wk_quant: MaybeUninit::new(wk_quant),
            wv_quant: MaybeUninit::new(wv_quant),
            wo_quant: MaybeUninit::new(wo_quant),
            w_rms_att: rms_att,
            w1: MaybeUninit::uninit(),
            w2: MaybeUninit::uninit(),
            w3: MaybeUninit::uninit(),
            w1_quant: MaybeUninit::new(w1_quant),
            w2_quant: MaybeUninit::new(w2_quant),
            w3_quant: MaybeUninit::new(w3_quant),
            w_rms_post_att: rms_post_att,
            w_rms_pre_ffn: rms_pre_ffn,
            w_rms_post_ffn: rms_post_ffn,
            w_rms_final: rms_final,
            w_cls: MaybeUninit::uninit(),
            w_cls_quant: MaybeUninit::new(emb_tab_quant),
            lm_head,
            lm_head_quant
        };

        let state = TransformerState {
            xb: vec![0.0; cfg.dim as usize],
            xb2: vec![0.0; cfg.dim as usize],
            xb3: vec![0.0; (cfg.head_size*cfg.n_heads) as usize],
            hb: vec![0.0; cfg.hidden_dim as usize],
            hb2: vec![0.0; cfg.hidden_dim as usize],
            q: vec![0.0; (cfg.head_size*cfg.n_heads) as usize],
            xq: MaybeUninit::new(MutableQuantizedTensor { q: Box::leak(vec![0; (cfg.dim) as usize].into_boxed_slice()), s: Box::leak(vec![0.0; (cfg.dim) as usize].into_boxed_slice())}),
            xq1: MaybeUninit::new(MutableQuantizedTensor { q: Box::leak(vec![0; (cfg.head_size*cfg.n_heads) as usize].into_boxed_slice()), s: Box::leak(vec![0.0; (cfg.head_size*cfg.n_heads) as usize].into_boxed_slice())}),
            hq: MaybeUninit::new(MutableQuantizedTensor { q: Box::leak(vec![0; (cfg.hidden_dim) as usize].into_boxed_slice()), s: Box::leak(vec![0.0; (cfg.hidden_dim) as usize].into_boxed_slice())}),
            key_cache: vec![0.0; (cfg.n_layers * cfg.seq_len * kv_dim) as usize],
            value_cache: vec![0.0; (cfg.n_layers * cfg.seq_len * kv_dim) as usize],
            logits: vec![0.0; cfg.vocab_size as usize],
        };
        
        println!("Done.\n");
        
        (Transformer {
            args: cfg,
            weights,
            state,
        }, offset)
    }

    pub fn forward(&mut self, token: u32, pos: u32) -> &mut [f32] {
        let p = self.args;
        let x = &mut vec![0.0; (p.dim) as usize];
        let dim = p.dim;
        let gs = p.group_size;

        let quantized = p.q_type != QuantType::None;

        x.copy_from_slice(&self.weights.token_embedding_table[(token * dim) as usize..(token * dim + dim) as usize]);
        
        // Gemma normalizes the token embeddings by sqrt(dim)
        if p.model_type == ModelType::GEMMA {
            let normalizer = (dim as f32).sqrt();
            for i in x.iter_mut() {
                *i *= normalizer;
            }
        }

        for l in 0..p.n_layers {
            self.forward_layer(x, l, pos);
        }
        
        let s = &mut self.state;
        let w = &self.weights;
        
        s.xb.copy_from_slice(x);

        rmsnorm(x, &s.xb, w.w_rms_final, dim as usize, p.rms_norm_eps, p.model_type == ModelType::GEMMA);
        
        unsafe {
            if !quantized {
                if p.model_type != ModelType::PHI {
                    matmul(&mut s.logits, x, w.w_cls.assume_init());
                } else {
                    matmul(&mut s.logits, x, w.lm_head.assume_init());
                }
            } else {
                let sxq = &mut *s.xq.as_mut_ptr();
                
                if p.q_type == QuantType::Q8_0 {
                    quantize(sxq, x, dim as usize, gs);
                    
                    if p.model_type != ModelType::PHI {
                        matmul_q8(&mut s.logits, sxq, &w.w_cls_quant.assume_init()[0], dim as usize, gs as usize);
                    } else {
                        matmul_q8(&mut s.logits, sxq, &w.lm_head_quant.assume_init()[0], dim as usize, gs as usize);
                    }
                } else if p.q_type == QuantType::Q4_0 {
                    quantize_q4(sxq, x, dim as usize, gs);

                    if p.model_type != ModelType::PHI {
                        matmul_q4(&mut s.logits, sxq, &w.w_cls_quant.assume_init()[0], dim as usize, gs as usize);
                    } else {
                        matmul_q4(&mut s.logits, sxq, &w.lm_head_quant.assume_init()[0], dim as usize, gs as usize);
                    }
                }
            }
        }

        if p.model_type == ModelType::GEMMA {
            for d in 0..dim {
                s.logits[d as usize] /= 30.0;
                s.logits[d as usize] = (s.logits[d as usize] as f64).tanh() as f32;
                s.logits[d as usize] *= 30.0;
            }
        }
        
        &mut s.logits
    }

    fn forward_layer(&mut self, x: &mut [f32], l: u32, pos: u32) {
        let p = self.args;
        let w = &self.weights;
        let s = &mut self.state;
        let dim = p.dim;
        let head_size = p.head_size;
        let att_dim = p.n_heads * head_size;
        let kv_dim = head_size * p.n_kv_heads;
        let kv_mul = p.n_heads / p.n_kv_heads;
        let hidden_dim = p.hidden_dim;
        let gs = p.group_size;

        let quantized = p.q_type != QuantType::None;

        rmsnorm(&mut s.xb, x, &w.w_rms_att[(l*dim) as usize..(l*dim + dim) as usize], dim as usize, p.rms_norm_eps, p.model_type == ModelType::GEMMA);
        
        let loff = l * p.seq_len * kv_dim; 
        let k = &mut s.key_cache[(loff + pos * kv_dim) as usize..(loff + pos * kv_dim + kv_dim) as usize];
        let v = &mut s.value_cache[(loff + pos * kv_dim) as usize..(loff + pos * kv_dim + kv_dim) as usize];
        
        unsafe {
            if !quantized {
                matmul(&mut s.q, &s.xb, &w.wq.assume_init()[(l*dim*att_dim) as usize..(l*dim*att_dim + dim*att_dim) as usize]);
                matmul(k, &s.xb, &w.wk.assume_init()[(l*dim*kv_dim) as usize..(l*dim*kv_dim + dim*kv_dim) as usize]);
                matmul(v, &s.xb, &w.wv.assume_init()[(l*dim*kv_dim) as usize..(l*dim*kv_dim + dim*kv_dim) as usize]);
            } else {
                let sxq = &mut *s.xq.as_mut_ptr();

                if p.q_type == QuantType::Q8_0 {
                    quantize(sxq, &s.xb, dim as usize, gs);
                    
                    matmul_q8(&mut s.q, sxq, &w.wq_quant.assume_init()[l as usize], dim as usize, gs as usize);
                    matmul_q8(k, sxq, &w.wk_quant.assume_init()[l as usize], dim as usize, gs as usize);
                    matmul_q8(v, sxq, &w.wv_quant.assume_init()[l as usize], dim as usize, gs as usize);
                } else if p.q_type == QuantType::Q4_0 {
                    quantize_q4(sxq, &s.xb, dim as usize, gs);
                    
                    matmul_q4(&mut s.q, sxq, &w.wq_quant.assume_init()[l as usize], dim as usize, gs as usize);
                    matmul_q4(k, sxq, &w.wk_quant.assume_init()[l as usize], dim as usize, gs as usize);
                    matmul_q4(v, sxq, &w.wv_quant.assume_init()[l as usize], dim as usize, gs as usize);
                }
            }
        }
        
        // RoPE
        for i in 0..p.n_heads {
            for j in 0..(head_size/2) {
                let head_dim: u32 = j * 2;
                let mut freq: f32 = 1.0 / p.rope_theta.powf(head_dim as f32/head_size as f32);

                let mut scaling_factor = 1.0;

                if p.model_type == ModelType::LLAMA {
                    let wavelen = (2.0 * std::f32::consts::PI) / freq;
                    
                    // Should be on args
                    let factor = 32.0;
                    let low_freq_factor = 1.0;
                    let high_freq_factor = 4.0;
                    let old_context_len = 8192.0;

                    let low_freq_wavelen = old_context_len / low_freq_factor;
                    let high_freq_wavelen = old_context_len / high_freq_factor;

                    if wavelen > low_freq_wavelen {
                        freq /= factor;
                    } else if wavelen <= low_freq_wavelen && wavelen >= high_freq_wavelen {
                        let smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor);
                        
                        freq = (1.0 - smooth_factor) * freq / factor + smooth_factor * freq
                    }
                }

                if p.model_type == ModelType::PHI {
                    let short_factor: [f64; 48] = [1.08, 1.1, 1.1300000000000001, 1.2800000000000002, 1.3100000000000003, 1.4500000000000004, 1.4500000000000004, 1.9500000000000008, 2.030000000000001, 2.4299999999999926, 2.5699999999999896, 2.9499999999999815, 3.729999999999965, 3.869999999999962, 4.189999999999955, 4.43999999999995, 4.6399999999999455, 4.979999999999938, 5.159999999999934, 5.279999999999932, 5.759999999999922, 5.889999999999919, 5.889999999999919, 5.969999999999917, 6.089999999999915, 6.2799999999999105, 6.7699999999999, 6.8899999999998975, 7.109999999999893, 7.129999999999892, 7.179999999999891, 7.289999999999889, 7.339999999999888, 7.559999999999883, 7.619999999999882, 7.69999999999988, 7.879999999999876, 7.879999999999876, 7.879999999999876, 7.939999999999875, 7.949999999999875, 7.979999999999874, 8.19999999999987, 8.439999999999864, 8.469999999999864, 8.589999999999861, 8.809999999999857, 8.999999999999853];

                    freq *= (1.0/short_factor[j as usize]) as f32;
                    let scale = 131072f32/4096f32;
                    scaling_factor = (1.0 + scale.ln() / 4096f32.ln()).sqrt();
                }

                let val: f32 = pos as f32 * freq;
                let fcr = val.cos() * scaling_factor;
                let fci = val.sin() * scaling_factor;
                let rotn: u32 = if (i*head_size) + j + head_size/2 < kv_dim {2} else {1};

                for v in 0..rotn{
                    let vec: &mut [f32] = if v == 0 {&mut s.q} else {k};
                    let v0: f32 = vec[((i*head_size) + j) as usize];
                    let v1: f32 = vec[(((i*head_size) + j)+(head_size/2)) as usize];
                    
                    vec[((i*head_size) + j) as usize] = v0 * fcr - v1 * fci;
                    vec[(((i*head_size) + j)+(head_size/2)) as usize]= v0 * fci + v1 * fcr;
                }
            }
        }
            
        s.xb3.par_chunks_mut(head_size as usize).enumerate().for_each( |(h, xb)| {
            let q = &s.q[(h as u32 * head_size) as usize..(h as u32 * head_size + head_size) as usize];

            let att = &mut vec![0.0; p.seq_len as usize];

            for t in 0..pos+1 {
                let k = &s.key_cache[(loff + t * kv_dim + (h as u32 / kv_mul) * head_size) as usize..(loff + t * kv_dim + (h as u32 / kv_mul) * head_size + head_size) as usize];
                
                let mut score: f32 = 0.0;

                for i in 0..head_size {
                    score += q[i as usize] * k[i as usize];
                }
                
                score /= (head_size as f32).sqrt();
                
                if p.model_type == ModelType::GEMMA {
                    // Softcapping
                    score /= 50.0f32;
                    score = (score as f64).tanh() as f32;
                    score *= 50.0f32;
                    
                    // Local attention
                    score += if pos - t <= 4096 {0.0} else {-2.3819763e38};
                }

                att[t as usize] = score;
            }

            softmax(&mut att[..(pos+1) as usize]);

            xb.fill(0.0);

            for t in 0..pos+1 {
                let v = &s.value_cache[(loff + t * kv_dim + (h as u32 / kv_mul) * head_size) as usize..(loff + t * kv_dim + (h as u32 / kv_mul) * head_size + head_size) as usize];
                let a = att[t as usize];

                for i in 0..head_size {
                    xb[i as usize] += a * v[i as usize];
                }
            }
        });

        unsafe {
            if !quantized {
                matmul(&mut s.xb2, &s.xb3, &w.wo.assume_init()[(l*dim*att_dim) as usize..(l*dim*att_dim + dim*att_dim) as usize]);
            } else {
                let sxq1 = &mut *s.xq1.as_mut_ptr();
                
                if p.q_type == QuantType::Q8_0 {
                    quantize(sxq1, &s.xb3, att_dim as usize, gs);
                    matmul_q8(&mut s.xb2, sxq1, &w.wo_quant.assume_init()[l as usize], att_dim as usize, gs as usize)
                } else {
                    quantize_q4(sxq1, &s.xb3, att_dim as usize, gs);
                    matmul_q4(&mut s.xb2, sxq1, &w.wo_quant.assume_init()[l as usize], att_dim as usize, gs as usize)
                }
            }
        }
        
        if p.model_type == ModelType::GEMMA {
            rmsnorm(&mut s.xb, &s.xb2, &w.w_rms_post_att[(l*dim) as usize..(l*dim + dim) as usize], dim as usize, p.rms_norm_eps, p.model_type == ModelType::GEMMA);
        
            for i in 0..dim {
                x[i as usize] += s.xb[i as usize];
            }
            
            unsafe {
                rmsnorm(&mut s.xb, x, &w.w_rms_pre_ffn.assume_init()[(l*dim) as usize..(l*dim + dim) as usize], dim as usize, p.rms_norm_eps, true);
            }
        } else {
            for i in 0..dim {
                x[i as usize] += s.xb2[i as usize];
            }
            
            rmsnorm(&mut s.xb, x, &w.w_rms_post_att[(l*dim) as usize..(l*dim + dim) as usize], dim as usize, p.rms_norm_eps, p.model_type == ModelType::GEMMA);
        }
        
        // GeGLU is w2(GELU(w1(x)) * w3(x)) 
        // w1 -> gate_proj weights
        // w2 -> down_proj weights
        // w3 -> up_proj weights
        // GELU using tanh as the approximation

        unsafe {
            if !quantized {
                matmul(&mut s.hb, &s.xb, &w.w1.assume_init()[(l*dim*hidden_dim) as usize..(l*dim*hidden_dim + dim*hidden_dim) as usize]);
                matmul(&mut s.hb2, &s.xb, &w.w3.assume_init()[(l*dim*hidden_dim) as usize..(l*dim*hidden_dim + dim*hidden_dim) as usize]);
            } else {
                let sxq = &mut *s.xq.as_mut_ptr();
                
                if p.q_type == QuantType::Q8_0 {
                    quantize(sxq, &s.xb, dim as usize, gs);
                    matmul_q8(&mut s.hb, sxq, &w.w1_quant.assume_init()[l as usize], dim as usize, gs as usize);
                    matmul_q8(&mut s.hb2, sxq, &w.w3_quant.assume_init()[l as usize], dim as usize, gs as usize);
                } else if p.q_type == QuantType::Q4_0{
                    quantize_q4(sxq, &s.xb, dim as usize, gs);
                    matmul_q4(&mut s.hb, sxq, &w.w1_quant.assume_init()[l as usize], dim as usize, gs as usize);
                    matmul_q4(&mut s.hb2, sxq, &w.w3_quant.assume_init()[l as usize], dim as usize, gs as usize);
                }
            }
        }
            
        for i in 0..hidden_dim {
            let mut val = s.hb[i as usize];

            // Best case we would have the activation in the args, but for now this will do 
            if p.model_type == ModelType::GEMMA {
                // GELU
                val *= 0.5 * (1.0 + ((0.7978845608028654 * (val + 0.044715 * val * val * val) as f64).tanh()) as f32);   
            } else {
                // SiLU
                val *= 1.0 / (1.0 + (-val).exp());
            }

            val *= s.hb2[i as usize];
            
            s.hb[i as usize] = val;
        }

        unsafe {
            if !quantized {
                matmul(&mut s.xb, &s.hb, &w.w2.assume_init()[(l*dim*hidden_dim) as usize..(l*dim*hidden_dim + dim*hidden_dim) as usize]);
            } else {
                let shq = &mut *s.hq.as_mut_ptr();

                if p.q_type == QuantType::Q8_0 {
                    quantize(shq, &s.hb, hidden_dim as usize, gs);
                    matmul_q8(&mut s.xb, shq, &w.w2_quant.assume_init()[l as usize], hidden_dim as usize, gs as usize);
                } else if p.q_type == QuantType::Q4_0 {
                    quantize_q4(shq, &s.hb, hidden_dim as usize, gs);
                    matmul_q4(&mut s.xb, shq, &w.w2_quant.assume_init()[l as usize], hidden_dim as usize, gs as usize);
                }
            }
        }

        if p.model_type == ModelType::GEMMA {
            unsafe {
                rmsnorm(&mut s.xb2, &s.xb, &w.w_rms_post_ffn.assume_init()[(l*dim) as usize..(l*dim + dim) as usize], dim as usize, p.rms_norm_eps, true);
            }
            
            for i in 0..dim {
                x[i as usize] += s.xb2[i as usize];
            }
        } else {
            for i in 0..dim {
                x[i as usize] += s.xb[i as usize];
            }
        }
    }

    pub fn get_embeddings(&self, tokens: &[u32]) -> Vec<f32> {
        let n_tokens = tokens.len();
        let dim = self.args.dim;
        let mut out_embeddings: Vec<f32> = Vec::with_capacity(dim as usize * n_tokens);

        for t in tokens.iter().take(n_tokens) {
            out_embeddings.extend(&self.weights.token_embedding_table[(t * dim) as usize..(t * dim + dim) as usize]);
        }

        out_embeddings
    }

    // For now we use batch = 1, probably doing a batched alternative would be faster
    pub fn fill_kv_cache(&mut self, embeddings: &mut [f32], curr_pos: u32) -> u32 {
        let p = self.args;
        let dim = p.dim;    
        let mut pos = curr_pos;

        embeddings.chunks_mut(dim as usize).for_each( |x| {
            for l in 0..p.n_layers {
                self.forward_layer(x, l, pos)
            }
            
            pos += 1;
        });

        pos
    }
}

// Deallocate fields created with Box::leak
impl<'a> Drop for Transformer<'a> {
    fn drop(&mut self) {
        if self.args.q_type != QuantType::None {
            unsafe {
                // Weights
                dealloc(self.weights.token_embedding_table.as_ptr() as *mut u8, Layout::array::<f32>(self.weights.token_embedding_table.len()).unwrap());
                
                let layer_weights_layout = Layout::array::<QuantizedTensor>(self.args.n_layers as usize).unwrap();
                dealloc(self.weights.wq_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.wk_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.wv_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.wo_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.w1_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.w2_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.w3_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.w_cls_quant.assume_init().as_ptr() as *mut u8, Layout::array::<QuantizedTensor>(self.weights.w_cls_quant.assume_init().len()).unwrap());

                // State
                let sxq = &mut *self.state.xq.as_mut_ptr();
                dealloc(sxq.q.as_ptr() as *mut u8, Layout::array::<i8>(sxq.q.len()).unwrap());
                dealloc(sxq.s.as_ptr() as *mut u8, Layout::array::<f32>(sxq.s.len()).unwrap());
                
                let sxq1 = &mut *self.state.xq1.as_mut_ptr();
                dealloc(sxq1.q.as_ptr() as *mut u8, Layout::array::<i8>(sxq1.q.len()).unwrap());
                dealloc(sxq1.s.as_ptr() as *mut u8, Layout::array::<f32>(sxq1.s.len()).unwrap());
                
                let shq = &mut *self.state.hq.as_mut_ptr();
                dealloc(shq.q.as_ptr() as *mut u8, Layout::array::<i8>(shq.q.len()).unwrap());
                dealloc(shq.s.as_ptr() as *mut u8, Layout::array::<f32>(shq.s.len()).unwrap());
            }
        }
    }
}

</document_content>
</document>
<document index="8">
<source>src/vision.rs</source>
<document_content>
use crate::quantization::{QuantizedTensor, MutableQuantizedTensor, QuantType, quantize, quantize_q4};
use crate::transformer::{init_param, init_param_quant};
use crate::functional::{matmul, matmul_q8, matmul_q4, matmul_conv, concat, layernorm, softmax};

use rayon::prelude::*;
use wide::f32x8;
use std::mem::MaybeUninit;
use std::alloc::dealloc;
use std::alloc::Layout;

#[repr(C, packed)]
#[derive(Debug, Copy, Clone)]
pub struct VisionTransformerArgs {
    dim: u32,
    hidden_dim: u32,
    n_layers: u32,
    n_heads: u32,
    head_size: u32,
    layernorm_eps: f32,
    pub patch_size: u32,
    pub image_size: u32,
    q_type: QuantType,
    group_size: u32,
}

pub struct VisionTransformerWeights<'a> {
    class_embedding: &'a [f32],

    patch_embedding: &'a [f32],
    
    position_embedding: &'a [f32],

    // Attention

    wq: MaybeUninit<&'a [f32]>,
    wq_bias: &'a [f32],
    wk: MaybeUninit<&'a [f32]>,
    wk_bias: &'a [f32],
    wv: MaybeUninit<&'a [f32]>,
    wv_bias: &'a [f32],
    wo: MaybeUninit<&'a [f32]>,
    wo_bias: &'a [f32],
    
    wq_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,
    wk_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,
    wv_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,
    wo_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,

    layer_norm1: &'a [f32],
    layer_norm2: &'a [f32],
    layer_norm1_bias: &'a [f32],
    layer_norm2_bias: &'a [f32],

    // FFN

    w1: MaybeUninit<&'a [f32]>,
    w1_bias: &'a [f32],
    w2: MaybeUninit<&'a [f32]>,
    w2_bias: &'a [f32],

    w1_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,
    w2_quant: MaybeUninit<&'a [QuantizedTensor<'a>]>,

    pre_layer_norm: &'a [f32],
    pre_layer_norm_bias: &'a [f32],
}

pub struct VisionTransformer<'a> {
    weights: VisionTransformerWeights<'a>,
    pub args: VisionTransformerArgs
}

pub fn qkv_split(qkv: &[f32], dim: u32, num_crops: u32, n_heads: u32, n_patches: u32, out_shape: u32) -> (Vec<f32>, Vec<f32>, Vec<f32>){
    let mut q: Vec<f32> = Vec::with_capacity((num_crops*out_shape) as usize);
    let mut k: Vec<f32> = Vec::with_capacity((num_crops*out_shape) as usize);
    let mut v: Vec<f32> = Vec::with_capacity((num_crops*out_shape) as usize);

    let head_size = dim / n_heads;

    for i in 0..num_crops {
        for h in 0..n_heads {
            for t in 0..n_patches {
                q.extend(&qkv[((t*dim*3) + h * head_size + (i*3*out_shape)) as usize..((t*dim*3) + h * head_size + (i*3*out_shape) + head_size) as usize]);
                k.extend(&qkv[((t*dim*3 + dim) + h * head_size + (i*3*out_shape)) as usize..((t*dim*3 + dim) + h * head_size + (i*3*out_shape) + head_size) as usize]);
            }
            
            for j in 0..head_size {
                for t in 0..n_patches {
                    v.push(qkv[((t*dim*3 + 2*dim) + j + h*head_size + (i*3*out_shape)) as usize]);
                }
            }
        }
    }

    (q, k, v)
}

impl<'a> VisionTransformer<'a> {
    pub fn new(data: &'a [u8]) -> (VisionTransformer<'a>, usize) {
        let (head, body, _) = unsafe { data[..37].align_to::<VisionTransformerArgs>() };

        assert!(head.is_empty(), "Data was not aligned");
        
        let cfg = &body[0];

        let head_size = cfg.head_size;
        
        let mut offset: usize = 128;

        let quantized = cfg.q_type != QuantType::None;
        
        let class_embedding = init_param(data, &mut offset, 1, cfg.dim);
        let patch_embedding = init_param(data, &mut offset, 1, cfg.dim*3*cfg.patch_size*cfg.patch_size);
        
        if !quantized {
            let position_embedding = init_param(data, &mut offset, 1, cfg.dim*577);

            let layer_norm1 = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
            let layer_norm1_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
            let layer_norm2 = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
            let layer_norm2_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);

            let wq = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size);
            let wq_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
            let wk = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size);
            let wk_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
            let wv = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size);
            let wv_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
            
            let wo = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size);
            let wo_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
            
            let w1 = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.hidden_dim);
            let w1_bias = init_param(data, &mut offset, cfg.n_layers, cfg.hidden_dim);
            
            let w2 = init_param(data, &mut offset, cfg.n_layers, cfg.dim * cfg.hidden_dim);
            let w2_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
            
            let pre_layer_norm = init_param(data, &mut offset, 1, cfg.dim);
            let pre_layer_norm_bias = init_param(data, &mut offset, 1, cfg.dim);
            
            let weights = VisionTransformerWeights {
                class_embedding,
                patch_embedding,
                position_embedding,
                layer_norm1,
                layer_norm1_bias,
                layer_norm2,
                layer_norm2_bias,
                wq: MaybeUninit::new(wq),
                wk: MaybeUninit::new(wk),
                wv: MaybeUninit::new(wv),
                wo: MaybeUninit::new(wo),
                wq_bias,
                wk_bias,
                wv_bias,
                wo_bias,
                wq_quant: MaybeUninit::uninit(),
                wk_quant: MaybeUninit::uninit(),
                wv_quant: MaybeUninit::uninit(),
                wo_quant: MaybeUninit::uninit(),
                w1: MaybeUninit::new(w1),
                w2: MaybeUninit::new(w2),
                w1_bias,
                w2_bias,
                w1_quant: MaybeUninit::uninit(),
                w2_quant: MaybeUninit::uninit(),
                pre_layer_norm,
                pre_layer_norm_bias,
            };

            return (VisionTransformer {
                args: *cfg,
                weights,
            }, offset)
        } 

        println!("Loading vision encoder weights...");

        let position_embedding = init_param(data, &mut offset, 1, cfg.dim*577);

        let layer_norm1 = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
        let layer_norm1_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
        let layer_norm2 = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
        let layer_norm2_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);

        let wq_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size, cfg.group_size, cfg.q_type);
        let wq_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
        let wk_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size, cfg.group_size, cfg.q_type);
        let wk_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
        let wv_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size, cfg.group_size, cfg.q_type);
        let wv_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
        
        let wo_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.n_heads * head_size, cfg.group_size, cfg.q_type);
        let wo_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
        
        let w1_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.hidden_dim, cfg.group_size, cfg.q_type);
        let w1_bias = init_param(data, &mut offset, cfg.n_layers, cfg.hidden_dim);
        
        let w2_quant = init_param_quant(data, &mut offset, cfg.n_layers, cfg.dim * cfg.hidden_dim, cfg.group_size, cfg.q_type);
        let w2_bias = init_param(data, &mut offset, cfg.n_layers, cfg.dim);
        
        let pre_layer_norm = init_param(data, &mut offset, 1, cfg.dim);
        let pre_layer_norm_bias = init_param(data, &mut offset, 1, cfg.dim);
        
        let weights = VisionTransformerWeights {
            class_embedding,
            patch_embedding,
            position_embedding,
            layer_norm1,
            layer_norm1_bias,
            layer_norm2,
            layer_norm2_bias,
            wq: MaybeUninit::uninit(),
            wk: MaybeUninit::uninit(),
            wv: MaybeUninit::uninit(),
            wo: MaybeUninit::uninit(),
            wq_bias,
            wk_bias,
            wv_bias,
            wo_bias,
            wq_quant: MaybeUninit::new(wq_quant),
            wk_quant: MaybeUninit::new(wk_quant),
            wv_quant: MaybeUninit::new(wv_quant),
            wo_quant: MaybeUninit::new(wo_quant),
            w1: MaybeUninit::uninit(),
            w2: MaybeUninit::uninit(),
            w1_bias,
            w2_bias,
            w1_quant: MaybeUninit::new(w1_quant),
            w2_quant: MaybeUninit::new(w2_quant),
            pre_layer_norm,
            pre_layer_norm_bias,
        };
        
        println!("Done.\n");
        
        (VisionTransformer {
            args: *cfg,
            weights,
        }, offset)
    }

    pub fn forward(&mut self, pixel_values: &[f32], num_crops: u32) -> (Vec<f32>, u32) {
        let p = self.args;
        let w = &self.weights;
        let dim = p.dim;
        let head_size = p.head_size;
        let att_dim = p.n_heads * head_size;
        let hidden_dim = p.hidden_dim;
        let gs = p.group_size;

        let quantized = p.q_type != QuantType::None;
        
        let patches_per_row = p.image_size/p.patch_size; 
        let mut n_patches = patches_per_row*patches_per_row;
        let mut out_shape = p.dim*n_patches;
        let img_pixels = 3*p.image_size*p.image_size;
        let mut patch_embeds: Vec<f32> = vec![0.0; (num_crops*out_shape) as usize];
        let patch_shape = p.patch_size*p.patch_size;
        
        for b in 0..num_crops {
            matmul_conv(&mut patch_embeds[(b*out_shape) as usize..(b*out_shape + out_shape) as usize], &pixel_values[(b*img_pixels) as usize..(b*img_pixels + img_pixels) as usize], &w.patch_embedding, (patch_shape*3) as usize, patches_per_row);
        }

        // Cat class embedding
        let mut embeddings = Vec::with_capacity((num_crops*out_shape + num_crops*dim) as usize);

        for i in 0..num_crops {
            let mut transposed: Vec<f32> = Vec::with_capacity((n_patches*dim) as usize);

            for k in 0..n_patches {
                for c in (0..dim*n_patches).step_by(n_patches as usize) {
                    transposed.push(patch_embeds[(i*(out_shape) + c + k) as usize]);
                }
            }

            embeddings.extend(concat::<f32>(w.class_embedding, &transposed));
        }

        n_patches += 1;
        out_shape += p.dim;

        // Add position embeddings
        for i in 0..embeddings.len() {
            let p_idx = i % out_shape as usize;

            embeddings[i as usize] += w.position_embedding[p_idx as usize];
        }

        // Input layernorm
        let mut norm_embeddings = vec![0.0; embeddings.len()];
        
        for i in 0..num_crops {
            for k in 0..n_patches {
                layernorm(&mut norm_embeddings[(i*out_shape+k*dim) as usize..(i*out_shape+k*dim + p.dim) as usize], 
                    &embeddings[(i*out_shape+k*dim) as usize..(i*out_shape+k*dim + p.dim) as usize], w.pre_layer_norm, w.pre_layer_norm_bias, dim as usize, p.layernorm_eps);
            }
        }
        
        let mut qkv: Vec<f32> = vec![0.0; (norm_embeddings.len() * 3) as usize];

        // In PHI they use the penultimate layer output (?)
        for l in 0..p.n_layers-1 {
            let mut x = norm_embeddings.clone();
            
            for i in 0..num_crops {
                for k in 0..n_patches {
                    layernorm(&mut embeddings[(i*out_shape+k*dim) as usize..(i*out_shape+k*dim + p.dim) as usize], 
                        &norm_embeddings[(i*out_shape+k*dim) as usize..(i*out_shape+k*dim + p.dim) as usize], &w.layer_norm1[(l*dim) as usize..(l*dim + dim) as usize], &w.layer_norm1_bias[(l*dim) as usize..(l*dim + dim) as usize], dim as usize, p.layernorm_eps);
                }
            }

            for i in 0..num_crops {
                qkv[(i*out_shape*3) as usize..(i*out_shape*3 + out_shape*3) as usize].par_chunks_mut((dim*3) as usize).enumerate().for_each( |(h, xb)| {
                    unsafe {
                        if !quantized {
                            matmul(&mut xb[..dim as usize], &embeddings[(i*out_shape+(h as u32*dim)) as usize..(i*out_shape+(h as u32*dim) + dim) as usize], &w.wq.assume_init()[(l*dim*att_dim) as usize..(l*dim*att_dim + dim*att_dim) as usize]);
                            matmul(&mut xb[dim as usize..(dim*2) as usize], &embeddings[(i*out_shape+(h as u32*dim)) as usize..(i*out_shape+(h as u32*dim) + dim) as usize], &w.wk.assume_init()[(l*dim*att_dim) as usize..(l*dim*att_dim + dim*att_dim) as usize]);
                            matmul(&mut xb[(dim*2) as usize..(dim*3) as usize], &embeddings[(i*out_shape+(h as u32*dim)) as usize..(i*out_shape+(h as u32*dim) + dim) as usize], &w.wv.assume_init()[(l*dim*att_dim) as usize..(l*dim*att_dim + dim*att_dim) as usize]);
                        } else {
                            let mut sxq = MutableQuantizedTensor { q: &mut vec![0; (dim) as usize], s: &mut vec![0.0; dim as usize]};

                            if p.q_type == QuantType::Q8_0 {
                                quantize(&mut sxq, &embeddings[(i*out_shape+(h as u32*dim)) as usize..(i*out_shape+(h as u32*dim) + dim) as usize], dim as usize, gs);
                                
                                matmul_q8(&mut xb[..dim as usize], &sxq, &w.wq_quant.assume_init()[l as usize], dim as usize, gs as usize);
                                matmul_q8(&mut xb[dim as usize..(dim*2) as usize], &sxq, &w.wk_quant.assume_init()[l as usize], dim as usize, gs as usize);
                                matmul_q8(&mut xb[(dim*2) as usize..(dim*3) as usize], &sxq, &w.wv_quant.assume_init()[l as usize], dim as usize, gs as usize);
                            } else if p.q_type == QuantType::Q4_0 {
                                quantize_q4(&mut sxq, &embeddings[(i*out_shape+(h as u32*dim)) as usize..(i*out_shape+(h as u32*dim) + dim) as usize], dim as usize, gs);
                                
                                matmul_q4(&mut xb[..dim as usize], &sxq, &w.wq_quant.assume_init()[l as usize], dim as usize, gs as usize);
                                matmul_q4(&mut xb[dim as usize..(dim*2) as usize], &sxq, &w.wk_quant.assume_init()[l as usize], dim as usize, gs as usize);
                                matmul_q4(&mut xb[(dim*2) as usize..(dim*3) as usize], &sxq, &w.wv_quant.assume_init()[l as usize], dim as usize, gs as usize);
                            }
                        }
                    }
                    
                    // Add bias
                    let n_simd = dim/8;
                    let scale = f32x8::splat((head_size as f32).sqrt());
                    
                    for k in 0..n_simd {
                        let wq_bias_vec = f32x8::from(&w.wq_bias[(l*dim+k*8) as usize..(l*dim+k*8+8) as usize]);
                        let wk_bias_vec = f32x8::from(&w.wk_bias[(l*dim+k*8) as usize..(l*dim+k*8+8) as usize]);
                        let wv_bias_vec = f32x8::from(&w.wv_bias[(l*dim+k*8) as usize..(l*dim+k*8+8) as usize]);

                        let mut xq_vec = f32x8::from(&xb[(k*8) as usize..(k*8+8) as usize]);
                        let mut xk_vec = f32x8::from(&xb[(dim+k*8) as usize..(dim+k*8+8) as usize]);
                        let mut xv_vec = f32x8::from(&xb[(dim*2+k*8) as usize..(dim*2+k*8+8) as usize]);

                        xq_vec += wq_bias_vec;
                        xk_vec += wk_bias_vec;
                        xv_vec += wv_bias_vec;
                        
                        // Apply scale here for convenience
                        let xq = (xq_vec/scale).to_array();
                        let xk = xk_vec.to_array();
                        let xv = xv_vec.to_array();

                        for j in 0..8 {
                            xb[(k*8 + j) as usize] = xq[j as usize];
                            xb[(k*8 + dim + j) as usize] = xk[j as usize];
                            xb[(k*8 + 2*dim + j) as usize] = xv[j as usize];
                        }
                    }
                })
            }
            
            // Split into q k v, and reshape so all the heads are consequent

            let (q, k, v) = qkv_split(&qkv, dim, num_crops, p.n_heads, n_patches, out_shape);

            let att_size = p.n_heads*n_patches*n_patches;
            let mut att: Vec<f32> = vec![0.0; (att_size*num_crops) as usize];
            
            // Q * K
            // Shape - ((c*heads)*T*head_size)
             
            for i in 0..num_crops {
                att[(i*att_size) as usize..(i*att_size + att_size) as usize].par_chunks_mut((n_patches) as usize).enumerate().for_each( |(h, xb)| {
                    let curr_head = h as u32 / n_patches;
                    matmul(xb, &q[(i*out_shape + (h as u32 * head_size)) as usize..(i*out_shape + (h as u32 * head_size) + head_size) as usize], &k[(i*out_shape + (curr_head*head_size*n_patches)) as usize..((i*out_shape) + (curr_head*head_size*n_patches) + head_size*n_patches) as usize]);
                })
            }

            // Softmax

            for i in 0..(num_crops*p.n_heads) {
                for k in 0..n_patches {
                    softmax(&mut att[(i*n_patches*n_patches + k*n_patches) as usize..(i*n_patches*n_patches + k*n_patches + n_patches) as usize])
                }
            }

            // Attention weights * v

            for i in 0..num_crops {
                embeddings[(i*out_shape) as usize..(i*out_shape + out_shape) as usize].par_chunks_mut((head_size) as usize).enumerate().for_each( |(h, xb)| {
                    let curr_head = h as u32 / n_patches;
                    matmul(xb, &att[(i*att_size + (h as u32 * n_patches)) as usize..(i*att_size + (h as u32 * n_patches) + n_patches) as usize], &v[(i*out_shape + curr_head*n_patches*head_size) as usize..((i*out_shape + curr_head*n_patches*head_size) + n_patches*head_size) as usize]);
                })
            }
            
            // Transpose V from (t1h1,t2h1,t3h1...) -> (t1h1t1h2t1h3t2hh1t2h2...)
            
            for i in 0..num_crops {
                for k in 0..n_patches {
                    for j in 0..p.n_heads {
                        norm_embeddings[(i*out_shape + k*dim + j*head_size) as usize..(i*out_shape + k*dim + j*head_size + head_size) as usize].copy_from_slice(&embeddings[(i*out_shape + j*head_size*n_patches + k*head_size) as usize..(i*out_shape + j*head_size*n_patches + k*head_size + head_size) as usize]);
                    }
                }
            }
            
            // Out linear projection
            
            for i in 0..num_crops {
                embeddings[(i*out_shape) as usize..(i*out_shape + out_shape) as usize].par_chunks_mut((dim) as usize).enumerate().for_each( |(h, xb)| {
                    unsafe {
                        if !quantized {
                            matmul(xb, &norm_embeddings[(i*out_shape+(h as u32*dim)) as usize..(i*out_shape+(h as u32*dim) + dim) as usize], &w.wo.assume_init()[(l*dim*att_dim) as usize..(l*dim*att_dim + dim*att_dim) as usize]);
                        } else {
                            let mut sxq = MutableQuantizedTensor { q: &mut vec![0; (dim) as usize], s: &mut vec![0.0; dim as usize]};

                            if p.q_type == QuantType::Q8_0 {
                                quantize(&mut sxq, &norm_embeddings[(i*out_shape+(h as u32*dim)) as usize..(i*out_shape+(h as u32*dim) + dim) as usize], dim as usize, gs);
                                
                                matmul_q8(xb, &sxq, &w.wo_quant.assume_init()[l as usize], dim as usize, gs as usize);
                            } else if p.q_type == QuantType::Q4_0 {
                                quantize_q4(&mut sxq, &norm_embeddings[(i*out_shape+(h as u32*dim)) as usize..(i*out_shape+(h as u32*dim) + dim) as usize], dim as usize, gs);
                                
                                matmul_q4(xb, &sxq, &w.wo_quant.assume_init()[l as usize], dim as usize, gs as usize);
                            }
                        }
                    }
                    
                    // Add bias
                    let n_simd = dim/8;
                    
                    for k in 0..n_simd {
                        let wo_bias_vec = f32x8::from(&w.wo_bias[(l*dim+k*8) as usize..(l*dim+k*8+8) as usize]);

                        let mut xo_vec = f32x8::from(&xb[(k*8) as usize..(k*8+8) as usize]);

                        xo_vec += wo_bias_vec;
                        
                        let xo = xo_vec.to_array();

                        for j in 0..8 {
                            xb[(k*8 + j) as usize] = xo[j as usize];
                        }
                    }
                })
            }
            
            // Add residual 

            for i in 0..num_crops {
                for t in 0..n_patches {
                    for d in 0..dim {
                        embeddings[(i*out_shape + t * dim + d) as usize] += x[(i*out_shape + t * dim + d) as usize];
                    }
                }
            }

            x.copy_from_slice(&embeddings);
            
            for i in 0..num_crops {
                for k in 0..n_patches {
                    layernorm(&mut norm_embeddings[(i*out_shape+k*dim) as usize..(i*out_shape+k*dim + p.dim) as usize], 
                        &embeddings[(i*out_shape+k*dim) as usize..(i*out_shape+k*dim + p.dim) as usize], &w.layer_norm2[(l*dim) as usize..(l*dim + dim) as usize], &w.layer_norm2_bias[(l*dim) as usize..(l*dim + dim) as usize], dim as usize, p.layernorm_eps);
                }
            }
            
            // MLP with QuickGELU activation w2(QuickGELU(w1(x)))

            for i in 0..num_crops {
                embeddings[(i*out_shape) as usize..(i*out_shape + out_shape) as usize].par_chunks_mut((dim) as usize).enumerate().for_each( |(h, xb)| {
                    let mut hidden_emb = vec![0.0; hidden_dim as usize];

                    unsafe {
                        if !quantized {
                            matmul(&mut hidden_emb, &norm_embeddings[(i*out_shape+(h as u32*dim)) as usize..(i*out_shape+(h as u32*dim) + dim) as usize], &w.w1.assume_init()[(l*dim*hidden_dim) as usize..(l*dim*hidden_dim + dim*hidden_dim) as usize]);
                        } else {
                            let mut sxq = MutableQuantizedTensor { q: &mut vec![0; (dim) as usize], s: &mut vec![0.0; dim as usize]};

                            if p.q_type == QuantType::Q8_0 {
                                quantize(&mut sxq, &norm_embeddings[(i*out_shape+(h as u32*dim)) as usize..(i*out_shape+(h as u32*dim) + dim) as usize], dim as usize, gs);
                                
                                matmul_q8(&mut hidden_emb, &sxq, &w.w1_quant.assume_init()[l as usize], dim as usize, gs as usize);
                            } else if p.q_type == QuantType::Q4_0 {
                                quantize_q4(&mut sxq, &norm_embeddings[(i*out_shape+(h as u32*dim)) as usize..(i*out_shape+(h as u32*dim) + dim) as usize], dim as usize, gs);
                                
                                matmul_q4(&mut hidden_emb, &sxq, &w.w1_quant.assume_init()[l as usize], dim as usize, gs as usize);
                            }
                        }
                    }
                    
                    // Add bias
                    let mut n_simd = hidden_dim/8;
                    
                    for k in 0..n_simd {
                        let w1_bias_vec = f32x8::from(&w.w1_bias[(l*hidden_dim+k*8) as usize..(l*hidden_dim+k*8+8) as usize]);

                        let mut x1_vec = f32x8::from(&hidden_emb[(k*8) as usize..(k*8+8) as usize]);

                        x1_vec += w1_bias_vec;
                        
                        let x1 = x1_vec.to_array();

                        for j in 0..8 {
                            hidden_emb[(k*8 + j) as usize] = x1[j as usize];
                            
                            // QuickGELU
                            hidden_emb[(k*8 + j) as usize] *= 1.0 / (1.0 + (-(1.702*hidden_emb[(k*8 + j) as usize])).exp());
                        }
                    }
                    
                    unsafe {
                        if !quantized {
                            matmul(xb, &hidden_emb, &w.w2.assume_init()[(l*dim*hidden_dim) as usize..(l*dim*hidden_dim + dim*hidden_dim) as usize]);
                        } else {
                            let mut sxq = MutableQuantizedTensor { q: &mut vec![0; (hidden_dim) as usize], s: &mut vec![0.0; hidden_dim as usize]};

                            if p.q_type == QuantType::Q8_0 {
                                quantize(&mut sxq, &hidden_emb, hidden_dim as usize, gs);
                                
                                matmul_q8(xb, &sxq, &w.w2_quant.assume_init()[l as usize], hidden_dim as usize, gs as usize);
                            } else if p.q_type == QuantType::Q4_0 {
                                quantize_q4(&mut sxq, &hidden_emb, hidden_dim as usize, gs);
                                
                                matmul_q4(xb, &sxq, &w.w2_quant.assume_init()[l as usize], hidden_dim as usize, gs as usize);
                            }
                        }
                    }

                    n_simd = dim/8;
                    
                    for k in 0..n_simd {
                        let w2_bias_vec = f32x8::from(&w.w2_bias[(l*dim+k*8) as usize..(l*dim+k*8+8) as usize]);

                        let mut x2_vec = f32x8::from(&xb[(k*8) as usize..(k*8+8) as usize]);

                        x2_vec += w2_bias_vec;
                        
                        let x2 = x2_vec.to_array();

                        for j in 0..8 {
                            xb[(k*8 + j) as usize] = x2[j as usize];
                        }
                    }
                })
            }
            
            // Add residual 

            for i in 0..num_crops {
                for t in 0..n_patches {
                    for d in 0..dim {
                        embeddings[(i*out_shape + t * dim + d) as usize] += x[(i*out_shape + t * dim + d) as usize];
                    }
                }
            }

            norm_embeddings.copy_from_slice(&embeddings);
        }

        // Remove CLS embedding
        let new_shape = dim*(n_patches - 1);
        let mut out_patches = vec![0.0; (num_crops*new_shape) as usize];
        
        for i in 0..num_crops {
            for p in 1..n_patches {
                out_patches[(i*new_shape + (p-1)*dim) as usize..(i*new_shape + (p-1)*dim + dim) as usize].copy_from_slice(&norm_embeddings[(i*out_shape + p*dim) as usize..(i*out_shape + p*dim + dim) as usize]);
            }
        }
         
        (out_patches, new_shape)
    }
}

// Deallocate fields created with Box::leak
impl<'a> Drop for VisionTransformer<'a> {
    fn drop(&mut self) {
        if self.args.q_type != QuantType::None {
            unsafe {
                let layer_weights_layout = Layout::array::<QuantizedTensor>(self.args.n_layers as usize).unwrap();
                dealloc(self.weights.wq_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.wk_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.wv_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.wo_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.w1_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
                dealloc(self.weights.w2_quant.assume_init().as_ptr() as *mut u8, layer_weights_layout);
            }
        }
    }
}
</document_content>
</document>
<document index="9">
<source>src/bin/backend.rs</source>
<document_content>
use clap::Parser;
use tokio::net::TcpListener;
use tokio_tungstenite::accept_async;
use futures_util::{StreamExt, SinkExt};
use tokio_tungstenite::tungstenite::Result;
use tokio_tungstenite::tungstenite::Message;
#[cfg(feature = "backend-multimodal")]
use image::ImageReader;
use serde::{Deserialize, Serialize};
use base64::prelude::*;

use lmrs::transformer::Transformer;
use lmrs::transformer::ModelType;
use lmrs::tokenizer::Tokenizer;
use lmrs::sampler::Sampler;

#[cfg(feature = "backend-multimodal")]
use lmrs::vision::VisionTransformer;
#[cfg(feature = "backend-multimodal")]
use lmrs::processor::PHI3VProcessor;

use std::time::{SystemTime, UNIX_EPOCH};
use std::fs::File;
use memmap2::Mmap;
use chrono::Local;
use std::fs;
#[cfg(feature = "backend-multimodal")]
use std::io::Cursor;

#[derive(Parser)]
#[command(name = "lmrs-api")]
#[command(version = "1.0")]
struct Args {
    #[arg(short, long, default_value = "127.0.0.1")]
    ip: String,
    #[arg(short, long, default_value = "8080")]
    port: u16,
    #[arg(long)]
    model: String,
    #[arg(long, default_value_t = String::from("tokenizer.bin"))]
    tokenizer: String,
    #[arg(long, default_value_t = 1.0f32)]
    temperature: f32,
    #[arg(long, default_value_t = 0.9f32)]
    top_p: f32,
    #[arg(long)]
    seed: Option<u64>,
    #[arg(long, default_value_t = false)]
    multimodal: bool,
    #[arg(long, default_value_t = 1)]
    num_crops: u32,
}

#[derive(Deserialize)]
struct ChatMessage {
    image: Option<String>, // Base64-encoded image, optional
    text: String,          // Text prompt
}

#[derive(Serialize)]
struct ResponseMessage {
    category: MessageCategory,
    text: String,
}

#[derive(Serialize)]
enum MessageCategory {
    STATUS,
    OUTPUT,
    FEATURE
}



#[tokio::main]
async fn main() -> Result<()> {
    let args = Args::parse();

    let addr = format!("{}:{}", args.ip, args.port);
    let listener = TcpListener::bind(&addr).await.expect("Failed to bind");

    println!("Listening on: {}", addr);

    while let Ok((stream, _)) = listener.accept().await {
        tokio::spawn(async move {
            let args = Args::parse();
            
            let model_path: &str = args.model.as_str();
            let tokenizer_path: &str = args.tokenizer.as_str();
            
            assert!(fs::metadata(tokenizer_path).is_ok(), "Tokenizer file not found: {}", tokenizer_path);
            assert!(fs::metadata(model_path).is_ok(), "Model file not found: {}", model_path);
            
            let file = File::open(&model_path).expect("Error opening model file");
            let data = unsafe { Mmap::map(&file).expect("MMap failed")  };

            let ws_stream = accept_async(stream).await.expect("Failed to accept");
            let (mut write, mut read) = ws_stream.split();

            let mut response = ResponseMessage {
                category: MessageCategory::STATUS,
                text: String::new(),
            };

            let mut tokenizer = Tokenizer::new(tokenizer_path);

            let (mut model, _offset_transformer) = Transformer::new(&data);

            #[cfg(feature = "backend-multimodal")]
            let mut vision_model: Option<VisionTransformer> = None;
            #[cfg(feature = "backend-multimodal")]
            let mut processor: Option<PHI3VProcessor> = None;
            #[cfg(feature = "backend-multimodal")]
            let mut image_pos = 0;
             
            #[cfg(feature = "backend-multimodal")]
            if !model.args.multimodal && args.multimodal {
                eprintln!("Current model doesn't support multimodality.");
                std::process::exit(1);
            } 
            
            #[cfg(feature = "backend-multimodal")]
            if args.multimodal {
                let vision_result = VisionTransformer::new(&data[_offset_transformer..]);
                vision_model = Some(vision_result.0);
                let processor_result = PHI3VProcessor::new(&data[vision_result.1 + _offset_transformer..]);
                processor = Some(processor_result);
                
                response.category = MessageCategory::FEATURE;
                response.text = String::from("multimodal");

                if write.send(serde_json::to_string(&response).unwrap().into()).await.is_err() {
                    return;
                }
            }

            let seed: u64;

            match args.seed {
                Some(seed_value) => {
                    seed = seed_value;
                }
                None => {
                    let start = SystemTime::now();
                    let since_epoch = start.duration_since(UNIX_EPOCH).expect("Error getting time since epoch");
                    seed = since_epoch.as_millis() as u64;
                }
            }

            let mut sampler = Sampler::new(model.args.vocab_size, args.temperature, args.top_p, seed);
            let mut pos = 0;
 
            while let Some(msg) = read.next().await {
                let mut message_text: String = String::from("");

                match msg {
                    Ok(Message::Text(text)) => {
                        match serde_json::from_str::<ChatMessage>(&text) {
                            Ok(parsed_message) => {
                                message_text = parsed_message.text; 

                                #[cfg(feature = "backend-multimodal")]
                                if let Some(vision_model) = &mut vision_model {
                                    if let Some(processor) = &mut processor {
                                        if let Some(image_data) = parsed_message.image {
                                            let image_data = BASE64_STANDARD.decode(image_data).expect("Image must be in base64");

                                            let cursor = Cursor::new(image_data);

                                            let img = ImageReader::new(cursor).with_guessed_format().expect("Image format not supported.").decode().expect("Error decoding image.");

                                            let rgb_image = img.to_rgb8();

                                            let (width, height) = rgb_image.dimensions();

                                            let pixels: &[u8] = rgb_image.as_raw();

                                            response.category = MessageCategory::STATUS;
                                            response.text = String::from("Preprocessing the image");

                                            if write.send(serde_json::to_string(&response).unwrap().into()).await.is_err() {
                                                break;
                                            }

                                            let (patches, w_crop, h_crop, num_crops_processed) = processor.process(pixels, width, height, vision_model.args.patch_size, args.num_crops);

                                            response.category = MessageCategory::STATUS;
                                            response.text = String::from("Encoding the image");

                                            if write.send(serde_json::to_string(&response).unwrap().into()).await.is_err() {
                                                break;
                                            }

                                            let (patch_embeddings, patch_emb_shape) = vision_model.forward(&patches, num_crops_processed);

                                            let image_features = processor.forward(&patch_embeddings, patch_emb_shape, vision_model.args.image_size/vision_model.args.patch_size/2, w_crop, h_crop);

                                            let mut prefix = model.get_embeddings(&[1, 32010, 29871, 13]);

                                            let suffix = model.get_embeddings(&[1, 29871, 13]);

                                            prefix.extend(image_features);
                                            prefix.extend(suffix);

                                            response.category = MessageCategory::STATUS;
                                            response.text = String::from("Filling KV cache");

                                            if write.send(serde_json::to_string(&response).unwrap().into()).await.is_err() {
                                                break;
                                            }

                                            pos += model.fill_kv_cache(&mut prefix, pos);
                                            image_pos += pos;
                                        }
                                    }
                                }
                            }
                            Err(e) => {
                                println!("Error parsing message: {:?}", e);
                            }
                        }
                    }
                    Ok(Message::Close(reason)) => {
                        println!("Connection closed: {:?}", reason);
                        break;
                    }
                    Err(e) => {
                        eprintln!("Error receiving message: {}", e);
                        break;
                    }
                    _ => {
                        continue;
                    }
                }
                
                let mut token: u32;
                let mut next: u32 = 0;
                let num_prompt_tokens;
                let mut user_idx = 0;
                
                let mut prompt_tokens: Vec<u32> = Vec::new();

                println!("Processing prompt: {}", message_text);
                
                // System prompt
                if model.args.model_type == ModelType::LLAMA && pos == 0 {
                    // First part of chat template with initial tags and cut off date
                    prompt_tokens.extend([128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220]);
                    
                    let today = Local::now().date_naive();
                    let formatted_date = today.format("%d %b %Y").to_string();
                    prompt_tokens.extend(tokenizer.encode(&formatted_date, false, false, false, model.args.model_type));

                    prompt_tokens.extend([271, 128009])
                }
                
                #[cfg(feature = "backend-multimodal")]
                if args.multimodal && pos == image_pos {
                    prompt_tokens.extend(tokenizer.encode(message_text.trim(), false, false, false, model.args.model_type));
                    prompt_tokens.extend([32007, 29871, 13, 32001, 29871, 13]);
                } else {
                    prompt_tokens.extend(tokenizer.encode(message_text.trim(), false, false, true, model.args.model_type));
                }
                
                #[cfg(not(feature = "backend-multimodal"))]
                prompt_tokens.extend(tokenizer.encode(message_text.trim(), false, false, true, model.args.model_type));
                num_prompt_tokens = prompt_tokens.len();

                while next != tokenizer.eos || user_idx < num_prompt_tokens {
                    if user_idx < num_prompt_tokens {
                        token = prompt_tokens[user_idx];
                        user_idx += 1;
                    } else {
                        token = next;
                    }

                    if token == tokenizer.eos && user_idx >= num_prompt_tokens {
                        break;
                    }

                    let logits: &mut [f32] = model.forward(token, pos);
                    next = sampler.sample(logits);
                    pos += 1;

                    if user_idx >= num_prompt_tokens && next != tokenizer.eos && !(model.args.model_type == ModelType::GEMMA && next == 107) {
                        let piece = tokenizer.decode(next);
                        
                        response.category = MessageCategory::OUTPUT;
                        response.text = piece;

                        if write.send(serde_json::to_string(&response).unwrap().into()).await.is_err() {
                            break;
                        }
                    }   
                } 
                
                response.category = MessageCategory::OUTPUT;
                response.text = String::from("<eos>");

                if write.send(serde_json::to_string(&response).unwrap().into()).await.is_err() {
                    break;
                }
                
                println!("Done!\n");
            }
        });
    }
    
    Ok(())
}

</document_content>
</document>
<document index="10">
<source>src/bin/chat.rs</source>
<document_content>
use lmrs::transformer::ModelType;
use lmrs::transformer::Transformer;
use lmrs::tokenizer::Tokenizer;
use lmrs::sampler::Sampler;

#[cfg(feature = "multimodal")]
use lmrs::vision::VisionTransformer;
#[cfg(feature = "multimodal")]
use lmrs::processor::PHI3VProcessor;
#[cfg(feature = "multimodal")]
use image::open;

use std::fs;
use std::io;
use std::io::Write;
use std::fs::File;
use clap::Parser;
use std::time::{SystemTime, UNIX_EPOCH, Instant};
use memmap2::Mmap;
use chrono::Local;

#[derive(Parser)]
struct Args {
    #[arg(long)]
    model: String,
    #[arg(long, default_value_t = String::from("tokenizer.bin"))]
    tokenizer: String,
    #[arg(long, default_value_t = 1.0f32)]
    temperature: f32,
    #[arg(long, default_value_t = 0.9f32)]
    top_p: f32,
    #[arg(long)]
    seed: Option<u64>,
    #[arg(long, default_value_t = false)]
    show_metrics: bool,
    #[arg(long)]
    image: Option<String>,
    #[arg(long, default_value_t = 1)]
    num_crops: u32,
}

fn main() {
    let logo = r#"
    L      M     M  RRRR    ssss
    L      MM   MM  R   R  s
    L      M M M M  RRRR    sss
    L      M  M  M  R  R       s
    LLLL   M     M  R   R  sssss
    "#;

    println!("{}", logo);

    let args = Args::parse();
    let model_path: &str = args.model.as_str();
    let tokenizer_path: &str = args.tokenizer.as_str();

    assert!(fs::metadata(tokenizer_path).is_ok(), "Tokenizer file not found: {}", tokenizer_path);
    assert!(fs::metadata(model_path).is_ok(), "Model file not found: {}", model_path);

    let mut tokenizer = Tokenizer::new(args.tokenizer.as_str());

    let file = File::open(model_path).expect("Error opening model file");
    let data = unsafe { Mmap::map(&file).expect("MMap failed")  };

    let (mut model, _offset_transformer) = Transformer::new(&data);
    
    let mut pos = 0;

    #[cfg(feature = "multimodal")]
    let mut image_path: String = String::new();
    #[cfg(feature = "multimodal")]
    let mut image_pos = 0;
    
    #[cfg(feature = "multimodal")]
    match args.image {
        Some(image_value) => {
            image_path = image_value;
        }
        None => {
        }
    };
    
    #[cfg(feature = "multimodal")]
    if !image_path.is_empty() {
        if !model.args.multimodal {
            eprintln!("Cannot use images in a non-multimodal model.");
            std::process::exit(1);
        }

        let (mut vision_model, offset_vision) = VisionTransformer::new(&data[_offset_transformer..]);
        let processor = PHI3VProcessor::new(&data[offset_vision + _offset_transformer..]);

        let img = open(image_path.clone()).expect("Image file not found!");
        let rgb_image = img.to_rgb8();

        let (width, height) = rgb_image.dimensions();

        let pixels: &[u8] = rgb_image.as_raw();

        println!("Preprocessing the image...");

        let (patches, w_crop, h_crop, num_crops_processed) = processor.process(pixels, width, height, vision_model.args.patch_size, args.num_crops);

        println!("Encoding the image...");

        let (patch_embeddings, patch_emb_shape) = vision_model.forward(&patches, num_crops_processed);

        let image_features = processor.forward(&patch_embeddings, patch_emb_shape, vision_model.args.image_size/vision_model.args.patch_size/2, w_crop, h_crop);

        let mut prefix = model.get_embeddings(&[1, 32010, 29871, 13]);

        let suffix = model.get_embeddings(&[1, 29871, 13]);

        prefix.extend(image_features);
        prefix.extend(suffix);

        println!("Filling KV Cache...\n");

        pos = model.fill_kv_cache(&mut prefix, pos);
        image_pos = pos;
    }

    let seed: u64 = match args.seed {
        Some(seed_value) => {
            seed_value
        }
        None => {
            let start = SystemTime::now();
            let since_epoch = start.duration_since(UNIX_EPOCH).expect("Error getting time since epoch");
            
            since_epoch.as_millis() as u64
        }
    };

    let mut sampler = Sampler::new(model.args.vocab_size, args.temperature, args.top_p, seed);

    let mut user_turn = true;
    let mut user_idx: usize = 0;
    let mut token: u32;
    let mut next: u32 = 0;
    let mut num_prompt_tokens = 0;
    let mut total_tokens: f32 = 0.0;
    let mut total_duration: f32 = 0.0;
    
    let mut prompt_tokens: Vec<u32> = Vec::new();
    let mut user_prompt: String;

    loop {
        if user_turn { 
            user_prompt = String::from("");

            print!("You: ");
            io::stdout().flush().unwrap();

            io::stdin().read_line(&mut user_prompt).expect("Failed to read line");
            
            // System prompt
            if model.args.model_type == ModelType::LLAMA && pos == 0 {
                // First part of chat template with initial tags and cut off date
                prompt_tokens.extend([128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220]);
                
                let today = Local::now().date_naive();
                let formatted_date = today.format("%d %b %Y").to_string();
                prompt_tokens.extend(tokenizer.encode(&formatted_date, false, false, false, model.args.model_type));

                prompt_tokens.extend([271, 128009])
            }
            
            #[cfg(feature = "multimodal")]
            if !image_path.is_empty() && pos == image_pos {
                prompt_tokens.extend(tokenizer.encode(user_prompt.trim(), false, false, false, model.args.model_type));
                prompt_tokens.extend([32007, 29871, 13, 32001, 29871, 13]);
            } else {
                prompt_tokens.extend(tokenizer.encode(user_prompt.trim(), false, false, true, model.args.model_type));
            }

            #[cfg(not(feature = "multimodal"))]
            prompt_tokens.extend(tokenizer.encode(user_prompt.trim(), false, false, true, model.args.model_type));
            
            num_prompt_tokens = prompt_tokens.len();

            user_turn = false; 
            user_idx = 0;
            
            println!("Assistant:");
        }

        if user_idx < num_prompt_tokens {
            token = prompt_tokens[user_idx];
            user_idx += 1;
        } else {
            token = next;
        }
        
        if token == tokenizer.eos && user_idx >= num_prompt_tokens { 
            user_turn = true; 
            println!();
            prompt_tokens = Vec::new();
            
            if args.show_metrics {
                let toks = total_tokens/(total_duration/1000.0);
                
                println!("Speed: {:.2} tok/s", toks);

                total_duration = 0.0;
                total_tokens = 0.0;
            } 

            continue;
        }
        
        let processing_start = Instant::now();

        let logits: &mut [f32] = model.forward(token, pos);
        next = sampler.sample(logits);
        pos += 1;

        if user_idx >= num_prompt_tokens && next != tokenizer.eos && !(model.args.model_type == ModelType::GEMMA && next == 107) {
            let piece = tokenizer.decode(next);
            print!("{}", piece);
            io::stdout().flush().unwrap();
        }   

        let duration = processing_start.elapsed();
        total_duration += duration.as_millis() as f32;
        total_tokens += 1.0;
    }
}
</document_content>
</document>
<document index="11">
<source>src/bin/desktop.rs</source>
<document_content>
use dioxus::prelude::*;

use dioxus_desktop::Config;
use lmrs::sampler::Sampler;
use lmrs::tokenizer::Tokenizer;
use lmrs::transformer::ModelType;
use lmrs::transformer::Transformer;

use chrono::Local;
use clap::Parser;
use memmap2::Mmap;
use std::fs;
use std::fs::File;
use std::sync::mpsc::channel;
use std::thread;
use std::time::{Instant, SystemTime, UNIX_EPOCH};

#[derive(Parser)]
struct Args {
    #[arg(long)]
    model: String,
    #[arg(long, default_value_t = String::from("tokenizer.bin"))]
    tokenizer: String,
    #[arg(long, default_value_t = 1.0f32)]
    temperature: f32,
    #[arg(long, default_value_t = 0.9f32)]
    top_p: f32,
    #[arg(long)]
    seed: Option<u64>,
    #[arg(long, default_value_t = false)]
    show_metrics: bool,
}

enum ModelResponse {
    Piece(String),
    Termination,
}

struct Context<'a> {
    user_idx: usize,
    pos: u32,
    token: u32,
    next: u32,
    num_prompt_tokens: usize,
    total_tokens: f32,
    total_duration: f32,
    model: Transformer<'a>,
    prompt_tokens: Vec<u32>,
    tokenizer: Tokenizer,
    sampler: Sampler,
}

impl<'a> Context<'a> {
    fn handle_user_prompt(&mut self, user_prompt: &str) {
        // System prompt
        if self.model.args.model_type == ModelType::LLAMA && self.pos == 0 {
            // First part of chat template with initial tags and cut off date
            self.prompt_tokens.extend([
                128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366,
                18, 198, 15724, 2696, 25, 220,
            ]);

            let today = Local::now().date_naive();
            let formatted_date = today.format("%d %b %Y").to_string();
            self.prompt_tokens.extend(self.tokenizer.encode(
                &formatted_date,
                false,
                false,
                false,
                self.model.args.model_type,
            ));

            self.prompt_tokens.extend([271, 128009])
        }

        self.prompt_tokens.extend(self.tokenizer.encode(
            user_prompt.trim(),
            false,
            false,
            true,
            self.model.args.model_type,
        ));
        self.num_prompt_tokens = self.prompt_tokens.len();

        self.user_idx = 0;
    }

    fn compute_model_response(&mut self, show_metrics: bool) -> ModelResponse {
        loop {
            if self.user_idx < self.num_prompt_tokens {
                self.token = self.prompt_tokens[self.user_idx];
                self.user_idx += 1;
            } else {
                self.token = self.next;
            }

            if self.token == self.tokenizer.eos && self.user_idx >= self.num_prompt_tokens {
                self.prompt_tokens.clear();

                if show_metrics {
                    let toks = self.total_tokens / (self.total_duration / 1000.0);

                    println!("Speed: {:.2} tok/s", toks);

                    self.total_duration = 0.0;
                    self.total_tokens = 0.0;
                }
                return ModelResponse::Termination;
            }

            let processing_start = Instant::now();
            let logits: &mut [f32] = self.model.forward(self.token, self.pos);
            self.next = self.sampler.sample(logits);
            self.pos += 1;
            if self.user_idx >= self.num_prompt_tokens
                && self.next != self.tokenizer.eos
                && !(self.model.args.model_type == ModelType::GEMMA && self.next == 107)
            {
                let piece = self.tokenizer.decode(self.next);
                return ModelResponse::Piece(piece);
            }
            let duration = processing_start.elapsed();
            self.total_duration += duration.as_millis() as f32;
            self.total_tokens += 1.0;
        }
    }
}

fn main() {
    dioxus_desktop::launch::launch(app, Vec::new(), Config::default());
}

enum Message {
    Bot(String),
    User(String),
}

fn app() -> Element {
    let (to_model_sender, to_model_receiver) = channel::<String>();
    let to_model_channel = use_signal(move || to_model_sender);

    let mut user_input = use_signal(|| String::new());
    let mut is_ctrl_pressed = use_signal(|| false);
    let mut is_waiting_for_response = use_signal_sync(|| false);

    let mut conversation = use_signal_sync(|| Vec::<Message>::new());

    use_hook(|| {
        thread::spawn(move || loop {
            let args = Args::parse();
            let model_path: &str = args.model.as_str();
            let tokenizer_path: &str = args.tokenizer.as_str();

            assert!(
                fs::metadata(tokenizer_path).is_ok(),
                "Tokenizer file not found: {}",
                tokenizer_path
            );
            assert!(
                fs::metadata(model_path).is_ok(),
                "Model file not found: {}",
                model_path
            );

            let tokenizer = Tokenizer::new(args.tokenizer.as_str());

            let file = File::open(model_path).expect("Error opening model file");
            let data = unsafe { Mmap::map(&file).expect("MMap failed") };

            let model = Transformer::new(&data);

            let seed: u64 = match args.seed {
                Some(seed_value) => seed_value,
                None => {
                    let start = SystemTime::now();
                    let since_epoch = start
                        .duration_since(UNIX_EPOCH)
                        .expect("Error getting time since epoch");

                    since_epoch.as_millis() as u64
                }
            };

            let sampler = Sampler::new(model.args.vocab_size, args.temperature, args.top_p, seed);

            let mut context = Context {
                user_idx: 0,
                pos: 0,
                token: 0,
                next: 0,
                num_prompt_tokens: 0,
                total_tokens: 0.0,
                total_duration: 0.0,
                model,
                prompt_tokens: Vec::new(),
                tokenizer,
                sampler,
            };

            loop {
                let Ok(user_prompt) = to_model_receiver.recv() else {
                    // The sender is droped when the app is closed
                    break;
                };

                context.handle_user_prompt(&user_prompt);

                conversation.write().push(Message::Bot(String::new()));

                loop {
                    let ModelResponse::Piece(resp_piece) = context.compute_model_response(true)
                    else {
                        is_waiting_for_response.set(false);
                        break;
                    };
                    let mut conv = conversation.write();
                    let Message::Bot(response) = conv.last_mut().expect("The bot response exists")
                    else {
                        panic!("The user cannot add messages to the conversation while the model is processing its response");
                    };
                    response.push_str(&resp_piece);
                }
            }
        });
    });

    let mut handle_submit = {
        let mut user_input = user_input.clone();

        move || {
            if is_waiting_for_response() {
                // do nothing when the model is already processing the response
                return;
            }

            let input_string = user_input();
            conversation
                .write()
                .push(Message::User(input_string.clone()));
            let input_trimmed = input_string.trim().to_owned();
            if input_trimmed.is_empty() {
                return;
            }
            is_waiting_for_response.set(true);
            user_input.write().clear();
            to_model_channel()
                .send(input_trimmed)
                .expect("Failed to send the input to the model thread");
        }
    };

    rsx! {
        head {
            link {
                rel: "stylesheet",
                href: "https://samuel-vitorino.github.io/lm.rs-webui/assets/index-DTGPVQCS.css",
            }
        }
        body {
            class: "dark:bg-[#212121]",
            div {
                id: "root",
                class: "flex flex-col",
                nav {
                    class: "bg-white px-2 py-2.5 dark:border-gray-700 sm:px-4 dark:bg-[#212121]",
                    div {
                        class: "mx-auto flex flex-wrap items-center justify-between",
                        span {
                            class: "flex items-center",
                            img {
                                src: "https://samuel-vitorino.github.io/lm.rs-webui/rust.svg",
                                class: "mr-3 h-6 sm:h-9",
                                alt: "lm.rs logo",
                            }
                            span {
                                class: "self-center whitespace-nowrap text-xl font-semibold dark:text-white",
                                "lm.rs",
                            }
                        }
                    }
                }
                div {
                    class: "flex flex-col justify-between mx-3 sm:mx-20 md:mx-44 xl:mx-96",
                    style: "max-height: 78%; overflow-y: auto;",
                    id: "chat-container",
                    for message in conversation.read().iter(){
                        {
                            match message {
                                Message::Bot(msg) => {
                                    rsx!{ MessageComponent{
                                        message: msg,
                                        is_user: false,
                                    }}
                                },
                                Message::User(msg) => {
                                    rsx!{ MessageComponent{
                                        message: msg,
                                        is_user: true,
                                    }}
                                },
                            }
                        }
                    }
                }
                footer {
                    class: "flex items-center justify-around dark:bg-[#303030] dark:text-white bg-[#f5f5f5] mx-3 md:m-0 md:w-6/12 md:self-center",
                    id: "chat-input",
                    textarea {
                        rows: "1",
                        placeholder: "Message lm.rs",
                        class: "bg-transparent",
                        id: "chat-text-area",
                        disabled: is_waiting_for_response(),
                        value: "{user_input}",
                        oninput: move |e| user_input.set(e.value().clone()),
                        onkeydown: move|event|{
                            if event.key() == Key::Control{
                                is_ctrl_pressed.set(true);
                            }
                        },
                        onkeyup: move|event|{
                            if event.key() == Key::Control{
                                is_ctrl_pressed.set(false);
                            }
                        },
                        onkeypress: move |event| {
                            if event.key() == Key::Enter && is_ctrl_pressed(){
                                handle_submit()
                            }
                        },
                    }
                    button {
                        id: "chat-send-button",
                        class: "flex",
                        disabled: is_waiting_for_response(),
                        onclick: move |_| handle_submit(),
                        svg {
                            stroke: "currentColor",
                            fill: "currentColor",
                            "stroke-width": "0",
                            view_box: "0 0 448 512",
                            class: "self-center",
                            height: "20px",
                            width: "20px",
                            xmlns: "http://www.w3.org/2000/svg",
                            path {
                                d: "M34.9 289.5l-22.2-22.2c-9.4-9.4-9.4-24.6 0-33.9L207 39c9.4-9.4 24.6-9.4 33.9 0l194.3 194.3c9.4 9.4 9.4 24.6 0 33.9L413 289.4c-9.5 9.5-25 9.3-34.3-.4L264 168.6V456c0 13.3-10.7 24-24 24h-32c-13.3 0-24-10.7-24-24V168.6L69.2 289.1c-9.3 9.8-24.8 10-34.3.4z"
                            }
                        }
                    }
                }
            }
        }
    }
}

#[component]
fn MessageComponent(message: String, is_user: bool) -> Element {
    rsx! {
        div {
            class: if is_user {"flex items-center message_l1_user" } else { "flex items-center message_l1_bot"},
            style: if is_user {r#"
                align-self: flex-end;
                margin-top: 0px;
                margin-bottom: 0px;
                width: 100%;
            "# }else{ r#"
                align-self: flex-start;
                margin-top: 30px;
                margin-bottom: 30px;
                width: 100%;
            "#},
            div {
                class: "flex flex-col",
                    style: if is_user {r#"
                        margin-left: 0px;
                        text-wrap: wrap;
                        overflow-wrap: break-word;
                        width: 100%;
                    "# }else{ r#"
                        margin-left: 20px;
                        text-wrap: wrap;
                        overflow-wrap: break-word;
                        width: 100%;
                    "#},
                Markdown {
                    class: if is_user { "bg-[#f5f5f5] dark:bg-[#323232] dark:text-white" } else {"dark:text-white"},
                    style: if is_user { r#"
                        border-radius: 20px;
                        padding: 10px 20px 10px 20px;
                        max-width: 40%;
                        margin-left: auto;
                        width: fit-content;
                    "#}else{r#"
                        margin-left: 0px;
                        text-wrap: wrap;
                        overflow-wrap: break-word;
                        width: 100%;
                    "#},
                    input: "{message}"
                }
            }
        }
    }
}

#[component]
fn Markdown(input: String, class: String, style: String) -> Element {
    let parser = pulldown_cmark::Parser::new(&input);
    let mut html_output: String = String::with_capacity(input.len() * 3 / 2);
    pulldown_cmark::html::push_html(&mut html_output, parser);

    rsx! { div {
    class: "{class}",
    style: "{style}",
    dangerous_inner_html: "{html_output}" } }
}

</document_content>
</document>
<document index="12">
<source>Cargo.toml</source>
<document_content>
[package]
name = "lmrs"
version = "0.1.0"
edition = "2021"

[features]
default = []
backend = ["dep:tokio-tungstenite", "dep:futures-util", "dep:tokio"]
desktop = ["dep:dioxus", "dep:dioxus-desktop", "dep:pulldown-cmark"]
backend-multimodal = ["dep:base64", "dep:image", "dep:serde_json", "dep:serde"]
multimodal = ["dep:image"]

[dependencies]
wide = "0.7.28"
memmap2 = "0.9.4"
rayon = "1.10.0"
chrono = "0.4.38"
clap = { version = "4.5.13", features = ["derive"] }

image = { version = "0.25.2", optional = true }
tokio-tungstenite = { version = "0.23.1", optional = true }
futures-util = { version = "0.3.28", default-features = false, features = ["sink", "std"], optional = true }
tokio = { version = "1.0.0", features = ["full"], optional = true }
base64 = { version = "0.22.1", optional = true }
serde_json = { version = "1.0.128", optional = true }
serde = { version = "1.0.210", features = ["derive"], optional = true }


dioxus = { version = "0.5.0", optional = true }
dioxus-desktop = { version = "0.5.0", optional = true }
pulldown-cmark = { version = "0.12.1", default-features = false , features = ["html"], optional = true}

[[bin]]
name = "backend"
required-features = ["backend"]

[[bin]]
name = "desktop"
required-features = ["desktop"]

[profile.release]
codegen-units = 1
lto = true
</document_content>
</document>
</documents>

System:

describe the most important depnedencies used by this project

Response:

Based on the provided source code, the most important dependencies used by this project are:

  1. wide: Used for SIMD operations and vector math.

  2. memmap2: Used for memory-mapped file I/O.

  3. rayon: Used for parallel processing and multi-threading.

  4. chrono: Used for date and time functionality.

  5. clap: Used for command-line argument parsing.

  6. image (optional): Used for image processing in multimodal features.

  7. tokio and tokio-tungstenite (optional): Used for asynchronous I/O and WebSocket support in the backend.

  8. futures-util (optional): Used for asynchronous programming constructs.

  9. serde and serde_json (optional): Used for serialization and deserialization of JSON data.

  10. dioxus and dioxus-desktop (optional): Used for building the desktop GUI application.

  11. pulldown-cmark (optional): Used for Markdown parsing in the desktop application.

These dependencies provide crucial functionality for various aspects of the project, including performance optimizations, file handling, parallel processing, command-line interfaces, and optional features for multimodal support, backend communication, and desktop GUI development.

2024-10-11T19:26:08

Prompt:

In more detail how are wide and memmap2 and rayon used?

Response:

Certainly! Let's dive into more detail on how wide, memmap2, and rayon are used in this project:

  1. wide:

The wide crate is used for SIMD (Single Instruction, Multiple Data) operations, which allow for efficient parallel processing of data at the CPU instruction level. In this project, it's primarily used in numerical computations and matrix operations:

  • In src/functional.rs, f32x8 and i32x8 types from wide are used for vectorized operations:

    • In functions like rmsnorm, layernorm, and matmul, f32x8 is used to process 8 float values simultaneously.
    • In matmul_q8 and matmul_q4, i32x8 is used for vectorized integer operations.
  • In src/vision.rs, f32x8 is used in the forward method of the VisionTransformer for efficient matrix multiplications and bias additions.

These SIMD operations significantly speed up the mathematical computations that are core to the transformer model's functionality.

  1. memmap2:

memmap2 is used for memory-mapped file I/O, which allows the program to access file contents as if they were in memory, potentially improving performance for large files:

  • In src/bin/chat.rs and src/bin/backend.rs, memmap2 is used to map the model file into memory:

    let file = File::open(model_path).expect("Error opening model file");
    let data = unsafe { Mmap::map(&file).expect("MMap failed") };
  • The mapped data is then passed to the Transformer::new() function, allowing the model to access its weights and parameters directly from the memory-mapped file.

This approach can be more efficient than reading the entire file into memory, especially for large model files, as it allows the operating system to manage paging and caching of the file contents.

  1. rayon:

Rayon is used for parallel processing and multi-threading, particularly in computationally intensive operations:

  • In src/functional.rs, rayon's parallel iterators are used in several functions:

    • matmul uses .par_iter_mut().enumerate().for_each() to parallelize matrix multiplication.
    • matmul_q8 and matmul_q4 also use parallel iterators for quantized matrix multiplications.
  • In src/vision.rs, rayon is used in the VisionTransformer's forward method:

    • Multiple .par_chunks_mut().enumerate().for_each() calls are used to parallelize various parts of the transformer's forward pass, including attention computations and linear projections.
  • In src/processor.rs, rayon is used in the PHI3VProcessor's forward method for parallel processing of image features.

By using rayon, the project can take advantage of multi-core processors to speed up these computationally intensive tasks, potentially providing significant performance improvements on modern hardware.

These three libraries work together to optimize different aspects of the program:

  • wide optimizes at the CPU instruction level with SIMD.
  • memmap2 optimizes file I/O and memory usage.
  • rayon optimizes by utilizing multiple CPU cores for parallel processing.

Together, they contribute to making the language model implementation more efficient and performant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment