Instantly share code, notes, and snippets.

Embed
What would you like to do?
Rust naive bayes machine learning classifier
/// Author Ben McDonald. Adapted from python scikit-learn
///
/// Naive bayes is a machine learning classifier. It uses bayes formula to classify data
/// based on how well data fits into normal distributions modeled from training data
use rulinalg::matrix::BaseMatrix;
use rulinalg::matrix::BaseMatrixMut;
use rulinalg::matrix::{Axes, Matrix};
use std::collections::HashMap;
use std::error::Error;
use std::f64::consts::PI;
pub struct GaussianBayesModel {
pub categories: Vec<usize>,
pub means_matrix: Matrix<f64>,
pub variances_matrix: Matrix<f64>,
}
#[derive(Debug, Display)]
#[display(fmt = "EmptyInputError: Training data is empty")]
pub struct EmptyInputError;
impl Error for EmptyInputError {}
#[derive(Debug, Display)]
#[display(fmt = "MismatchedFeaturesError: Test data length does not match number of features in suppied model.")]
pub struct MismatchedFeaturesError;
impl Error for MismatchedFeaturesError {}
///
/// Summarise the dataset per class using normal distributions
/// training_x_y - array of tuple trainging data. len = number of samples, first tuple elements are observations
/// and 2nd is the class
pub fn build_model(training_x_y: Vec<(Vec<f64>, usize)>) -> Result<GaussianBayesModel, Box<dyn Error + 'static>> {
let num_features: usize = training_x_y
.first()
.ok_or_else(|| EmptyInputError {})?
.0
.len();
let mut categories_map: HashMap<usize, Vec<f64>> = HashMap::new();
for x in training_x_y.iter() {
categories_map
.entry(x.1)
.or_insert(Vec::new())
.extend(x.0.clone());
}
let num_categories = categories_map.len();
let mut categories: Vec<usize> = Vec::with_capacity(num_categories);
let mut means: Vec<f64> = Vec::with_capacity(num_categories * num_features);
let mut variances: Vec<f64> = Vec::with_capacity(num_categories * num_features);
for (category, category_grouped) in categories_map.iter() {
categories.push(category.clone());
let num_samples_in_category: usize = category_grouped.len() / num_features;
let cat_matrix = Matrix::new(
num_samples_in_category,
num_features,
category_grouped.clone(),
);
let mean = cat_matrix.mean(Axes::Row);
// cat_matrix.variance can fail if no variance in training feature
let variance = cat_matrix.variance(Axes::Row)?;
means.extend(mean.into_vec());
variances.extend(variance.into_vec());
}
let means_matrix = Matrix::new(num_features, num_categories, means);
let variances_matrix = Matrix::new(num_features, num_categories, variances);
Ok(GaussianBayesModel {
categories,
means_matrix,
variances_matrix,
})
}
/// Find index of highest value in array
fn max_index(array: &[impl PartialOrd]) -> usize {
let mut i = 0;
for (j, ref value) in array.iter().enumerate() {
if value > &&array[i] {
i = j;
}
}
i
}
///
/// Make a prediction
/// test_x - array where len = num_features
/// model - model trained with fn build_model
pub fn fit_model(test_x: Vec<f64>, model: &GaussianBayesModel) -> Result<usize, Box<dyn Error + 'static>> {
let num_categories:usize = model.means_matrix.cols();
let num_features:usize = model.means_matrix.rows();
if test_x.len() != num_features {
return Err((MismatchedFeaturesError {}).into())
}
let test_x_repeated: Vec<f64> = test_x
.iter()
.cycle()
.take(test_x.len() * num_categories)
.map(|&x| x)
.collect::<Vec<f64>>();
let mut test_x_matrix: Matrix<f64> =
Matrix::new(num_features, num_categories, test_x_repeated.clone());
let mean_diff_sqrt: Matrix<f64> =
(test_x_matrix.as_mut_slice() - model.means_matrix.clone()).apply(&|x| -(x * x));
let exponent = mean_diff_sqrt
.elediv(&model.variances_matrix.clone().apply(&|x| x * 2.0))
.apply(&|x| x.exp());
let sqrt_2pi: f64 = ((PI * 2.0) as f64).sqrt();
let inner_divide = model
.variances_matrix
.clone()
.apply(&|x| 1.0 / (sqrt_2pi * x.sqrt()));
let unreduced_probabilies = exponent.elemul(&inner_divide);
// Vector of the probabilities of the test data fitting a catagory.
let unreduced_probabilies_data = unreduced_probabilies.data();
let mut probabilities_each_catagory: Vec<f64> = Vec::with_capacity(num_categories);
for n in 0..(num_categories) {
let mut prob: f64 = 0.0;
for m in 0..(num_features) {
prob += unreduced_probabilies_data[(n * num_features) + m]
}
probabilities_each_catagory.push(prob);
}
// Find the highest probability.
// argmax returns (index_of_max_value, max_value). Take first element, the index
let best_fit_index = max_index(&probabilities_each_catagory);
Ok(model.categories[best_fit_index])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic]
fn test_panic_on_empty_input() {
build_model(vec![]).unwrap();;
}
#[test]
#[should_panic]
fn test_panic_on_one_input() {
build_model(vec![(vec![11.0, 2.0], 1)]).unwrap();
}
#[test]
#[should_panic]
fn test_panic_on_mismatched_features() {
build_model(vec![(vec![11.0, 2.0], 1), (vec![10.0], 1)]).unwrap();
}
#[test]
#[should_panic]
fn test_panic_no_variance() {
build_model(vec![
(vec![11.0, 2.0], 1),
(vec![10.0, 1.0], 1),
(vec![11.0, 2.0], 2),
])
.unwrap();
}
#[test]
fn test_model() {
let x_input = vec![
(vec![1.0, 20.0], 1),
(vec![20.0, 210.0], 0),
(vec![3.0, 22.0], 1),
(vec![40.0, 220.0], 0),
(vec![6.0, 10.0], 2),
(vec![7.0, 11.0], 2),
(vec![7.0, 11.0], 2),
];
// means_matrix
// ([[100.00000821, 25.00000821],
// [ 1.00000821, 1.00000821],
// [ 0.22223043, 0.22223043]])
// variances_matrix
// ([[ 30. , 215. ],
// [ 2. , 21. ],
// [ 6.66666667, 10.66666667]])
let model: GaussianBayesModel = build_model(x_input).unwrap();
assert_eq!(fit_model(vec![1.0, 20.0], &model).unwrap(), 1);
assert_eq!(fit_model(vec![20.0, 210.0], &model).unwrap(), 0);
assert_eq!(fit_model(vec![25.0, 225.0], &model).unwrap(), 0);
assert_eq!(fit_model(vec![30.0, 215.0], &model).unwrap(), 0);
assert_eq!(fit_model(vec![3.0, 22.0], &model).unwrap(), 1);
assert_eq!(fit_model(vec![6.0, 10.0], &model).unwrap(), 2);
}
}
@Charles-Johnson

This comment has been minimized.

Copy link

Charles-Johnson commented Jan 11, 2019

Why do you need to define the struct ModelError when it's just a String?

@benjaminmcdonald

This comment has been minimized.

Copy link
Owner

benjaminmcdonald commented Jan 12, 2019

Not easy to simplify at the moment. Try operator (?) on an Option throws a nightly only NoneError
https://doc.rust-lang.org/std/option/struct.NoneError.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment