Last active
March 7, 2021 10:43
-
-
Save thomasdarimont/ca37355db8cffdc3ebd8 to your computer and use it in GitHub Desktop.
Small jpmml-evaulator demo (with a dummy model) on how to dynamically select a regression model based on an input attribute value. In this case we use the "attribute1" to select the right model.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?xml version="1.0" encoding="UTF-8"?> | |
<PMML version="4.2" | |
xmlns="http://www.dmg.org/PMML-4_2" | |
> | |
<Header copyright="Copyright (c) 2014 tom" description="Linear Regression Model"> | |
<Extension name="user" value="tom" extender="Rattle/PMML"/> | |
<Application name="Rattle/PMML" version="1.3"/> | |
<Timestamp>2014-03-15 13:18:06</Timestamp> | |
</Header> | |
<DataDictionary numberOfFields="3"> | |
<DataField name="rate" optype="continuous" dataType="double"/> | |
<DataField name="year" optype="continuous" dataType="double"/> | |
<DataField name="attribute1" optype="categorical" dataType="string"/> | |
</DataDictionary> | |
<MiningModel functionName="regression"> | |
<MiningSchema> | |
<MiningField name="rate" usageType="predicted"/> | |
<MiningField name="year" usageType="active"/> | |
<MiningField name="attribute1" usageType="active"/> | |
</MiningSchema> | |
<Segmentation multipleModelMethod="selectFirst"> | |
<Segment id="A"> | |
<SimplePredicate field="attribute1" operator="equal" value="1315"/> | |
<RegressionModel modelName="interest-rate-simple-linear-regression-1" functionName="regression" | |
algorithmName="least squares" | |
targetFieldName="rate"> | |
<MiningSchema> | |
<MiningField name="rate" usageType="predicted"/> | |
<MiningField name="year" usageType="active"/> | |
<MiningField name="attribute1" usageType="active"/> | |
</MiningSchema> | |
<Output> | |
<OutputField name="Predicted_rate" feature="predictedValue"/> | |
</Output> | |
<RegressionTable intercept="10000"> | |
<NumericPredictor name="year" exponent="1" coefficient="1.5"/> | |
</RegressionTable> | |
</RegressionModel> | |
</Segment> | |
<Segment id="B"> | |
<SimplePredicate field="attribute1" operator="equal" value="1330"/> | |
<RegressionModel modelName="interest-rate-simple-linear-regression-2" functionName="regression" | |
algorithmName="least squares" | |
targetFieldName="rate"> | |
<MiningSchema> | |
<MiningField name="rate" usageType="predicted"/> | |
<MiningField name="year" usageType="active"/> | |
<MiningField name="attribute1" usageType="active"/> | |
</MiningSchema> | |
<Output> | |
<OutputField name="Predicted_rate" feature="predictedValue"/> | |
</Output> | |
<RegressionTable intercept="20000"> | |
<NumericPredictor name="year" exponent="1" coefficient="-0.5"/> | |
</RegressionTable> | |
</RegressionModel> | |
</Segment> | |
<Segment id="C"> | |
<!--default model --> | |
<True/> | |
<RegressionModel modelName="interest-rate-simple-linear-regression-2" functionName="regression" | |
algorithmName="least squares" | |
targetFieldName="rate"> | |
<MiningSchema> | |
<MiningField name="rate" usageType="predicted"/> | |
<MiningField name="year" usageType="active"/> | |
<MiningField name="attribute1" usageType="active"/> | |
</MiningSchema> | |
<Output> | |
<OutputField name="Predicted_rate" feature="predictedValue"/> | |
</Output> | |
<RegressionTable intercept="5000"> | |
<NumericPredictor name="year" exponent="1" coefficient="-0.75"/> | |
</RegressionTable> | |
</RegressionModel> | |
</Segment> | |
</Segmentation> | |
</MiningModel> | |
</PMML> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package de.tutorials.training; | |
import org.dmg.pmml.FieldName; | |
import org.dmg.pmml.PMML; | |
import org.jpmml.evaluator.FieldValue; | |
import org.jpmml.evaluator.ModelEvaluator; | |
import org.jpmml.evaluator.ModelEvaluatorFactory; | |
import org.jpmml.manager.PMMLManager; | |
import org.jpmml.model.ImportFilter; | |
import org.jpmml.model.JAXBUtil; | |
import org.xml.sax.InputSource; | |
import java.io.InputStream; | |
import java.util.HashMap; | |
import java.util.LinkedHashMap; | |
import java.util.List; | |
import java.util.Map; | |
public class DynamicPmmlModelSelectionExample { | |
public static void main(String[] args) throws Exception { | |
String modelFile = "dynamic-model-selection-example.pmml.xml"; | |
InputStream inputStream = DynamicPmmlModelSelectionExample.class.getClassLoader().getResourceAsStream(modelFile); | |
PMML pmml = JAXBUtil.unmarshalPMML(ImportFilter.apply(new InputSource(inputStream))); | |
PMMLManager pmmlManager = new PMMLManager(pmml); | |
ModelEvaluator<?> evaluator = (ModelEvaluator<?>) pmmlManager.getModelManager(ModelEvaluatorFactory.getInstance()); | |
System.out.println(score(evaluator, new HashMap<String,Object>(){ | |
{ | |
put("year",2014); | |
put("attribute1","1330"); | |
} | |
})); | |
System.out.println(score(evaluator, new HashMap<String,Object>(){ | |
{ | |
put("year",2014); | |
put("attribute1","1315"); | |
} | |
})); | |
System.out.println(score(evaluator, new HashMap<String,Object>(){ | |
{ | |
put("year",2014); | |
put("attribute1","1345"); | |
} | |
})); | |
} | |
private static Map<FieldName, ?> score(ModelEvaluator<?> evaluator, Map<String, Object> input) { | |
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>(); | |
List<FieldName> activeFields = evaluator.getActiveFields(); | |
for(FieldName activeField : activeFields){ | |
Object rawValue = input.get(activeField.getValue()); | |
FieldValue activeValue = evaluator.prepare(activeField, rawValue); | |
arguments.put(activeField, activeValue); | |
} | |
return evaluator.evaluate(arguments); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{rate=18993.0, Predicted_rate=18993.0} | |
{rate=13021.0, Predicted_rate=13021.0} | |
{rate=3489.5, Predicted_rate=3489.5} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment