Rust naive bayes machine learning classifier
| /// Author Ben McDonald. Adapted from python scikit-learn | |
| /// | |
| /// Naive bayes is a machine learning classifier. It uses bayes formula to classify data | |
| /// based on how well data fits into normal distributions modeled from training data | |
| use rulinalg::matrix::BaseMatrix; | |
| use rulinalg::matrix::BaseMatrixMut; | |
| use rulinalg::matrix::{Axes, Matrix}; | |
| use std::collections::HashMap; | |
| use std::error::Error; | |
| use std::f64::consts::PI; | |
| pub struct GaussianBayesModel { | |
| pub categories: Vec<usize>, | |
| pub means_matrix: Matrix<f64>, | |
| pub variances_matrix: Matrix<f64>, | |
| } | |
| #[derive(Debug, Display)] | |
| #[display(fmt = "EmptyInputError: Training data is empty")] | |
| pub struct EmptyInputError; | |
| impl Error for EmptyInputError {} | |
| #[derive(Debug, Display)] | |
| #[display(fmt = "MismatchedFeaturesError: Test data length does not match number of features in suppied model.")] | |
| pub struct MismatchedFeaturesError; | |
| impl Error for MismatchedFeaturesError {} | |
| /// | |
| /// Summarise the dataset per class using normal distributions | |
| /// training_x_y - array of tuple trainging data. len = number of samples, first tuple elements are observations | |
| /// and 2nd is the class | |
| pub fn build_model(training_x_y: Vec<(Vec<f64>, usize)>) -> Result<GaussianBayesModel, Box<dyn Error + 'static>> { | |
| let num_features: usize = training_x_y | |
| .first() | |
| .ok_or_else(|| EmptyInputError {})? | |
| .0 | |
| .len(); | |
| let mut categories_map: HashMap<usize, Vec<f64>> = HashMap::new(); | |
| for x in training_x_y.iter() { | |
| categories_map | |
| .entry(x.1) | |
| .or_insert(Vec::new()) | |
| .extend(x.0.clone()); | |
| } | |
| let num_categories = categories_map.len(); | |
| let mut categories: Vec<usize> = Vec::with_capacity(num_categories); | |
| let mut means: Vec<f64> = Vec::with_capacity(num_categories * num_features); | |
| let mut variances: Vec<f64> = Vec::with_capacity(num_categories * num_features); | |
| for (category, category_grouped) in categories_map.iter() { | |
| categories.push(category.clone()); | |
| let num_samples_in_category: usize = category_grouped.len() / num_features; | |
| let cat_matrix = Matrix::new( | |
| num_samples_in_category, | |
| num_features, | |
| category_grouped.clone(), | |
| ); | |
| let mean = cat_matrix.mean(Axes::Row); | |
| // cat_matrix.variance can fail if no variance in training feature | |
| let variance = cat_matrix.variance(Axes::Row)?; | |
| means.extend(mean.into_vec()); | |
| variances.extend(variance.into_vec()); | |
| } | |
| let means_matrix = Matrix::new(num_features, num_categories, means); | |
| let variances_matrix = Matrix::new(num_features, num_categories, variances); | |
| Ok(GaussianBayesModel { | |
| categories, | |
| means_matrix, | |
| variances_matrix, | |
| }) | |
| } | |
| /// Find index of highest value in array | |
| fn max_index(array: &[impl PartialOrd]) -> usize { | |
| let mut i = 0; | |
| for (j, ref value) in array.iter().enumerate() { | |
| if value > &&array[i] { | |
| i = j; | |
| } | |
| } | |
| i | |
| } | |
| /// | |
| /// Make a prediction | |
| /// test_x - array where len = num_features | |
| /// model - model trained with fn build_model | |
| pub fn fit_model(test_x: Vec<f64>, model: &GaussianBayesModel) -> Result<usize, Box<dyn Error + 'static>> { | |
| let num_categories:usize = model.means_matrix.cols(); | |
| let num_features:usize = model.means_matrix.rows(); | |
| if test_x.len() != num_features { | |
| return Err((MismatchedFeaturesError {}).into()) | |
| } | |
| let test_x_repeated: Vec<f64> = test_x | |
| .iter() | |
| .cycle() | |
| .take(test_x.len() * num_categories) | |
| .map(|&x| x) | |
| .collect::<Vec<f64>>(); | |
| let mut test_x_matrix: Matrix<f64> = | |
| Matrix::new(num_features, num_categories, test_x_repeated.clone()); | |
| let mean_diff_sqrt: Matrix<f64> = | |
| (test_x_matrix.as_mut_slice() - model.means_matrix.clone()).apply(&|x| -(x * x)); | |
| let exponent = mean_diff_sqrt | |
| .elediv(&model.variances_matrix.clone().apply(&|x| x * 2.0)) | |
| .apply(&|x| x.exp()); | |
| let sqrt_2pi: f64 = ((PI * 2.0) as f64).sqrt(); | |
| let inner_divide = model | |
| .variances_matrix | |
| .clone() | |
| .apply(&|x| 1.0 / (sqrt_2pi * x.sqrt())); | |
| let unreduced_probabilies = exponent.elemul(&inner_divide); | |
| // Vector of the probabilities of the test data fitting a catagory. | |
| let unreduced_probabilies_data = unreduced_probabilies.data(); | |
| let mut probabilities_each_catagory: Vec<f64> = Vec::with_capacity(num_categories); | |
| for n in 0..(num_categories) { | |
| let mut prob: f64 = 0.0; | |
| for m in 0..(num_features) { | |
| prob += unreduced_probabilies_data[(n * num_features) + m] | |
| } | |
| probabilities_each_catagory.push(prob); | |
| } | |
| // Find the highest probability. | |
| // argmax returns (index_of_max_value, max_value). Take first element, the index | |
| let best_fit_index = max_index(&probabilities_each_catagory); | |
| Ok(model.categories[best_fit_index]) | |
| } | |
| #[cfg(test)] | |
| mod tests { | |
| use super::*; | |
| #[test] | |
| #[should_panic] | |
| fn test_panic_on_empty_input() { | |
| build_model(vec![]).unwrap();; | |
| } | |
| #[test] | |
| #[should_panic] | |
| fn test_panic_on_one_input() { | |
| build_model(vec![(vec![11.0, 2.0], 1)]).unwrap(); | |
| } | |
| #[test] | |
| #[should_panic] | |
| fn test_panic_on_mismatched_features() { | |
| build_model(vec![(vec![11.0, 2.0], 1), (vec![10.0], 1)]).unwrap(); | |
| } | |
| #[test] | |
| #[should_panic] | |
| fn test_panic_no_variance() { | |
| build_model(vec![ | |
| (vec![11.0, 2.0], 1), | |
| (vec![10.0, 1.0], 1), | |
| (vec![11.0, 2.0], 2), | |
| ]) | |
| .unwrap(); | |
| } | |
| #[test] | |
| fn test_model() { | |
| let x_input = vec![ | |
| (vec![1.0, 20.0], 1), | |
| (vec![20.0, 210.0], 0), | |
| (vec![3.0, 22.0], 1), | |
| (vec![40.0, 220.0], 0), | |
| (vec![6.0, 10.0], 2), | |
| (vec![7.0, 11.0], 2), | |
| (vec![7.0, 11.0], 2), | |
| ]; | |
| // means_matrix | |
| // ([[100.00000821, 25.00000821], | |
| // [ 1.00000821, 1.00000821], | |
| // [ 0.22223043, 0.22223043]]) | |
| // variances_matrix | |
| // ([[ 30. , 215. ], | |
| // [ 2. , 21. ], | |
| // [ 6.66666667, 10.66666667]]) | |
| let model: GaussianBayesModel = build_model(x_input).unwrap(); | |
| assert_eq!(fit_model(vec![1.0, 20.0], &model).unwrap(), 1); | |
| assert_eq!(fit_model(vec![20.0, 210.0], &model).unwrap(), 0); | |
| assert_eq!(fit_model(vec![25.0, 225.0], &model).unwrap(), 0); | |
| assert_eq!(fit_model(vec![30.0, 215.0], &model).unwrap(), 0); | |
| assert_eq!(fit_model(vec![3.0, 22.0], &model).unwrap(), 1); | |
| assert_eq!(fit_model(vec![6.0, 10.0], &model).unwrap(), 2); | |
| } | |
| } |
This comment has been minimized.
This comment has been minimized.
|
Not easy to simplify at the moment. Try operator (?) on an Option throws a nightly only NoneError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This comment has been minimized.
Charles-Johnson commentedJan 11, 2019
Why do you need to define the struct
ModelErrorwhen it's just aString?