Skip to content

Instantly share code, notes, and snippets.

@spicoflorin
Last active April 19, 2019 10:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save spicoflorin/af538892470216beab6bc6ec8bba995c to your computer and use it in GitHub Desktop.
Save spicoflorin/af538892470216beab6bc6ec8bba995c to your computer and use it in GitHub Desktop.
@RestController
public class SGMController {
private ModelEvaluator<MiningModel> modelEvaluator;
@PostConstruct
public void init() {
try {
PMML pmml = createPMMLfromFile("iris_rf.pmml");
modelEvaluator = new MiningModelEvaluator(pmml);
} catch (SAXException | JAXBException | IOException e) {
e.printStackTrace();
}
}
@RequestMapping(value = "/ping", method = RequestMethod.GET)
public String ping() {
return "";
}
@RequestMapping(value = "/invocations", method = RequestMethod.POST)
public String invoke(HttpServletRequest request) throws IOException {
return predict(request.getReader().lines(), modelEvaluator);
}
private static String predict(Stream<String> inputData,
ModelEvaluator<MiningModel> modelEvaluator) {
String returns = inputData.map(dataLine -> {
Map<FieldName, FieldValue> arguments = readArgumentsFromLine(dataLine, modelEvaluator);
modelEvaluator.verify();
Map<FieldName, ?> results = modelEvaluator.evaluate(arguments);
FieldName targetName = modelEvaluator.getTargetField();
Object targetValue = results.get(targetName);
ProbabilityClassificationMap nodeMap = (ProbabilityClassificationMap) targetValue;
return ( nodeMap != null && nodeMap.getResult() != null) ? nodeMap.getResult().toString() : "NA for input->"+dataLine;
}).collect(Collectors.joining(System.lineSeparator()));
return returns;
}
private static PMML createPMMLfromFile(String fileName)
throws SAXException, JAXBException, IOException {
try (
InputStream pmmlFile = SGMController.class.getClassLoader().getResourceAsStream(fileName)) {
String pmmlString = new Scanner(pmmlFile).useDelimiter("\\Z").next();
InputStream is = new ByteArrayInputStream(pmmlString.getBytes());
InputSource source = new InputSource(is);
SAXSource transformedSource = ImportFilter.apply(source);
return JAXBUtil.unmarshalPMML(transformedSource);
}
}
private static Map<FieldName, FieldValue> readArgumentsFromLine(String line,
ModelEvaluator<MiningModel> modelEvaluator) {
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
String[] lineArgs = line.split(",");
if (lineArgs.length != 5)
return arguments;
FieldValue sepalLength = modelEvaluator.prepare(new FieldName("Sepal.Length"),
lineArgs[0].isEmpty() ? 0 : lineArgs[0]);
FieldValue sepalWidth = modelEvaluator.prepare(new FieldName("Sepal.Width"),
lineArgs[1].isEmpty() ? 0 : lineArgs[1]);
FieldValue petalLength = modelEvaluator.prepare(new FieldName("Petal.Length"),
lineArgs[2].isEmpty() ? 0 : lineArgs[2]);
FieldValue petalWidth = modelEvaluator.prepare(new FieldName("Petal.Width"),
lineArgs[3].isEmpty() ? 0 : lineArgs[3]);
arguments.put(new FieldName("Sepal.Length"), sepalLength);
arguments.put(new FieldName("Sepal.Width"), sepalWidth);
arguments.put(new FieldName("Petal.Length"), petalLength);
arguments.put(new FieldName("Petal.Width"), petalWidth);
return arguments;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment