-
-
Save zoq/0c4181694e7e6d3d516b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <shogun/base/init.h> | |
#define private protected | |
#include <shogun/clustering/KMeans.h> | |
#undef private | |
#include <shogun/distance/EuclideanDistance.h> | |
#include <shogun/features/DenseFeatures.h> | |
#include <shogun/features/RealFileFeatures.h> | |
#include <shogun/io/AsciiFile.h> | |
using namespace shogun; | |
class KMeans : public CKMeans | |
{ | |
public: | |
KMeans(int32_t k_, CDistance* d) : CKMeans(k_, d) { } | |
virtual bool train_machine(CFeatures* data, CDenseFeatures<float64_t>* centroids) | |
{ | |
ASSERT(distance); | |
if (data) | |
distance->init(data, data); | |
ASSERT(distance->get_feature_type() == F_DREAL); | |
CDenseFeatures<float64_t>* lhs = (CDenseFeatures<float64_t>*) | |
distance->get_lhs(); | |
ASSERT(lhs); | |
int32_t num = lhs->get_num_vectors(); | |
SG_UNREF(lhs); | |
Weights = SGVector<float64_t>(num); | |
for (int32_t i = 0; i < num; ++i) | |
Weights.vector[i] = 1.0; | |
clustknb(true, centroids->get_feature_matrix().matrix); | |
return true; | |
} | |
}; | |
int main(int argc, char** argv) | |
{ | |
init_shogun_with_defaults(); | |
const char* dataset = argv[1]; | |
const char* centroids = argv[2]; | |
int32_t clusters = atoi(argv[3]); | |
CAsciiFile* dfile = new CAsciiFile(dataset); | |
SGMatrix<float64_t> dmat = SGMatrix<float64_t>(); | |
dmat.load(dfile); | |
SG_UNREF(dfile); | |
CAsciiFile* cfile = new CAsciiFile(centroids); | |
SGMatrix<float64_t> cmat = SGMatrix<float64_t>(); | |
cmat.load(cfile); | |
SG_UNREF(cfile); | |
CDenseFeatures<float64_t>* data = new CDenseFeatures<float64_t>(dmat); | |
SG_REF(data); | |
CDenseFeatures<float64_t>* cent = new CDenseFeatures<float64_t>(cmat); | |
SG_REF(cent); | |
CEuclideanDistance* dist = new CEuclideanDistance(data, data); | |
KMeans k(clusters, dist); | |
k.train_machine(data, cent); | |
SG_UNREF(data); | |
SG_UNREF(cent); | |
exit_shogun(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment