Skip to content

Instantly share code, notes, and snippets.

@tteofili
Last active December 10, 2020 14:52
Show Gist options
  • Save tteofili/78e3896b342623576420224de05a24c4 to your computer and use it in GitHub Desktop.
Save tteofili/78e3896b342623576420224de05a24c4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Explainability Core example\n",
"\n",
"\n",
"## Dependencies\n",
"\n",
"Start by specifying the necessary dependencies:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%%loadFromPOM\n",
"<repository>\n",
" <id>jboss-public-repository-group</id>\n",
" <name>JBoss Public Repository Group</name>\n",
" <url>https://repository.jboss.org/nexus/content/groups/public/</url>\n",
"</repository>\n",
"\n",
"<dependency>\n",
" <groupId>org.kie.kogito</groupId>\n",
" <artifactId>kogito-bom</artifactId>\n",
" <version>1.0.0.Final</version>\n",
"</dependency>\n",
"\n",
"<dependency>\n",
" <groupId>org.kie.kogito</groupId>\n",
" <artifactId>kogito-apps</artifactId>\n",
" <version>1.0.0.Final</version>\n",
"</dependency>\n",
"\n",
"<dependency>\n",
" <groupId>org.kie.kogito</groupId>\n",
" <artifactId>explainability-core</artifactId>\n",
" <version>1.0.0.Final</version>\n",
"</dependency>\n",
"\n",
"<dependency>\n",
" <groupId>org.apache.opennlp</groupId>\n",
" <artifactId>opennlp-tools</artifactId>\n",
" <version>1.9.2</version>\n",
"</dependency>"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import io.github.spencerpark.ijava.IJava;\n",
"import io.github.spencerpark.jupyter.kernel.magic.common.Shell;\n",
"IJava.getKernelInstance().getMagics().registerMagics(Shell.class);"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"%sh wget https://downloads.apache.org/opennlp/models/langdetect/1.8.3/langdetect-183.bin"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import java.io.InputStream;\n",
"import java.io.FileInputStream;\n",
"import opennlp.tools.langdetect.Language;\n",
"import opennlp.tools.langdetect.LanguageDetector;\n",
"import opennlp.tools.langdetect.LanguageDetectorME;\n",
"import opennlp.tools.langdetect.LanguageDetectorModel;\n",
"\n",
"InputStream is = new FileInputStream(\"langdetect-183.bin\");\n",
"LanguageDetectorModel languageDetectorModel = new LanguageDetectorModel(is);\n",
"LanguageDetector languageDetector = new LanguageDetectorME(languageDetectorModel);"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import java.util.LinkedList;\n",
"import java.util.List;\n",
"import java.util.Map;\n",
"import java.util.Random;\n",
"import java.util.concurrent.CompletableFuture;\n",
"\n",
"import org.kie.kogito.explainability.model.Type;\n",
"import org.kie.kogito.explainability.model.Value;\n",
"import org.kie.kogito.explainability.model.Feature;\n",
"import org.kie.kogito.explainability.model.FeatureFactory;\n",
"import org.kie.kogito.explainability.model.Output;\n",
"import org.kie.kogito.explainability.model.PerturbationContext;\n",
"import org.kie.kogito.explainability.model.Prediction;\n",
"import org.kie.kogito.explainability.model.PredictionInput;\n",
"import org.kie.kogito.explainability.model.PredictionOutput;\n",
"import org.kie.kogito.explainability.model.PredictionProvider;\n",
"\n",
"PredictionProvider model = inputs -> CompletableFuture.supplyAsync(() -> {\n",
" List<PredictionOutput> results = new LinkedList<>();\n",
" for (PredictionInput predictionInput : inputs) {\n",
" StringBuilder builder = new StringBuilder();\n",
" for (Feature f : predictionInput.getFeatures()) {\n",
" if (builder.length() > 0) {\n",
" builder.append(' ');\n",
" }\n",
" builder.append(f.getValue().asString());\n",
" }\n",
" Language language = languageDetector.predictLanguage(builder.toString());\n",
" PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output(\"lang\", Type.TEXT, new Value<>(language.getLang()), language.getConfidence())));\n",
" results.add(predictionOutput);\n",
" }\n",
" return results;\n",
"});"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Output{value=Value{ita}, type=text, score=0.028884185063368605, name='lang'}]\n"
]
}
],
"source": [
"String inputText = \"italiani spaghetti pizza mandolino\";\n",
"List<Feature> features = new LinkedList<>();\n",
"features.add(FeatureFactory.newFulltextFeature(\"text\", inputText));\n",
"PredictionInput input = new PredictionInput(features);\n",
"\n",
"List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get();\n",
"PredictionOutput output = predictionOutputs.get(0);\n",
"Prediction prediction = new Prediction(input, output);\n",
"\n",
"System.out.println(output.getOutputs())"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Feature{name='text', type=text, value=Value{italiani}}: 0.0\n",
"Feature{name='text', type=text, value=Value{spaghetti}}: 0.09680093109947152\n",
"Feature{name='text', type=text, value=Value{pizza}}: 0.053478635073901735\n",
"Feature{name='text', type=text, value=Value{mandolino}}: 0.0\n"
]
}
],
"source": [
"import java.util.concurrent.TimeUnit;\n",
"import org.kie.kogito.explainability.Config;\n",
"import org.kie.kogito.explainability.model.Saliency;\n",
"import org.kie.kogito.explainability.model.FeatureImportance;\n",
"import org.kie.kogito.explainability.local.lime.LimeExplainer;\n",
"\n",
"LimeExplainer limeExplainer = new LimeExplainer(1, 100);\n",
"\n",
"Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model)\n",
" .get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());\n",
"\n",
"for (FeatureImportance fi : saliencyMap.get(\"lang\").getPerFeatureImportance()) {\n",
" System.out.println(fi.getFeature() + \": \" + fi.getScore());\n",
"}"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Java",
"language": "java",
"name": "java"
},
"language_info": {
"codemirror_mode": "java",
"file_extension": ".jshell",
"mimetype": "text/x-java-source",
"name": "Java",
"pygments_lexer": "java",
"version": "11.0.3+7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment