Skip to content

Instantly share code, notes, and snippets.

@yoshitsugu
Created November 14, 2017 12:38
Show Gist options
  • Save yoshitsugu/344f41a5c67af70bf00696912068926d to your computer and use it in GitHub Desktop.
Save yoshitsugu/344f41a5c67af70bf00696912068926d to your computer and use it in GitHub Desktop.
extern crate rusty_machine;
extern crate rand;
extern crate gnuplot;
use rusty_machine::linalg::{Matrix, BaseMatrix};
use rusty_machine::learning::k_means::KMeansClassifier;
use rusty_machine::learning::UnSupModel;
use rand::thread_rng;
use rand::distributions::IndependentSample;
use rand::distributions::normal::Normal;
use gnuplot::*;
fn generate_data(centroids: &Matrix<f64>,
points_per_centroid: usize,
noise: f64)
-> Matrix<f64> {
assert!(centroids.cols() > 0, "Centroids cannot be empty.");
assert!(centroids.rows() > 0, "Centroids cannot be empty.");
assert!(noise >= 0f64, "Noise must be non-negative.");
let mut raw_cluster_data = Vec::with_capacity(centroids.rows() * points_per_centroid *
centroids.cols());
let mut rng = thread_rng();
let normal_rv = Normal::new(0f64, noise);
for _ in 0..points_per_centroid {
// Generate points from each centroid
for centroid in centroids.iter_rows() {
// Generate a point randomly around the centroid
let mut point = Vec::with_capacity(centroids.cols());
for feature in centroid.iter() {
point.push(feature + normal_rv.ind_sample(&mut rng));
}
// Push point to raw_cluster_data
raw_cluster_data.extend(point);
}
}
Matrix::new(centroids.rows() * points_per_centroid,
centroids.cols(),
raw_cluster_data)
}
fn main() {
println!("K-Means clustering example:");
const SAMPLES_PER_CENTROID: usize = 2000;
println!("Generating {0} samples from each centroids:",
SAMPLES_PER_CENTROID);
// Choose two cluster centers, at (-0.5, -0.5) and (0, 0.5).
let centroids = Matrix::new(2, 2, vec![-0.5, -0.5, 0.0, 0.5]);
println!("{}", centroids);
// Generate some data randomly around the centroids
let samples = generate_data(&centroids, SAMPLES_PER_CENTROID, 0.4);
// Create a new model with 2 clusters
let mut model = KMeansClassifier::new(2);
// Train the model
println!("Training the model...");
// Our train function returns a Result<(), E>
model.train(&samples).unwrap();
let centroids = model.centroids().as_ref().unwrap();
println!("Model Centroids:\n{:.3}", centroids);
// Predict the classes and partition into
println!("Classifying the samples...");
let classes = model.predict(&samples).unwrap();
println!("Plotting the samples...");
let mut first_rows: Matrix<f64> = Matrix::new(0, 2, vec![]);
let mut second_rows: Matrix<f64> = Matrix::new(0, 2, vec![]);
for (i, x) in samples.iter_rows().enumerate() {
if classes.data()[i] == 0 {
first_rows = first_rows.vcat(&Matrix::new(1, 2, x));
} else {
second_rows = second_rows.vcat(&Matrix::new(1, 2, x));
}
}
let mut fg = Figure::new();
let xs1 = first_rows.select_cols(&[0]).into_vec();
let ys1 = first_rows.select_cols(&[1]).into_vec();
let xs2 = second_rows.select_cols(&[0]).into_vec();
let ys2 = second_rows.select_cols(&[1]).into_vec();
fg.axes2d()
.points(xs1, ys1, &[Caption("Points1"), PointSymbol('x'), Color("#008000"), PointSize(0.5)])
.points(xs2, ys2, &[Caption("Points2"), PointSymbol('o'), Color("#ff4000"), PointSize(0.5)])
.set_title("Plot", &[]);
fg.set_terminal("pngcairo", "/tmp/fg1.png");
fg.show();
println!("Finish! Created /tmp/fg1.png")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment