Skip to content

Instantly share code, notes, and snippets.

@thomasdarimont
Last active March 7, 2021 10:43
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save thomasdarimont/ca37355db8cffdc3ebd8 to your computer and use it in GitHub Desktop.
Save thomasdarimont/ca37355db8cffdc3ebd8 to your computer and use it in GitHub Desktop.
Small jpmml-evaulator demo (with a dummy model) on how to dynamically select a regression model based on an input attribute value. In this case we use the "attribute1" to select the right model.
<?xml version="1.0" encoding="UTF-8"?>
<PMML version="4.2"
xmlns="http://www.dmg.org/PMML-4_2"
>
<Header copyright="Copyright (c) 2014 tom" description="Linear Regression Model">
<Extension name="user" value="tom" extender="Rattle/PMML"/>
<Application name="Rattle/PMML" version="1.3"/>
<Timestamp>2014-03-15 13:18:06</Timestamp>
</Header>
<DataDictionary numberOfFields="3">
<DataField name="rate" optype="continuous" dataType="double"/>
<DataField name="year" optype="continuous" dataType="double"/>
<DataField name="attribute1" optype="categorical" dataType="string"/>
</DataDictionary>
<MiningModel functionName="regression">
<MiningSchema>
<MiningField name="rate" usageType="predicted"/>
<MiningField name="year" usageType="active"/>
<MiningField name="attribute1" usageType="active"/>
</MiningSchema>
<Segmentation multipleModelMethod="selectFirst">
<Segment id="A">
<SimplePredicate field="attribute1" operator="equal" value="1315"/>
<RegressionModel modelName="interest-rate-simple-linear-regression-1" functionName="regression"
algorithmName="least squares"
targetFieldName="rate">
<MiningSchema>
<MiningField name="rate" usageType="predicted"/>
<MiningField name="year" usageType="active"/>
<MiningField name="attribute1" usageType="active"/>
</MiningSchema>
<Output>
<OutputField name="Predicted_rate" feature="predictedValue"/>
</Output>
<RegressionTable intercept="10000">
<NumericPredictor name="year" exponent="1" coefficient="1.5"/>
</RegressionTable>
</RegressionModel>
</Segment>
<Segment id="B">
<SimplePredicate field="attribute1" operator="equal" value="1330"/>
<RegressionModel modelName="interest-rate-simple-linear-regression-2" functionName="regression"
algorithmName="least squares"
targetFieldName="rate">
<MiningSchema>
<MiningField name="rate" usageType="predicted"/>
<MiningField name="year" usageType="active"/>
<MiningField name="attribute1" usageType="active"/>
</MiningSchema>
<Output>
<OutputField name="Predicted_rate" feature="predictedValue"/>
</Output>
<RegressionTable intercept="20000">
<NumericPredictor name="year" exponent="1" coefficient="-0.5"/>
</RegressionTable>
</RegressionModel>
</Segment>
<Segment id="C">
<!--default model -->
<True/>
<RegressionModel modelName="interest-rate-simple-linear-regression-2" functionName="regression"
algorithmName="least squares"
targetFieldName="rate">
<MiningSchema>
<MiningField name="rate" usageType="predicted"/>
<MiningField name="year" usageType="active"/>
<MiningField name="attribute1" usageType="active"/>
</MiningSchema>
<Output>
<OutputField name="Predicted_rate" feature="predictedValue"/>
</Output>
<RegressionTable intercept="5000">
<NumericPredictor name="year" exponent="1" coefficient="-0.75"/>
</RegressionTable>
</RegressionModel>
</Segment>
</Segmentation>
</MiningModel>
</PMML>
package de.tutorials.training;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.manager.PMMLManager;
import org.jpmml.model.ImportFilter;
import org.jpmml.model.JAXBUtil;
import org.xml.sax.InputSource;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
public class DynamicPmmlModelSelectionExample {
public static void main(String[] args) throws Exception {
String modelFile = "dynamic-model-selection-example.pmml.xml";
InputStream inputStream = DynamicPmmlModelSelectionExample.class.getClassLoader().getResourceAsStream(modelFile);
PMML pmml = JAXBUtil.unmarshalPMML(ImportFilter.apply(new InputSource(inputStream)));
PMMLManager pmmlManager = new PMMLManager(pmml);
ModelEvaluator<?> evaluator = (ModelEvaluator<?>) pmmlManager.getModelManager(ModelEvaluatorFactory.getInstance());
System.out.println(score(evaluator, new HashMap<String,Object>(){
{
put("year",2014);
put("attribute1","1330");
}
}));
System.out.println(score(evaluator, new HashMap<String,Object>(){
{
put("year",2014);
put("attribute1","1315");
}
}));
System.out.println(score(evaluator, new HashMap<String,Object>(){
{
put("year",2014);
put("attribute1","1345");
}
}));
}
private static Map<FieldName, ?> score(ModelEvaluator<?> evaluator, Map<String, Object> input) {
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
List<FieldName> activeFields = evaluator.getActiveFields();
for(FieldName activeField : activeFields){
Object rawValue = input.get(activeField.getValue());
FieldValue activeValue = evaluator.prepare(activeField, rawValue);
arguments.put(activeField, activeValue);
}
return evaluator.evaluate(arguments);
}
}
{rate=18993.0, Predicted_rate=18993.0}
{rate=13021.0, Predicted_rate=13021.0}
{rate=3489.5, Predicted_rate=3489.5}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment