Skip to content

Instantly share code, notes, and snippets.

@tteofili
Last active March 19, 2021 16:05
Show Gist options
  • Save tteofili/2d135e78c7edf39010f14c13dcd54397 to your computer and use it in GitHub Desktop.
Save tteofili/2d135e78c7edf39010f14c13dcd54397 to your computer and use it in GitHub Desktop.
/*
* Copyright 2021 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.kie.kogito.explainability.explainability.integrationtests.pmml;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.kie.api.pmml.PMML4Result;
import org.kie.api.pmml.PMMLRequestData;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.DataUtils;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;
import org.kie.pmml.api.runtime.PMMLContext;
import org.kie.pmml.api.runtime.PMMLRuntime;
import org.kie.pmml.evaluator.core.PMMLContextImpl;
import static org.kie.kogito.pmml.utils.PMMLUtils.getPMMLRequestData;
import static org.kie.pmml.evaluator.assembler.factories.PMMLRuntimeFactoryInternal.getPMMLRuntime;
class MinimalNumericPMMLLimeBenchmarkTest {
private static final String MODEL_NAME = "RandomForestClassifier";
private static PMMLRuntime minimalNumericRuntime;
@BeforeAll
static void setUpBefore() {
minimalNumericRuntime = getPMMLRuntime(Paths.get("/home/tteofili/dev/benchmark-models/minimal-numerical/models/model.pmml").toFile());
}
@Test
void test() throws Exception {
Random random = new Random();
random.setSeed(4);
List<Type> schema = new ArrayList<>();
schema.add(Type.NUMBER);
schema.add(Type.NUMBER);
schema.add(Type.NUMBER);
schema.add(Type.NUMBER);
DataDistribution dataDistribution = DataUtils.readCSV(Paths.get("/home/tteofili/dev/benchmark-models/minimal-numerical/data/processed/inputs.csv"), schema, true);
PredictionProvider model = inputs -> CompletableFuture.supplyAsync(() -> {
List<PredictionOutput> outputs = new ArrayList<>(inputs.size());
for (PredictionInput input : inputs) {
Map<String, Object> map = new HashMap<>();
for (Feature f : input.getFeatures()) {
map.put(f.getName(), f.getValue().asNumber());
}
final PMMLRequestData pmmlRequestData = getPMMLRequestData("RandomForestClassifier", map);
final PMMLContext pmmlContext = new PMMLContextImpl(pmmlRequestData);
PMML4Result pmml4Result = minimalNumericRuntime.evaluate(MODEL_NAME, pmmlContext);
Map<String, Object> resultVariables = pmml4Result.getResultVariables();
String score = "" + resultVariables.get("probability_1");
String approved = "" + resultVariables.get("predicted_Approved");
PredictionOutput predictionOutput = new PredictionOutput(List.of(
new Output("predicted_Approved", Type.TEXT, new Value(approved), Double.parseDouble(score))));
outputs.add(predictionOutput);
}
return outputs;
});
LocalExplainer<Map<String, Saliency>> limeExplainer = new LimeExplainer();
List<PredictionInput> samples = dataDistribution.getAllSamples();
List<PredictionOutput> predictionOutputs = model.predictAsync(samples).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
List<Prediction> predictions = DataUtils.getPredictions(samples, predictionOutputs);
double meanImpactScore = 0;
for (Prediction prediction : predictions) {
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model)
.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
for (Saliency saliency : saliencyMap.values()) {
double v = ExplainabilityMetrics.impactScore(model, prediction, saliency.getTopFeatures(2));
meanImpactScore += v;
}
}
System.out.println(meanImpactScore / (double) predictions.size());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment