Skip to content

Instantly share code, notes, and snippets.

@y3nr1ng
Created February 4, 2018 15:38
Show Gist options
  • Save y3nr1ng/d6e63c2d08611ecd3474e74308382af7 to your computer and use it in GitHub Desktop.
Save y3nr1ng/d6e63c2d08611ecd3474e74308382af7 to your computer and use it in GitHub Desktop.
Simple K-mean implementation.
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <chrono>
#include <random>
#include <algorithm>
#include <cmath>
#define DATASET_FILENAME "iris.data"
#define TOLERANCE 1e-3
#define MAX_ITER 1e3
struct Data {
int label;
float length;
float width;
Data() {
label = -1;
length = width = 0.0f;
}
Data(float _length, float _width) {
label = -1;
length = _length;
width = _width;
}
};
float calc_dist2(const Data data_a, const Data data_b) {
return (data_a.length-data_b.length)*(data_a.length-data_b.length) +
(data_a.width-data_b.width)*(data_a.width-data_b.width);
}
float calc_tol(
const std::vector<Data>& curr_cents,
const std::vector<Data>& next_cents
) {
const std::size_t n = curr_cents.size();
float dist2_sum = 0.0f;
for (auto i = 0; i < n; i++) {
dist2_sum += calc_dist2(curr_cents[i], next_cents[i]);
}
return dist2_sum / n;
}
int find_label(const Data data, const std::vector<Data>& cents) {
const std::size_t n = cents.size();
float min_dist2 = 0.0f;
int min_index = -1;
for (auto i = 0; i < n; i++) {
float dist2 = calc_dist2(data, cents[i]);
if (min_dist2 > dist2 or min_index < 0) {
min_dist2 = dist2;
min_index = i;
}
}
return min_index;
}
void update_centroid(
const std::vector<Data>& dataset,
std::vector<Data>& centroids
) {
const std::size_t n_cents = centroids.size();
for (auto i = 0; i < n_cents; i++) {
centroids[i].label = 0;
centroids[i].length = centroids[i].width = 0.0f;
}
const std::size_t n_dataset = dataset.size();
for (auto i = 0; i < n_dataset; i++) {
const Data& data = dataset[i];
centroids[data.label].label++;
centroids[data.label].length += data.length;
centroids[data.label].width += data.width;
}
for (auto i = 0; i < n_cents; i++) {
centroids[i].length /= centroids[i].label;
centroids[i].width /= centroids[i].label;
}
}
float kmean(
std::vector<Data>& dataset,
const std::vector<Data>& curr_cents, std::vector<Data>& next_cents
) {
const std::size_t n = dataset.size();
for (auto i = 0; i < n; i++) {
dataset[i].label = find_label(dataset[i], curr_cents);
}
update_centroid(dataset, next_cents);
return calc_tol(curr_cents, next_cents);
}
void init_centroids(const std::vector<Data>& dataset, std::vector<Data>& cents) {
float min_length = -1, max_length = -1;
float min_width = -1, max_width = -1;
const std::size_t n_dataset = dataset.size();
for (auto i = 0; i < n_dataset; i++) {
if (min_length < 0 or min_length > dataset[i].length) {
min_length = dataset[i].length;
}
if (max_length < 0 or max_length < dataset[i].length) {
max_length = dataset[i].length;
}
if (min_width < 0 or min_width > dataset[i].width) {
min_width = dataset[i].width;
}
if (max_width < 0 or max_width < dataset[i].width) {
max_width = dataset[i].width;
}
}
std::cerr << std::endl;
std::cerr << "min(length) = " << min_length << ", max(length) = " << max_length << std::endl;
std::cerr << "min(width) = " << min_width << ", max(width) = " << max_width << std::endl;
std::cerr << std::endl;
std::cerr << "Generating centroids" << std::endl;
auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
std::mt19937 rng(seed);
std::uniform_real_distribution<float> length_gen(min_length, max_length);
std::uniform_real_distribution<float> width_gen(min_width, max_width);
const std::size_t n_cents = cents.size();
for (auto i = 0; i < n_cents; i++) {
cents[i].label = i;
cents[i].length = length_gen(rng);
cents[i].width = width_gen(rng);
std::cerr << i << ", (";
std::cerr << cents[i].length << ", " << cents[i].width << ")" << std::endl;
}
}
std::vector<Data> kmean(std::vector<Data>& dataset, int n_groups = 3) {
std::vector<Data> curr_cents(n_groups), next_cents(n_groups);
init_centroids(dataset, curr_cents);
float error;
for (auto i = 0; ; i++) {
error = kmean(dataset, curr_cents, next_cents);
std::swap(curr_cents, next_cents);
std::cerr << std::endl;
std::cerr << "iter " << i << ", error = " << error << std::endl;
// early stop conditions
if (error < TOLERANCE or i >= MAX_ITER) {
break;
}
// reset the computation if NaN reached
if (std::isnan(error)) {
std::cerr << ".. reset" << std::endl;
init_centroids(dataset, curr_cents);
i = 0;
}
}
return curr_cents;
}
struct InvalidChar {
bool operator()(char c) const {
return !isprint(static_cast<unsigned char>(c)) &&
!isblank(static_cast<unsigned char>(c));
}
};
void read_from_file(std::ifstream& infile, std::vector<Data>& dataset) {
std::string species;
float length, width;
std::cerr << std::endl;
std::cerr << "Reading from file" << std::endl;
std::string line;
std::istringstream iss;
while(std::getline(infile, line, '\n')) {
// remove non-ASCII characters
line.erase(std::remove_if(line.begin(), line.end(), InvalidChar()), line.end());
iss.clear();
iss.str(line);
iss >> species >> length >> width;
dataset.emplace_back(length, width);
}
std::cerr << dataset.size() << " entries read" << std::endl;
}
int main(void) {
std::ifstream infile(DATASET_FILENAME);
std::vector<Data> dataset, centroids;
read_from_file(infile, dataset);
centroids = kmean(dataset, 3);
// using label field to do the accounting
const std::size_t n_dataset = dataset.size();
const std::size_t n_cents = centroids.size();
for (auto i = 0; i < n_cents; i++) {
centroids[i].label = 0;
}
for (auto i = 0; i < n_dataset; i++) {
const Data& data = dataset[i];
centroids[data.label].label++;
}
std::cout << std::endl;
for (auto i = 0; i < n_cents; i++) {
std::cout << "(" << centroids[i].length << ", " << centroids[i].width << "), ";
std::cout << "n = " << centroids[i].label << std::endl;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment