Skip to content

Instantly share code, notes, and snippets.

@AbdealiLoKo
Created July 21, 2020 19:33
Show Gist options
  • Save AbdealiLoKo/1dd5b7677435ba22f9ab3e26016bb3e7 to your computer and use it in GitHub Desktop.
Save AbdealiLoKo/1dd5b7677435ba22f9ab3e26016bb3e7 to your computer and use it in GitHub Desktop.
Comparing py-java libraries
# Example:
# PYJAVA_LIB=jpype venv/bin/python pyjava.py
import os
from datetime import datetime
from jpmml_evaluator import _package_classpath
lib = os.environ.get('PYJAVA_LIB')
assert lib is not None, 'Set env var PYJAVA_LIB to py4j/jnius/jpype'
##### Create a JVM
start_t = datetime.now()
if lib == 'py4j':
from py4j.java_gateway import JavaGateway
gateway = JavaGateway.launch_gateway(classpath=os.pathsep.join(_package_classpath()))
jString = gateway.jvm.__getattr__('java.lang.String')
jDouble = gateway.jvm.__getattr__('java.lang.Double')
jFile = gateway.jvm.__getattr__('java.io.File')
jLinkedHashMap = gateway.jvm.__getattr__('java.util.LinkedHashMap')
jLoadingModelEvaluatorBuilder = gateway.jvm.__getattr__('org.jpmml.evaluator.LoadingModelEvaluatorBuilder')
jModelEvaluationContext = gateway.jvm.__getattr__('org.jpmml.evaluator.ModelEvaluationContext')
jEvaluatorUtil = gateway.jvm.__getattr__('org.jpmml.evaluator.EvaluatorUtil')
elif lib == 'jnius':
import jnius_config
jnius_config.set_classpath(*_package_classpath())
import jnius
jString = jnius.autoclass('java.lang.String')
jDouble = jnius.autoclass('java.lang.Double')
jFile = jnius.autoclass('java.io.File')
jLinkedHashMap = jnius.autoclass('java.util.LinkedHashMap')
jLoadingModelEvaluatorBuilder = jnius.autoclass('org.jpmml.evaluator.LoadingModelEvaluatorBuilder')
jModelEvaluationContext = jnius.autoclass('org.jpmml.evaluator.ModelEvaluationContext')
jEvaluatorUtil = jnius.autoclass('org.jpmml.evaluator.EvaluatorUtil')
elif lib == 'jpype':
import jpype
import jpype.imports
jpype.startJVM(classpath=_package_classpath())
from java.lang import String as jString
from java.lang import Double as jDouble
from java.io import File as jFile
from java.util import LinkedHashMap as jLinkedHashMap
from org.jpmml.evaluator import LoadingModelEvaluatorBuilder as jLoadingModelEvaluatorBuilder
from org.jpmml.evaluator import ModelEvaluationContext as jModelEvaluationContext
from org.jpmml.evaluator import EvaluatorUtil as jEvaluatorUtil
time = (datetime.now() - start_t).total_seconds()
print(f"createjvm: {time:.3f}s")
##### Load a Model
times = []
for _ in range(100):
start_t = datetime.now()
evaluatorBuilder = jLoadingModelEvaluatorBuilder()
evaluatorBuilder.setLocatable(True)
evaluatorBuilder.load(jFile(jString("jpmml_evaluator/tests/resources/DecisionTreeIris.pmml")))
evaluator = evaluatorBuilder.build().verify()
times.append((datetime.now() - start_t).total_seconds())
print(f"loadmodel: tot={sum(times):.6f} max={max(times):.6f}s avg={sum(times) / len(times):.6f}s")
##### Query field info
times = []
for _ in range(100):
start_t = datetime.now()
inputFields = evaluator.getInputFields()
vals = [inputField.getName() for inputField in inputFields]
# print("Input fields: ", vals)
targetFields = evaluator.getTargetFields()
vals = [targetField.getName() for targetField in targetFields]
# print("Target field(s): ", vals)
outputFields = evaluator.getOutputFields()
vals = [outputField.getName() for outputField in outputFields]
# print("Output fields: ", vals)
times.append((datetime.now() - start_t).total_seconds())
print(f"fields : tot={sum(times):.6f} max={max(times):.6f}s avg={sum(times) / len(times):.6f}s")
##### Score
times = []
for i in range(100):
val = i * 0.001
start_t = datetime.now()
arguments1 = {
"Sepal.Length" : 5.1 + val,
"Sepal.Width" : 3.5 + val,
"Petal.Length" : 1.4 + val,
"Petal.Width" : 0.2 + val,
}
arguments2 = jLinkedHashMap()
for k, v in arguments1.items():
arguments2.put(jString(k), jDouble(v))
arguments = jEvaluatorUtil.encodeKeys(arguments2)
results1 = evaluator.evaluate(arguments)
results2 = jEvaluatorUtil.decodeAll(results1)
times.append((datetime.now() - start_t).total_seconds())
print(f"score : tot={sum(times):.6f} max={max(times):.6f}s avg={sum(times) / len(times):.6f}s")
# jpype
# createjvm: 0.550s
# loadmodel: tot=1.466451 max=1.064521s avg=0.014665s
# fields : tot=0.019881 max=0.009795s avg=0.000199s
# score : tot=0.033356 max=0.023338s avg=0.000334s
# jnius
# createjvm: 0.249s
# loadmodel: tot=1.773011 max=1.385274s avg=0.017730s
# fields : tot=0.039058 max=0.012234s avg=0.000391s
# score : tot=0.067590 max=0.031904s avg=0.000676s
# py4j
# createjvm: 0.222s
# loadmodel: tot=0.616913 max=0.027464s avg=0.006169s
# fields : tot=0.699152 max=0.026426s avg=0.006992s
# score : tot=0.389583 max=0.017620s avg=0.003896s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment