Created
November 14, 2017 12:38
-
-
Save yoshitsugu/344f41a5c67af70bf00696912068926d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate rusty_machine; | |
extern crate rand; | |
extern crate gnuplot; | |
use rusty_machine::linalg::{Matrix, BaseMatrix}; | |
use rusty_machine::learning::k_means::KMeansClassifier; | |
use rusty_machine::learning::UnSupModel; | |
use rand::thread_rng; | |
use rand::distributions::IndependentSample; | |
use rand::distributions::normal::Normal; | |
use gnuplot::*; | |
fn generate_data(centroids: &Matrix<f64>, | |
points_per_centroid: usize, | |
noise: f64) | |
-> Matrix<f64> { | |
assert!(centroids.cols() > 0, "Centroids cannot be empty."); | |
assert!(centroids.rows() > 0, "Centroids cannot be empty."); | |
assert!(noise >= 0f64, "Noise must be non-negative."); | |
let mut raw_cluster_data = Vec::with_capacity(centroids.rows() * points_per_centroid * | |
centroids.cols()); | |
let mut rng = thread_rng(); | |
let normal_rv = Normal::new(0f64, noise); | |
for _ in 0..points_per_centroid { | |
// Generate points from each centroid | |
for centroid in centroids.iter_rows() { | |
// Generate a point randomly around the centroid | |
let mut point = Vec::with_capacity(centroids.cols()); | |
for feature in centroid.iter() { | |
point.push(feature + normal_rv.ind_sample(&mut rng)); | |
} | |
// Push point to raw_cluster_data | |
raw_cluster_data.extend(point); | |
} | |
} | |
Matrix::new(centroids.rows() * points_per_centroid, | |
centroids.cols(), | |
raw_cluster_data) | |
} | |
fn main() { | |
println!("K-Means clustering example:"); | |
const SAMPLES_PER_CENTROID: usize = 2000; | |
println!("Generating {0} samples from each centroids:", | |
SAMPLES_PER_CENTROID); | |
// Choose two cluster centers, at (-0.5, -0.5) and (0, 0.5). | |
let centroids = Matrix::new(2, 2, vec![-0.5, -0.5, 0.0, 0.5]); | |
println!("{}", centroids); | |
// Generate some data randomly around the centroids | |
let samples = generate_data(¢roids, SAMPLES_PER_CENTROID, 0.4); | |
// Create a new model with 2 clusters | |
let mut model = KMeansClassifier::new(2); | |
// Train the model | |
println!("Training the model..."); | |
// Our train function returns a Result<(), E> | |
model.train(&samples).unwrap(); | |
let centroids = model.centroids().as_ref().unwrap(); | |
println!("Model Centroids:\n{:.3}", centroids); | |
// Predict the classes and partition into | |
println!("Classifying the samples..."); | |
let classes = model.predict(&samples).unwrap(); | |
println!("Plotting the samples..."); | |
let mut first_rows: Matrix<f64> = Matrix::new(0, 2, vec![]); | |
let mut second_rows: Matrix<f64> = Matrix::new(0, 2, vec![]); | |
for (i, x) in samples.iter_rows().enumerate() { | |
if classes.data()[i] == 0 { | |
first_rows = first_rows.vcat(&Matrix::new(1, 2, x)); | |
} else { | |
second_rows = second_rows.vcat(&Matrix::new(1, 2, x)); | |
} | |
} | |
let mut fg = Figure::new(); | |
let xs1 = first_rows.select_cols(&[0]).into_vec(); | |
let ys1 = first_rows.select_cols(&[1]).into_vec(); | |
let xs2 = second_rows.select_cols(&[0]).into_vec(); | |
let ys2 = second_rows.select_cols(&[1]).into_vec(); | |
fg.axes2d() | |
.points(xs1, ys1, &[Caption("Points1"), PointSymbol('x'), Color("#008000"), PointSize(0.5)]) | |
.points(xs2, ys2, &[Caption("Points2"), PointSymbol('o'), Color("#ff4000"), PointSize(0.5)]) | |
.set_title("Plot", &[]); | |
fg.set_terminal("pngcairo", "/tmp/fg1.png"); | |
fg.show(); | |
println!("Finish! Created /tmp/fg1.png") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment