Skip to content

Instantly share code, notes, and snippets.

@otorreno
Last active December 14, 2018 11:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save otorreno/ca6c5347c1bbde2d4fedd02b51d02cbb to your computer and use it in GitHub Desktop.
Save otorreno/ca6c5347c1bbde2d4fedd02b51d02cbb to your computer and use it in GitHub Desktop.
Apache Ignite Machine Learning example with keepBinaryCache
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.binary.BinaryObject;
import org.apache.ignite.binary.BinaryObjectBuilder;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
import org.apache.ignite.ml.math.Tracer;
import org.apache.ignite.ml.math.distances.ManhattanDistance;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import javax.cache.Cache;
import java.util.Random;
import java.util.UUID;
public class KmeansBinaryObjectCache {
private static IgniteCache<Integer, BinaryObject> populateCache(Ignite ignite) {
CacheConfiguration<Integer, BinaryObject> cacheConfiguration = new CacheConfiguration<>();
cacheConfiguration.setName("TEST_" + UUID.randomUUID());
IgniteCache<Integer, BinaryObject> cache = ignite.createCache(cacheConfiguration).withKeepBinary();
BinaryObjectBuilder builder = ignite.binary().builder("testType");
for (int i = 0; i < 1_000; i++) {
BinaryObject value = builder.setField("label", (i < 500)? 0.0 : 1.0)
.setField("feat1", new Random().nextDouble())
.setField("feat2", new Random().nextDouble())
.setField("feat3", "Bob_" + new Random().nextInt())
.build();
cache.put(i, value);
}
return cache;
}
/**
* Run example.
*/
public static void main(String[] args) {
System.out.println();
System.out.println(">>> Kmeans binary object cache started.");
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
IgniteCache<Integer, BinaryObject> dataCache = populateCache(ignite);
// Defines first preprocessor that extracts features from an upstream data.
IgniteBiFunction<Integer, BinaryObject, Vector> featureExtractor
= (k, v) -> VectorUtils.of(new double[]{v.field("feat1"), v.field("feat2")});
IgniteBiFunction<Integer, BinaryObject, Double> lbExtractor = (k, v) -> (double) v.field("label");
IgniteBiFunction<Integer, BinaryObject, Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, BinaryObject>()
.withP(1)
.fit(
ignite,
dataCache,
featureExtractor
);
KMeansTrainer trainer = new KMeansTrainer().withDistance(new ManhattanDistance()).withSeed(7867L);
KMeansModel kmdl = trainer.fit(ignite, dataCache, normalizationPreprocessor, lbExtractor);
System.out.println("\n>>> Trained model");
System.out.println(">>> KMeans centroids");
Tracer.showAscii(kmdl.getCenters()[0]);
Tracer.showAscii(kmdl.getCenters()[1]);
System.out.println(">>>");
System.out.println(">>> -----------------------------------");
System.out.println(">>> | Predicted cluster\t| Real Label\t|");
System.out.println(">>> -----------------------------------");
int amountOfErrors = 0;
int totalAmount = 0;
try (QueryCursor<Cache.Entry<Integer, BinaryObject>> observations = dataCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, BinaryObject> observation : observations) {
BinaryObject val = observation.getValue();
double[] inputs = new double[]{val.field("feat1"), val.field("feat2")};
double groundTruth = val.field("label");
double prediction = kmdl.apply(new DenseVector(inputs));
totalAmount++;
if (groundTruth != prediction)
amountOfErrors++;
System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
System.out.println(">>> ---------------------------------");
System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
System.out.println(">>> KMeans clustering algorithm over cached binary object dataset usage example completed.");
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment