Last active
December 14, 2018 11:03
-
-
Save otorreno/ca6c5347c1bbde2d4fedd02b51d02cbb to your computer and use it in GitHub Desktop.
Apache Ignite Machine Learning example with keepBinaryCache
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.ignite.Ignite; | |
import org.apache.ignite.IgniteCache; | |
import org.apache.ignite.Ignition; | |
import org.apache.ignite.binary.BinaryObject; | |
import org.apache.ignite.binary.BinaryObjectBuilder; | |
import org.apache.ignite.cache.query.QueryCursor; | |
import org.apache.ignite.cache.query.ScanQuery; | |
import org.apache.ignite.configuration.CacheConfiguration; | |
import org.apache.ignite.ml.clustering.kmeans.KMeansModel; | |
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer; | |
import org.apache.ignite.ml.math.Tracer; | |
import org.apache.ignite.ml.math.distances.ManhattanDistance; | |
import org.apache.ignite.ml.math.functions.IgniteBiFunction; | |
import org.apache.ignite.ml.math.primitives.vector.Vector; | |
import org.apache.ignite.ml.math.primitives.vector.VectorUtils; | |
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; | |
import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; | |
import javax.cache.Cache; | |
import java.util.Random; | |
import java.util.UUID; | |
public class KmeansBinaryObjectCache { | |
private static IgniteCache<Integer, BinaryObject> populateCache(Ignite ignite) { | |
CacheConfiguration<Integer, BinaryObject> cacheConfiguration = new CacheConfiguration<>(); | |
cacheConfiguration.setName("TEST_" + UUID.randomUUID()); | |
IgniteCache<Integer, BinaryObject> cache = ignite.createCache(cacheConfiguration).withKeepBinary(); | |
BinaryObjectBuilder builder = ignite.binary().builder("testType"); | |
for (int i = 0; i < 1_000; i++) { | |
BinaryObject value = builder.setField("label", (i < 500)? 0.0 : 1.0) | |
.setField("feat1", new Random().nextDouble()) | |
.setField("feat2", new Random().nextDouble()) | |
.setField("feat3", "Bob_" + new Random().nextInt()) | |
.build(); | |
cache.put(i, value); | |
} | |
return cache; | |
} | |
/** | |
* Run example. | |
*/ | |
public static void main(String[] args) { | |
System.out.println(); | |
System.out.println(">>> Kmeans binary object cache started."); | |
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { | |
IgniteCache<Integer, BinaryObject> dataCache = populateCache(ignite); | |
// Defines first preprocessor that extracts features from an upstream data. | |
IgniteBiFunction<Integer, BinaryObject, Vector> featureExtractor | |
= (k, v) -> VectorUtils.of(new double[]{v.field("feat1"), v.field("feat2")}); | |
IgniteBiFunction<Integer, BinaryObject, Double> lbExtractor = (k, v) -> (double) v.field("label"); | |
IgniteBiFunction<Integer, BinaryObject, Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, BinaryObject>() | |
.withP(1) | |
.fit( | |
ignite, | |
dataCache, | |
featureExtractor | |
); | |
KMeansTrainer trainer = new KMeansTrainer().withDistance(new ManhattanDistance()).withSeed(7867L); | |
KMeansModel kmdl = trainer.fit(ignite, dataCache, normalizationPreprocessor, lbExtractor); | |
System.out.println("\n>>> Trained model"); | |
System.out.println(">>> KMeans centroids"); | |
Tracer.showAscii(kmdl.getCenters()[0]); | |
Tracer.showAscii(kmdl.getCenters()[1]); | |
System.out.println(">>>"); | |
System.out.println(">>> -----------------------------------"); | |
System.out.println(">>> | Predicted cluster\t| Real Label\t|"); | |
System.out.println(">>> -----------------------------------"); | |
int amountOfErrors = 0; | |
int totalAmount = 0; | |
try (QueryCursor<Cache.Entry<Integer, BinaryObject>> observations = dataCache.query(new ScanQuery<>())) { | |
for (Cache.Entry<Integer, BinaryObject> observation : observations) { | |
BinaryObject val = observation.getValue(); | |
double[] inputs = new double[]{val.field("feat1"), val.field("feat2")}; | |
double groundTruth = val.field("label"); | |
double prediction = kmdl.apply(new DenseVector(inputs)); | |
totalAmount++; | |
if (groundTruth != prediction) | |
amountOfErrors++; | |
System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth); | |
} | |
System.out.println(">>> ---------------------------------"); | |
System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); | |
System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); | |
System.out.println(">>> KMeans clustering algorithm over cached binary object dataset usage example completed."); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment