Skip to content

Instantly share code, notes, and snippets.

@edumucelli
Created April 14, 2017 14:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save edumucelli/fd0cfdcb621e6f7154019c855acfeb4e to your computer and use it in GitHub Desktop.
Save edumucelli/fd0cfdcb621e6f7154019c855acfeb4e to your computer and use it in GitHub Desktop.
library("caret")
library("r2pmml")
data(iris)
rf_fit = train(Species ~ ., data = iris, method = "rf")
print(rf_fit)
r2pmml(rf_fit, "rf_fit.pmml")
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ProbabilityDistribution;
import java.util.Map;
import java.util.concurrent.Callable;
class ParallelPredictor implements Callable {
private Map<FieldName, FieldValue> arguments;
private ModelEvaluator<?> evaluator;
public ParallelPredictor(Map<FieldName, FieldValue> arguments, ModelEvaluator<?> evaluator) {
this.arguments = arguments;
this.evaluator = evaluator;
}
@Override
public Double call() throws Exception {
return ((ProbabilityDistribution) evaluator.evaluate(arguments).get(evaluator.getTargetFieldName())).getProbability("versicolor");
}
}
import lombok.extern.slf4j.Slf4j;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;
import javax.xml.bind.JAXBException;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.time.Duration;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
@Slf4j
class Predictor {
private String modelFilename;
private ModelEvaluator<?> evaluator;
private static final Random random = new Random();
private static final int NUMBER_OF_THREADS = 100;
private ExecutorService executor = Executors.newFixedThreadPool(NUMBER_OF_THREADS);
Predictor(String modelFilename) {
this.modelFilename = modelFilename;
}
void buildEvaluator() {
PMML pmml = null;
File inputFilePath = new File(this.modelFilename);
try(InputStream in = new FileInputStream(inputFilePath)) {
pmml = org.jpmml.model.PMMLUtil.unmarshal(in);
} catch (SAXException | JAXBException | IOException e) {
e.printStackTrace();
}
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
this.evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
}
void predict() {
List<Callable<Double>> callableArguments = new ArrayList<>();
int numberOfRepeats = 33;
int numberOfRows = 100;
for (int j = 0; j < numberOfRepeats; j++) {
for (int i = 0; i < numberOfRows; i++) {
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
FieldName sepalLengthName = FieldName.create("Sepal.Length");
FieldName sepalWidthName = FieldName.create("Sepal.Width");
FieldName petalLengthName = FieldName.create("Petal.Length");
FieldName petalWidthName = FieldName.create("Petal.Width");
FieldValue sepalLengthValue = FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, randomValue());
FieldValue sepalWidthValue = FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, randomValue());
FieldValue petalLengthValue = FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, randomValue());
FieldValue petalWidthValue = FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, randomValue());
arguments.put(sepalLengthName, sepalLengthValue);
arguments.put(sepalWidthName, sepalWidthValue);
arguments.put(petalLengthName, petalLengthValue);
arguments.put(petalWidthName, petalWidthValue);
callableArguments.add(new ParallelPredictor(arguments, evaluator));
}
try {
Instant start = Instant.now();
executor.invokeAll(callableArguments)
.stream()
.map(future -> {
try {
return future.get();
} catch (Exception e) {
throw new IllegalStateException(e);
}
})
.collect(Collectors.toList());
Instant end = Instant.now();
log.info(String.valueOf(Duration.between(start, end).toMillis()));
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
public Double randomValue() {
return 1 + (10 - 1) * random.nextDouble();
}
}
public class Runner {
public static void main(String[] args) {
Predictor predictor = new Predictor("rf_fit.pmml");
predictor.buildEvaluator();
predictor.predict();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment