Skip to content

Instantly share code, notes, and snippets.

@maxwillzq
Created July 26, 2023 04:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxwillzq/e2bcd8898dabdf3002d9aee747c22c8a to your computer and use it in GitHub Desktop.
Save maxwillzq/e2bcd8898dabdf3002d9aee747c22c8a to your computer and use it in GitHub Desktop.
test.log
============================= test session starts ==============================
platform linux -- Python 3.11.4, pytest-7.4.0, pluggy-1.2.0
rootdir: /usr/local/google/home/johnqiangzhang/Documents/jaxonnxruntime
configfile: pyproject.toml
collected 3924 items
tests/onnx_ops_test.py .ss..ssss..ssssFsFsssFsFsss.ss.ss.ssFssFss.ss.ss. [ 1%]
ss.ss.ss.ss.ss.ss.s.sss.s.sss.s.sss.s.sss.s.sss.s.sss.s.sss.s.sss.s.sss. [ 3%]
s.sss.s.sss.s.sss.s.sss.s.sss.s.sss.s.sss..ssss..ssss..ssss..ssss.ssxss. [ 4%]
ssxss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss...ssss.sss.. [ 6%]
sssss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss. [ 8%]
ss..sss..sssss.ss.ss.ss.ssxssxssxssxss.ss.ss.ss.ss.ss.ss.ss.ss.ssFss.ss. [ 10%]
ssxssxssxssxssFssxssxssxssxssxssxssxssxssxss.xssss.xssss.xssss.xssss.xss [ 12%]
ss.xssss.xssss.xssss.xssssFxssss.xssss.xssssxxssssxxssssxxssssxxssssxxss [ 14%]
ssxxssss..ssss..ssss..ssss.Fssss.Fssss..sss.Fsssss..ssss.F.ssssF.ssssF.s [ 15%]
sss..ssssF.ssss..ssss..ssss.sss..ssss..ssss..sssss.ss..sss.ss.sss.ss.ss. [ 17%]
ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.s.ss.s.ssss.ss.ss.ss.ss.ss.ss. [ 19%]
ss.ss.ss.ss.ss.ss..sss.ss.ss.ss.sss..ssss..ssss..sss..sssss.ss.ss.ss.ss. [ 21%]
ss.ss.ss.ss..ss.ssss.ss.ss.ss.s.sss.ss..ssss.ss.s.s.sss.ss.sss.ss..sss.. [ 23%]
ssss..sssss.ss.ss.ss.ss.ss.ss...ssss..ssss.ssss.ss.sxssxsss.ss..ssss.ss. [ 25%]
ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss..ssss.ss.ss.ss.ss.ss.ss.ss.ss.ss. [ 26%]
ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.s.sss.s.sss.ss...ssss..ssssss.ss.ss. [ 28%]
ss.ss.s.ss.sss.ss..ssss..ssss.ss.ss.ss.ss..sss..sssss..sss..sssss.ss.ss. [ 30%]
ss.ss.ss.ss.ss...ssss..ssss.ssss..ssss.s.ss.sss.s.ss.sss.ss.ss.s.ss.sss. [ 32%]
ss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ss [ 34%]
ss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ss [ 36%]
ss..ss.ssss..ss.ssss..ss.ssss...ssss..ssss.ssss.ss...ssss..ssssss..ssss. [ 37%]
.ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ssssFssFssFss. [ 39%]
ss.ss.ss.ss.ss.ss.ss.ss..ssss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss. [ 41%]
ss.ss.ss.ss.ss.ss.ss.ss.ssxss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ssFssFss.ss. [ 43%]
ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss..ssss.ss.ss.ss. [ 45%]
ss.ss.ss.ss.ss.ss.ss.ss.ss.ssFsFsss.ss..ssss.ss..ss.ssss..ssssFss.Fssss. [ 47%]
Fsss.Fssss.Fsssss.Fsss.Fsssss.Fsss.Fssss.Fssss.Fsssss.Fsss.Fssss.Fsss.Fs [ 48%]
sssss.Fssss.Fssss.Fssss.Fssss.ss.ss.ss.ss.ss.ss.ss.ss.ssxss.ss.ss.ss.ss. [ 50%]
ss.ss.ss.ss.ss.ss.ssFssFss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss. [ 52%]
.ssss.ss.ss.ss.ss.ss.ss.ss.ss..ssss..ssss.ss.ss.ss.ss..ss.ssss..ssss..ss [ 54%]
ss..ssss..ssss..ssss..ssss..ssss..ssss..ssss..ssss..ssss..ssss..ssss..ss [ 56%]
ss..ssss..ssss..ssss..ssss..ssss..ssss..ssss..ssss..ssss..ssss..ssss..ss [ 58%]
ss..ssss..ssss..ssss..ssss..ssss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss. [ 59%]
ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss. [ 61%]
ss.ss.ss.ss..ssss..ssss..ssss..ssss..ssss..ssss..ssss..ssss.ss..ssss.ss. [ 63%]
ss.ss.ss.ss.ss.ss.ss.ss.ssFssFssFssFssFssFssFsFsssFssFssFssFssFssFsFssFs [ 65%]
ssFssFssFssFssFssFssFssFssFsFsssFssFssFssFssFssFssFssFFsssFssFsss.ss.ss. [ 67%]
ss.ss.ss.ss.ss.ssFss.ss.ss.ss.ss.ss.ss.ss.ss.ss.s.ss.ss.sss..sss..sssss. [ 69%]
.sss..sssss..sss..sssss..sss..sssss..sss..sssss..sss..sssss..sss..ssss.. [ 70%]
sss..sssss..sss..sssss..sss..ssssss..sss..sss..sssss..sss..sssss..sss..s [ 72%]
ssss..sssss..sss..sssss..sss..sssss..sss..sssss...ssss..ssss.ssss.ss.ss. [ 74%]
Fssss.Fssss..ssss.Fssss..ssss.Fssss.ss.ss..ss.ss.sss..ss.ssss.sss..ssss. [ 76%]
.ssss..ssss.ss.ss.ss.ss..ssss..ssss..ssss..ss.ss.sss.s.sss.ss.sss..ss.ss [ 78%]
ss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ssss..ss.ssss...ssss.ssss...s [ 80%]
sss.ssss..ssss.ss.ss.ss.ss.s.sss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss. [ 81%]
.ssss.s.sssFss.ss.ss.ss.ss.ss.ss.ss.ss..ssss.ss.ss.ss.ss..ssss..ssss.ss. [ 83%]
ss.ss.ss.ss.ss.ss...ssss..ssss.ssss.s.sss.s.ss.sssFFsFssssFsssFsFsss.ss. [ 85%]
ss.ss.ss.ss.ss.ss.s.ss.ss.ss.ss.ss.s.ssss.ss.s.ss.ss.ss.ss.ss.s.ssss.ss. [ 87%]
ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss. [ 89%]
ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.s.s [ 91%]
ss.s.sss.s.ss.sss.ss.ss.ss.ss.ss.ss..sss.ss.ss.ss.ss.ss.sss..s.ss.sss.ss [ 92%]
.sss.s.sss.ss.ss.sss..s.ssss.ss.ss.s.ssss.s.sss.ss.s.sss..ssss.ss.ss.s.s [ 94%]
ss.ssFsFsFssss.s.sssFsFsFssss.sFsss.sFsss.sFsss.ss.ss.ss.ss.ss.ss.ss.ss. [ 96%]
ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ssFss.ss.ss.ss.ss.ss.ssFss. [ 98%]
ssFss.ss.ss.ss.ss.ss.ss.ss.s.sss.s.sss..ssss.ss.ss.s.sssFss [100%]
=================================== FAILURES ===================================
__________________ OnnxBackendNodeModelTest.test_adagrad_cpu ___________________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_adagrad_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:105: in call_onnx_graph
handlers = _get_all_handlers(opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
opset = [domain: "ai.onnx.preview.training"
version: 1
]
def _get_all_handlers(
opset: Sequence[onnx.OperatorSetIdProto],
) -> Dict[str, Dict[str, Type[Handler]]]:
"""Get all ONNX OP_TYPE handlers from Handler subclasses.
Args:
opset: An OperatorSetIdProto message containing the operator set version
information.
Returns:
A dictionary of all the ONNX handlers, where the keys are the domain
names
and the values are nested dictionaries mapping operator names to their
Handler
subclasses.
Raises:
ValueError: If there is no OP_TYPE attribute defined in the Handler class.
"""
handlers: Dict[Any, Any] = {}
for handler in Handler.__subclasses__():
if not hasattr(handler, 'OP_TYPE'):
logger.warning(
(
"%s doesn't have ONNX OP_TYPE. "
'Please use handler.register_op decorator to register it.'
),
handler.__name__,
)
domain = handler.DOMAIN
opset_dict = dict([(o.domain, o.version) for o in opset])
if handler.DOMAIN not in opset_dict:
> raise ValueError(
f'handler.DOMAIN {handler.DOMAIN} is not in opset_dict {opset_dict}'
)
E ValueError: handler.DOMAIN is not in opset_dict {'ai.onnx.preview.training': 1}
jaxonnxruntime/call_onnx.py:229: ValueError
______________ OnnxBackendNodeModelTest.test_adagrad_multiple_cpu ______________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_adagrad_multiple_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:105: in call_onnx_graph
handlers = _get_all_handlers(opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
opset = [domain: "ai.onnx.preview.training"
version: 1
]
def _get_all_handlers(
opset: Sequence[onnx.OperatorSetIdProto],
) -> Dict[str, Dict[str, Type[Handler]]]:
"""Get all ONNX OP_TYPE handlers from Handler subclasses.
Args:
opset: An OperatorSetIdProto message containing the operator set version
information.
Returns:
A dictionary of all the ONNX handlers, where the keys are the domain
names
and the values are nested dictionaries mapping operator names to their
Handler
subclasses.
Raises:
ValueError: If there is no OP_TYPE attribute defined in the Handler class.
"""
handlers: Dict[Any, Any] = {}
for handler in Handler.__subclasses__():
if not hasattr(handler, 'OP_TYPE'):
logger.warning(
(
"%s doesn't have ONNX OP_TYPE. "
'Please use handler.register_op decorator to register it.'
),
handler.__name__,
)
domain = handler.DOMAIN
opset_dict = dict([(o.domain, o.version) for o in opset])
if handler.DOMAIN not in opset_dict:
> raise ValueError(
f'handler.DOMAIN {handler.DOMAIN} is not in opset_dict {opset_dict}'
)
E ValueError: handler.DOMAIN is not in opset_dict {'ai.onnx.preview.training': 1}
jaxonnxruntime/call_onnx.py:229: ValueError
____________________ OnnxBackendNodeModelTest.test_adam_cpu ____________________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_adam_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:105: in call_onnx_graph
handlers = _get_all_handlers(opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
opset = [domain: "ai.onnx.preview.training"
version: 1
]
def _get_all_handlers(
opset: Sequence[onnx.OperatorSetIdProto],
) -> Dict[str, Dict[str, Type[Handler]]]:
"""Get all ONNX OP_TYPE handlers from Handler subclasses.
Args:
opset: An OperatorSetIdProto message containing the operator set version
information.
Returns:
A dictionary of all the ONNX handlers, where the keys are the domain
names
and the values are nested dictionaries mapping operator names to their
Handler
subclasses.
Raises:
ValueError: If there is no OP_TYPE attribute defined in the Handler class.
"""
handlers: Dict[Any, Any] = {}
for handler in Handler.__subclasses__():
if not hasattr(handler, 'OP_TYPE'):
logger.warning(
(
"%s doesn't have ONNX OP_TYPE. "
'Please use handler.register_op decorator to register it.'
),
handler.__name__,
)
domain = handler.DOMAIN
opset_dict = dict([(o.domain, o.version) for o in opset])
if handler.DOMAIN not in opset_dict:
> raise ValueError(
f'handler.DOMAIN {handler.DOMAIN} is not in opset_dict {opset_dict}'
)
E ValueError: handler.DOMAIN is not in opset_dict {'ai.onnx.preview.training': 1}
jaxonnxruntime/call_onnx.py:229: ValueError
_______________ OnnxBackendNodeModelTest.test_adam_multiple_cpu ________________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_adam_multiple_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:105: in call_onnx_graph
handlers = _get_all_handlers(opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
opset = [domain: "ai.onnx.preview.training"
version: 1
]
def _get_all_handlers(
opset: Sequence[onnx.OperatorSetIdProto],
) -> Dict[str, Dict[str, Type[Handler]]]:
"""Get all ONNX OP_TYPE handlers from Handler subclasses.
Args:
opset: An OperatorSetIdProto message containing the operator set version
information.
Returns:
A dictionary of all the ONNX handlers, where the keys are the domain
names
and the values are nested dictionaries mapping operator names to their
Handler
subclasses.
Raises:
ValueError: If there is no OP_TYPE attribute defined in the Handler class.
"""
handlers: Dict[Any, Any] = {}
for handler in Handler.__subclasses__():
if not hasattr(handler, 'OP_TYPE'):
logger.warning(
(
"%s doesn't have ONNX OP_TYPE. "
'Please use handler.register_op decorator to register it.'
),
handler.__name__,
)
domain = handler.DOMAIN
opset_dict = dict([(o.domain, o.version) for o in opset])
if handler.DOMAIN not in opset_dict:
> raise ValueError(
f'handler.DOMAIN {handler.DOMAIN} is not in opset_dict {opset_dict}'
)
E ValueError: handler.DOMAIN is not in opset_dict {'ai.onnx.preview.training': 1}
jaxonnxruntime/call_onnx.py:229: ValueError
_____ OnnxBackendNodeModelTest.test_ai_onnx_ml_array_feature_extractor_cpu _____
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_ai_onnx_ml_array_feature_extractor_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:105: in call_onnx_graph
handlers = _get_all_handlers(opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
opset = [domain: "ai.onnx.ml"
version: 1
]
def _get_all_handlers(
opset: Sequence[onnx.OperatorSetIdProto],
) -> Dict[str, Dict[str, Type[Handler]]]:
"""Get all ONNX OP_TYPE handlers from Handler subclasses.
Args:
opset: An OperatorSetIdProto message containing the operator set version
information.
Returns:
A dictionary of all the ONNX handlers, where the keys are the domain
names
and the values are nested dictionaries mapping operator names to their
Handler
subclasses.
Raises:
ValueError: If there is no OP_TYPE attribute defined in the Handler class.
"""
handlers: Dict[Any, Any] = {}
for handler in Handler.__subclasses__():
if not hasattr(handler, 'OP_TYPE'):
logger.warning(
(
"%s doesn't have ONNX OP_TYPE. "
'Please use handler.register_op decorator to register it.'
),
handler.__name__,
)
domain = handler.DOMAIN
opset_dict = dict([(o.domain, o.version) for o in opset])
if handler.DOMAIN not in opset_dict:
> raise ValueError(
f'handler.DOMAIN {handler.DOMAIN} is not in opset_dict {opset_dict}'
)
E ValueError: handler.DOMAIN is not in opset_dict {'ai.onnx.ml': 1}
jaxonnxruntime/call_onnx.py:229: ValueError
____________ OnnxBackendNodeModelTest.test_ai_onnx_ml_binarizer_cpu ____________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_ai_onnx_ml_binarizer_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:105: in call_onnx_graph
handlers = _get_all_handlers(opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
opset = [domain: "ai.onnx.ml"
version: 1
]
def _get_all_handlers(
opset: Sequence[onnx.OperatorSetIdProto],
) -> Dict[str, Dict[str, Type[Handler]]]:
"""Get all ONNX OP_TYPE handlers from Handler subclasses.
Args:
opset: An OperatorSetIdProto message containing the operator set version
information.
Returns:
A dictionary of all the ONNX handlers, where the keys are the domain
names
and the values are nested dictionaries mapping operator names to their
Handler
subclasses.
Raises:
ValueError: If there is no OP_TYPE attribute defined in the Handler class.
"""
handlers: Dict[Any, Any] = {}
for handler in Handler.__subclasses__():
if not hasattr(handler, 'OP_TYPE'):
logger.warning(
(
"%s doesn't have ONNX OP_TYPE. "
'Please use handler.register_op decorator to register it.'
),
handler.__name__,
)
domain = handler.DOMAIN
opset_dict = dict([(o.domain, o.version) for o in opset])
if handler.DOMAIN not in opset_dict:
> raise ValueError(
f'handler.DOMAIN {handler.DOMAIN} is not in opset_dict {opset_dict}'
)
E ValueError: handler.DOMAIN is not in opset_dict {'ai.onnx.ml': 1}
jaxonnxruntime/call_onnx.py:229: ValueError
___________ OnnxBackendNodeModelTest.test_cast_FLOAT_to_BFLOAT16_cpu ___________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_cast_FLOAT_to_BFLOAT16_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
> self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
jaxonnxruntime/runner.py:361:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class 'onnx_ops_test.Runner'>
ref_outputs = [array([[16117, 16117, 16127, 16209],
[16112, 16209, 15959, 16185],
[32704, 32640, 32640, 65408]], dtype=uint16)]
outputs = [Array([[0.478516, 0.480469, 0.5, 0.820312],
[0.470703, 0.816406, 0.210938, 0.722656],
[nan, inf, inf, -inf]], dtype=bfloat16)]
rtol = 0.001, atol = 1e-07
@classmethod
def assert_similar_outputs(
cls,
ref_outputs: Sequence[Any],
outputs: Sequence[Any],
rtol: float,
atol: float,
) -> None:
"""Assert that two sequences of outputs are similar within given tolerances."""
np.testing.assert_equal(len(outputs), len(ref_outputs))
for i, _ in enumerate(outputs):
if isinstance(outputs[i], (list, tuple)):
for j, _ in enumerate(outputs[i]):
cls.assert_similar_outputs(
ref_outputs[i][j], outputs[i][j], rtol, atol
)
else:
> np.testing.assert_equal(outputs[i].dtype, ref_outputs[i].dtype)
E AssertionError:
E Items are not equal:
E ACTUAL: dtype(bfloat16)
E DESIRED: dtype('uint16')
jaxonnxruntime/runner.py:253: AssertionError
____________ OnnxBackendNodeModelTest.test_cast_FLOAT_to_STRING_cpu ____________
Unexpected success
----------------------------- Captured stdout call -----------------------------
Cast JAX version do not support STRING type yet.
_________ OnnxBackendNodeModelTest.test_castlike_FLOAT_to_BFLOAT16_cpu _________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_castlike_FLOAT_to_BFLOAT16_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
> self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
jaxonnxruntime/runner.py:361:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:257: in assert_similar_outputs
np.testing.assert_allclose(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x7f97ec1f1800>, array([[ 0, 0, 0, 0],
[ 0, ...[16117, 16117, 16127, 16209],
[16112, 16209, 15959, 16185],
[32704, 32640, 32640, 65408]], dtype=uint16))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0.001, atol=1e-07', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=0.001, atol=1e-07
E
E Mismatched elements: 12 / 12 (100%)
E Max absolute difference: 32704
E Max relative difference: 1.00003064
E x: array([[ 0, 0, 0, 0],
E [ 0, 0, 0, 0],
E [ 0, 65535, 65535, 0]], dtype=uint16)
E y: array([[16117, 16117, 16127, 16209],
E [16112, 16209, 15959, 16185],
E [32704, 32640, 32640, 65408]], dtype=uint16)
/usr/lib/python3.11/contextlib.py:81: AssertionError
___ OnnxBackendNodeModelTest.test_center_crop_pad_crop_axes_chw_expanded_cpu ___
graph = node {
output: "CenterCropPad_test_center_crop_pad_crop_axes_chw_expanded_function_k2"
name: "node_0"
op_type: "... }
dim {
dim_value: 10
}
dim {
dim_value: 9
}
}
}
}
}
tensor_dict = {'CenterCropPad_test_center_crop_pad_crop_axes_chw_expanded_function_axes_input': Array([1, 2], dtype=int64), 'CenterC... 'CenterCropPad_test_center_crop_pad_crop_axes_chw_expanded_function_pad_amount_left': Array([0, 0], dtype=int64), ...}
opset = [domain: ""
version: 18
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97ec1611e0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_center_crop_pad_crop_axes_chw_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
output: "CenterCropPad_test_center_crop_pad_crop_axes_chw_expanded_function_k2"
name: "node_0"
op_type: "... }
dim {
dim_value: 10
}
dim {
dim_value: 9
}
}
}
}
}
tensor_dict = {'CenterCropPad_test_center_crop_pad_crop_axes_chw_expanded_function_axes_input': Array([1, 2], dtype=int64), 'CenterC... 'CenterCropPad_test_center_crop_pad_crop_axes_chw_expanded_function_pad_amount_left': Array([0, 0], dtype=int64), ...}
opset = [domain: ""
version: 18
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['x', 'CenterCropPad_test_center_crop_pad_crop_axes_chw_expanded_function_pads', '', 'CenterCropPad_test_center_crop_pad_crop_axes_chw_expanded_function_axes_input'], the node proto isinput: "x"
E input: "CenterCropPad_test_center_crop_pad_crop_axes_chw_expanded_function_pads"
E input: ""
E input: "CenterCropPad_test_center_crop_pad_crop_axes_chw_expanded_function_axes_input"
E output: "CenterCropPad_test_center_crop_pad_crop_axes_chw_expanded_function_padded_input"
E name: "node_9"
E op_type: "Pad"
E domain: ""
E .
jaxonnxruntime/call_onnx.py:115: ValueError
___ OnnxBackendNodeModelTest.test_center_crop_pad_crop_axes_hwc_expanded_cpu ___
graph = node {
output: "CenterCropPad_test_center_crop_pad_crop_axes_hwc_expanded_function_k2"
name: "node_0"
op_type: "... }
dim {
dim_value: 9
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'CenterCropPad_test_center_crop_pad_crop_axes_hwc_expanded_function_axes_input': Array([0, 1], dtype=int64), 'CenterC... 'CenterCropPad_test_center_crop_pad_crop_axes_hwc_expanded_function_pad_amount_left': Array([0, 0], dtype=int64), ...}
opset = [domain: ""
version: 18
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97e47779a0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_center_crop_pad_crop_axes_hwc_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
output: "CenterCropPad_test_center_crop_pad_crop_axes_hwc_expanded_function_k2"
name: "node_0"
op_type: "... }
dim {
dim_value: 9
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'CenterCropPad_test_center_crop_pad_crop_axes_hwc_expanded_function_axes_input': Array([0, 1], dtype=int64), 'CenterC... 'CenterCropPad_test_center_crop_pad_crop_axes_hwc_expanded_function_pad_amount_left': Array([0, 0], dtype=int64), ...}
opset = [domain: ""
version: 18
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['x', 'CenterCropPad_test_center_crop_pad_crop_axes_hwc_expanded_function_pads', '', 'CenterCropPad_test_center_crop_pad_crop_axes_hwc_expanded_function_axes_input'], the node proto isinput: "x"
E input: "CenterCropPad_test_center_crop_pad_crop_axes_hwc_expanded_function_pads"
E input: ""
E input: "CenterCropPad_test_center_crop_pad_crop_axes_hwc_expanded_function_axes_input"
E output: "CenterCropPad_test_center_crop_pad_crop_axes_hwc_expanded_function_padded_input"
E name: "node_9"
E op_type: "Pad"
E domain: ""
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_center_crop_pad_crop_negative_axes_hwc_expanded_cpu _
graph = node {
output: "CenterCropPad_test_center_crop_pad_crop_negative_axes_hwc_expanded_function_k2"
name: "node_0"
o... }
dim {
dim_value: 9
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'CenterCropPad_test_center_crop_pad_crop_negative_axes_hwc_expanded_function_axes_input': Array([-3, -2], dtype=int64...ropPad_test_center_crop_pad_crop_negative_axes_hwc_expanded_function_pad_amount_left': Array([0, 0], dtype=int64), ...}
opset = [domain: ""
version: 18
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97e4468220>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_center_crop_pad_crop_negative_axes_hwc_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
output: "CenterCropPad_test_center_crop_pad_crop_negative_axes_hwc_expanded_function_k2"
name: "node_0"
o... }
dim {
dim_value: 9
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'CenterCropPad_test_center_crop_pad_crop_negative_axes_hwc_expanded_function_axes_input': Array([-3, -2], dtype=int64...ropPad_test_center_crop_pad_crop_negative_axes_hwc_expanded_function_pad_amount_left': Array([0, 0], dtype=int64), ...}
opset = [domain: ""
version: 18
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['x', 'CenterCropPad_test_center_crop_pad_crop_negative_axes_hwc_expanded_function_pads', '', 'CenterCropPad_test_center_crop_pad_crop_negative_axes_hwc_expanded_function_axes_input'], the node proto isinput: "x"
E input: "CenterCropPad_test_center_crop_pad_crop_negative_axes_hwc_expanded_function_pads"
E input: ""
E input: "CenterCropPad_test_center_crop_pad_crop_negative_axes_hwc_expanded_function_axes_input"
E output: "CenterCropPad_test_center_crop_pad_crop_negative_axes_hwc_expanded_function_padded_input"
E name: "node_9"
E op_type: "Pad"
E domain: ""
E .
jaxonnxruntime/call_onnx.py:115: ValueError
___________ OnnxBackendNodeModelTest.test_clip_default_inbounds_cpu ____________
graph = node {
input: "x"
input: ""
input: ""
output: "y"
name: "node_0"
op_type: "Clip"
}
name: "test_clip_defaul...
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'x': array([-1., 0., 1.], dtype=float32)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97e451f460>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_clip_default_inbounds_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "x"
input: ""
input: ""
output: "y"
name: "node_0"
op_type: "Clip"
}
name: "test_clip_defaul...
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'x': array([-1., 0., 1.], dtype=float32)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['x', '', ''], the node proto isinput: "x"
E input: ""
E input: ""
E output: "y"
E name: "node_0"
E op_type: "Clip"
E .
jaxonnxruntime/call_onnx.py:115: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_________ OnnxBackendNodeModelTest.test_clip_default_int8_inbounds_cpu _________
graph = node {
input: "x"
input: ""
input: ""
output: "y"
name: "node_0"
op_type: "Clip"
}
name: "test_clip_defaul...
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'x': array([-1, 0, 1], dtype=int8)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97e44693f0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_clip_default_int8_inbounds_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "x"
input: ""
input: ""
output: "y"
name: "node_0"
op_type: "Clip"
}
name: "test_clip_defaul...
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'x': array([-1, 0, 1], dtype=int8)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['x', '', ''], the node proto isinput: "x"
E input: ""
E input: ""
E output: "y"
E name: "node_0"
E op_type: "Clip"
E .
jaxonnxruntime/call_onnx.py:115: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
___________ OnnxBackendNodeModelTest.test_clip_default_int8_max_cpu ____________
graph = node {
input: "x"
input: ""
input: "max"
output: "y"
name: "node_0"
op_type: "Clip"
}
name: "test_clip_def... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'max': array(0, dtype=int8), 'x': array([[[ 0, 0, 0, -1, 0],
[ 0, -1, 0, 0, 0],
[ 0, 0, 1, -1..., -1, 0, -1],
[ 1, 0, 0, 1, 1],
[ 1, 0, 0, 1, 0],
[ 0, 0, 0, 0, 0]]], dtype=int8)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97e44519f0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_clip_default_int8_max_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "x"
input: ""
input: "max"
output: "y"
name: "node_0"
op_type: "Clip"
}
name: "test_clip_def... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'max': array(0, dtype=int8), 'x': array([[[ 0, 0, 0, -1, 0],
[ 0, -1, 0, 0, 0],
[ 0, 0, 1, -1..., -1, 0, -1],
[ 1, 0, 0, 1, 1],
[ 1, 0, 0, 1, 0],
[ 0, 0, 0, 0, 0]]], dtype=int8)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['x', '', 'max'], the node proto isinput: "x"
E input: ""
E input: "max"
E output: "y"
E name: "node_0"
E op_type: "Clip"
E .
jaxonnxruntime/call_onnx.py:115: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
______________ OnnxBackendNodeModelTest.test_clip_default_max_cpu ______________
graph = node {
input: "x"
input: ""
input: "max"
output: "y"
name: "node_0"
op_type: "Clip"
}
name: "test_clip_def... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'max': array(0., dtype=float32), 'x': array([[[-0.67246044, -0.35955316, -0.8131463 , -1.7262826 ,
0.177426... -0.26800337],
[ 0.8024564 , 0.947252 , -0.15501009, 0.61407936,
0.9222067 ]]], dtype=float32)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97e44d5840>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_clip_default_max_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "x"
input: ""
input: "max"
output: "y"
name: "node_0"
op_type: "Clip"
}
name: "test_clip_def... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'max': array(0., dtype=float32), 'x': array([[[-0.67246044, -0.35955316, -0.8131463 , -1.7262826 ,
0.177426... -0.26800337],
[ 0.8024564 , 0.947252 , -0.15501009, 0.61407936,
0.9222067 ]]], dtype=float32)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['x', '', 'max'], the node proto isinput: "x"
E input: ""
E input: "max"
E output: "y"
E name: "node_0"
E op_type: "Clip"
E .
jaxonnxruntime/call_onnx.py:115: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
___________________ OnnxBackendNodeModelTest.test_loop11_cpu ___________________
graph = node {
input: "trip_count"
input: "cond"
input: "y"
output: "res_y"
output: "res_scan"
name: "node_0"
op...pe {
dim {
dim_value: 5
}
dim {
dim_value: 1
}
}
}
}
}
tensor_dict = {'cond': array(True), 'trip_count': array(5), 'y': array([-2.], dtype=float32)}
opset = [domain: ""
version: 11
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97a423c460>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: 'cond_in'
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_loop11_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "trip_count"
input: "cond"
input: "y"
output: "res_y"
output: "res_scan"
name: "node_0"
op...pe {
dim {
dim_value: 5
}
dim {
dim_value: 1
}
}
}
}
}
tensor_dict = {'cond': array(True), 'trip_count': array(5), 'y': array([-2.], dtype=float32)}
opset = [domain: ""
version: 11
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['trip_count', 'cond', 'y', 'cond_in', 'iter_count', 'y_in'], the node proto isinput: "trip_count"
E input: "cond"
E input: "y"
E output: "res_y"
E output: "res_scan"
E name: "node_0"
E op_type: "Loop"
E attribute {
E name: "body"
E type: GRAPH
E g {
E node {
E input: "cond_in"
E output: "cond_out"
E op_type: "Identity"
E }
E node {
E output: "x"
E op_type: "Constant"
E attribute {
E name: "value"
E type: TENSOR
E t {
E dims: 5
E data_type: 1
E float_data: 1
E float_data: 2
E float_data: 3
E float_data: 4
E float_data: 5
E name: "const_tensor_x"
E }
E }
E }
E node {
E output: "one"
E op_type: "Constant"
E attribute {
E name: "value"
E type: TENSOR
E t {
E data_type: 7
E int64_data: 1
E name: "const_tensor_one"
E }
E }
E }
E node {
E input: "iter_count"
E input: "one"
E output: "end"
E op_type: "Add"
E }
E node {
E input: "iter_count"
E output: "slice_start"
E op_type: "Unsqueeze"
E attribute {
E name: "axes"
E type: INTS
E ints: 0
E }
E }
E node {
E input: "end"
E output: "slice_end"
E op_type: "Unsqueeze"
E attribute {
E name: "axes"
E type: INTS
E ints: 0
E }
E }
E node {
E input: "x"
E input: "slice_start"
E input: "slice_end"
E output: "slice_out"
E op_type: "Slice"
E }
E node {
E input: "y_in"
E input: "slice_out"
E output: "y_out"
E op_type: "Add"
E }
E node {
E input: "y_out"
E output: "scan_out"
E op_type: "Identity"
E }
E name: "loop_body"
E input {
E name: "iter_count"
E type {
E tensor_type {
E elem_type: 7
E shape {
E }
E }
E }
E }
E input {
E name: "cond_in"
E type {
E tensor_type {
E elem_type: 9
E shape {
E }
E }
E }
E }
E input {
E name: "y_in"
E type {
E tensor_type {
E elem_type: 1
E shape {
E dim {
E dim_value: 1
E }
E }
E }
E }
E }
E output {
E name: "cond_out"
E type {
E tensor_type {
E elem_type: 9
E shape {
E }
E }
E }
E }
E output {
E name: "y_out"
E type {
E tensor_type {
E elem_type: 1
E shape {
E dim {
E dim_value: 1
E }
E }
E }
E }
E }
E output {
E name: "scan_out"
E type {
E tensor_type {
E elem_type: 1
E shape {
E dim {
E dim_value: 1
E }
E }
E }
E }
E }
E }
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 11. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 11. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 11. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 11. Set to 1.
_________________ OnnxBackendNodeModelTest.test_loop13_seq_cpu _________________
graph = node {
input: "trip_count"
input: "cond"
input: "seq_empty"
output: "seq_res"
name: "node_0"
op_type: "Loo...ype {
sequence_type {
elem_type {
tensor_type {
elem_type: 1
}
}
}
}
}
tensor_dict = {'cond': array(True), 'seq_empty': [], 'trip_count': array(5)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f9784707e80>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: 'cond_in'
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_loop13_seq_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "trip_count"
input: "cond"
input: "seq_empty"
output: "seq_res"
name: "node_0"
op_type: "Loo...ype {
sequence_type {
elem_type {
tensor_type {
elem_type: 1
}
}
}
}
}
tensor_dict = {'cond': array(True), 'seq_empty': [], 'trip_count': array(5)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['trip_count', 'cond', 'seq_empty', 'cond_in', 'iter_count', 'seq_in'], the node proto isinput: "trip_count"
E input: "cond"
E input: "seq_empty"
E output: "seq_res"
E name: "node_0"
E op_type: "Loop"
E attribute {
E name: "body"
E type: GRAPH
E g {
E node {
E input: "cond_in"
E output: "cond_out"
E op_type: "Identity"
E }
E node {
E output: "x"
E op_type: "Constant"
E attribute {
E name: "value"
E type: TENSOR
E t {
E dims: 5
E data_type: 1
E float_data: 1
E float_data: 2
E float_data: 3
E float_data: 4
E float_data: 5
E name: "const_tensor_x"
E }
E }
E }
E node {
E output: "one"
E op_type: "Constant"
E attribute {
E name: "value"
E type: TENSOR
E t {
E data_type: 7
E int64_data: 1
E name: "const_tensor_one"
E }
E }
E }
E node {
E output: "slice_start"
E op_type: "Constant"
E attribute {
E name: "value"
E type: TENSOR
E t {
E dims: 1
E data_type: 7
E int64_data: 0
E name: "const_tensor_zero"
E }
E }
E }
E node {
E input: "iter_count"
E input: "one"
E output: "end"
E op_type: "Add"
E }
E node {
E output: "axes"
E op_type: "Constant"
E attribute {
E name: "value"
E type: TENSOR
E t {
E data_type: 7
E int64_data: 0
E name: "const_tensor_axes"
E }
E }
E }
E node {
E input: "end"
E input: "axes"
E output: "slice_end"
E op_type: "Unsqueeze"
E }
E node {
E input: "x"
E input: "slice_start"
E input: "slice_end"
E output: "slice_out"
E op_type: "Slice"
E }
E node {
E input: "seq_in"
E input: "slice_out"
E output: "seq_out"
E op_type: "SequenceInsert"
E }
E name: "loop_body"
E input {
E name: "iter_count"
E type {
E tensor_type {
E elem_type: 7
E shape {
E }
E }
E }
E }
E input {
E name: "cond_in"
E type {
E tensor_type {
E elem_type: 9
E shape {
E }
E }
E }
E }
E input {
E name: "seq_in"
E type {
E sequence_type {
E elem_type {
E tensor_type {
E elem_type: 1
E }
E }
E }
E }
E }
E output {
E name: "cond_out"
E type {
E tensor_type {
E elem_type: 9
E shape {
E }
E }
E }
E }
E output {
E name: "seq_out"
E type {
E sequence_type {
E elem_type {
E tensor_type {
E elem_type: 1
E }
E }
E }
E }
E }
E }
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
______________ OnnxBackendNodeModelTest.test_loop16_seq_none_cpu _______________
graph = node {
input: "trip_count"
input: "cond"
input: "opt_seq"
output: "seq_res"
name: "node_0"
op_type: "Loop"...ype {
sequence_type {
elem_type {
tensor_type {
elem_type: 1
}
}
}
}
}
tensor_dict = {'cond': array(True), 'opt_seq': [array(0., dtype=float32)], 'trip_count': array(5)}
opset = [domain: ""
version: 16
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97ec15d2d0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: 'cond_in'
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_loop16_seq_none_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "trip_count"
input: "cond"
input: "opt_seq"
output: "seq_res"
name: "node_0"
op_type: "Loop"...ype {
sequence_type {
elem_type {
tensor_type {
elem_type: 1
}
}
}
}
}
tensor_dict = {'cond': array(True), 'opt_seq': [array(0., dtype=float32)], 'trip_count': array(5)}
opset = [domain: ""
version: 16
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['trip_count', 'cond', 'opt_seq', 'cond_in', 'opt_seq_in', 'iter_count'], the node proto isinput: "trip_count"
E input: "cond"
E input: "opt_seq"
E output: "seq_res"
E name: "node_0"
E op_type: "Loop"
E attribute {
E name: "body"
E type: GRAPH
E g {
E node {
E input: "cond_in"
E output: "cond_out"
E op_type: "Identity"
E }
E node {
E input: "opt_seq_in"
E output: "optional_has_elem"
E op_type: "OptionalHasElement"
E }
E node {
E input: "optional_has_elem"
E output: "optional_is_none"
E op_type: "Not"
E }
E node {
E input: "optional_is_none"
E output: "sequence"
E op_type: "If"
E attribute {
E name: "else_branch"
E type: GRAPH
E g {
E node {
E input: "opt_seq_in"
E output: "seq_in"
E op_type: "OptionalGetElement"
E }
E name: "else_body"
E output {
E name: "seq_in"
E type {
E sequence_type {
E elem_type {
E tensor_type {
E elem_type: 1
E shape {
E }
E }
E }
E }
E }
E }
E }
E }
E attribute {
E name: "then_branch"
E type: GRAPH
E g {
E node {
E output: "constant_in"
E op_type: "Constant"
E attribute {
E name: "value"
E type: TENSOR
E t {
E data_type: 1
E float_data: 0
E name: "const_tensor"
E }
E }
E }
E node {
E input: "constant_in"
E output: "init_seq_in"
E op_type: "SequenceConstruct"
E }
E name: "then_body"
E output {
E name: "init_seq_in"
E type {
E sequence_type {
E elem_type {
E tensor_type {
E elem_type: 1
E shape {
E }
E }
E }
E }
E }
E }
E }
E }
E }
E node {
E output: "x"
E op_type: "Constant"
E attribute {
E name: "value"
E type: TENSOR
E t {
E dims: 5
E data_type: 1
E float_data: 1
E float_data: 2
E float_data: 3
E float_data: 4
E float_data: 5
E name: "const_tensor_x"
E }
E }
E }
E node {
E output: "one"
E op_type: "Constant"
E attribute {
E name: "value"
E type: TENSOR
E t {
E data_type: 7
E int64_data: 1
E name: "const_tensor_one"
E }
E }
E }
E node {
E output: "slice_start"
E op_type: "Constant"
E attribute {
E name: "value"
E type: TENSOR
E t {
E dims: 1
E data_type: 7
E int64_data: 0
E name: "const_tensor_zero"
E }
E }
E }
E node {
E input: "iter_count"
E input: "one"
E output: "end"
E op_type: "Add"
E }
E node {
E output: "axes"
E op_type: "Constant"
E attribute {
E name: "value"
E type: TENSOR
E t {
E data_type: 7
E int64_data: 0
E name: "const_tensor_axes"
E }
E }
E }
E node {
E input: "end"
E input: "axes"
E output: "slice_end"
E op_type: "Unsqueeze"
E }
E node {
E input: "x"
E input: "slice_start"
E input: "slice_end"
E output: "slice_out"
E op_type: "Slice"
E }
E node {
E input: "sequence"
E input: "slice_out"
E output: "seq_out"
E op_type: "SequenceInsert"
E }
E name: "loop_body"
E input {
E name: "iter_count"
E type {
E tensor_type {
E elem_type: 7
E shape {
E }
E }
E }
E }
E input {
E name: "cond_in"
E type {
E tensor_type {
E elem_type: 9
E shape {
E }
E }
E }
E }
E input {
E name: "opt_seq_in"
E type {
E optional_type {
E elem_type {
E sequence_type {
E elem_type {
E tensor_type {
E elem_type: 1
E shape {
E }
E }
E }
E }
E }
E }
E }
E }
E output {
E name: "cond_out"
E type {
E tensor_type {
E elem_type: 9
E shape {
E }
E }
E }
E }
E output {
E name: "seq_out"
E type {
E sequence_type {
E elem_type {
E tensor_type {
E elem_type: 1
E shape {
E }
E }
E }
E }
E }
E }
E }
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
__ OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_pads_cpu ___
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_maxpool_with_argmax_2d_precomputed_pads_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:67: in run
return model_func(model_params=model_params, inputs=inputs)
jaxonnxruntime/call_onnx.py:168: in model_func
return [tensor_dict[n.name] for n in graph.output]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <iterator object at 0x7f97a4167700>
> return [tensor_dict[n.name] for n in graph.output]
E KeyError: 'z'
jaxonnxruntime/call_onnx.py:168: KeyError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 12. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 12. Set to 1.
_ OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_strides_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_maxpool_with_argmax_2d_precomputed_strides_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:67: in run
return model_func(model_params=model_params, inputs=inputs)
jaxonnxruntime/call_onnx.py:168: in model_func
return [tensor_dict[n.name] for n in graph.output]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <iterator object at 0x7f97845221a0>
> return [tensor_dict[n.name] for n in graph.output]
E KeyError: 'z'
jaxonnxruntime/call_onnx.py:168: KeyError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 12. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 12. Set to 1.
__________________ OnnxBackendNodeModelTest.test_momentum_cpu __________________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_momentum_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:105: in call_onnx_graph
handlers = _get_all_handlers(opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
opset = [domain: "ai.onnx.preview.training"
version: 1
]
def _get_all_handlers(
opset: Sequence[onnx.OperatorSetIdProto],
) -> Dict[str, Dict[str, Type[Handler]]]:
"""Get all ONNX OP_TYPE handlers from Handler subclasses.
Args:
opset: An OperatorSetIdProto message containing the operator set version
information.
Returns:
A dictionary of all the ONNX handlers, where the keys are the domain
names
and the values are nested dictionaries mapping operator names to their
Handler
subclasses.
Raises:
ValueError: If there is no OP_TYPE attribute defined in the Handler class.
"""
handlers: Dict[Any, Any] = {}
for handler in Handler.__subclasses__():
if not hasattr(handler, 'OP_TYPE'):
logger.warning(
(
"%s doesn't have ONNX OP_TYPE. "
'Please use handler.register_op decorator to register it.'
),
handler.__name__,
)
domain = handler.DOMAIN
opset_dict = dict([(o.domain, o.version) for o in opset])
if handler.DOMAIN not in opset_dict:
> raise ValueError(
f'handler.DOMAIN {handler.DOMAIN} is not in opset_dict {opset_dict}'
)
E ValueError: handler.DOMAIN is not in opset_dict {'ai.onnx.preview.training': 1}
jaxonnxruntime/call_onnx.py:229: ValueError
_____________ OnnxBackendNodeModelTest.test_momentum_multiple_cpu ______________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_momentum_multiple_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:105: in call_onnx_graph
handlers = _get_all_handlers(opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
opset = [domain: "ai.onnx.preview.training"
version: 1
]
def _get_all_handlers(
opset: Sequence[onnx.OperatorSetIdProto],
) -> Dict[str, Dict[str, Type[Handler]]]:
"""Get all ONNX OP_TYPE handlers from Handler subclasses.
Args:
opset: An OperatorSetIdProto message containing the operator set version
information.
Returns:
A dictionary of all the ONNX handlers, where the keys are the domain
names
and the values are nested dictionaries mapping operator names to their
Handler
subclasses.
Raises:
ValueError: If there is no OP_TYPE attribute defined in the Handler class.
"""
handlers: Dict[Any, Any] = {}
for handler in Handler.__subclasses__():
if not hasattr(handler, 'OP_TYPE'):
logger.warning(
(
"%s doesn't have ONNX OP_TYPE. "
'Please use handler.register_op decorator to register it.'
),
handler.__name__,
)
domain = handler.DOMAIN
opset_dict = dict([(o.domain, o.version) for o in opset])
if handler.DOMAIN not in opset_dict:
> raise ValueError(
f'handler.DOMAIN {handler.DOMAIN} is not in opset_dict {opset_dict}'
)
E ValueError: handler.DOMAIN is not in opset_dict {'ai.onnx.preview.training': 1}
jaxonnxruntime/call_onnx.py:229: ValueError
_____________ OnnxBackendNodeModelTest.test_nesterov_momentum_cpu ______________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nesterov_momentum_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:105: in call_onnx_graph
handlers = _get_all_handlers(opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
opset = [domain: "ai.onnx.preview.training"
version: 1
]
def _get_all_handlers(
opset: Sequence[onnx.OperatorSetIdProto],
) -> Dict[str, Dict[str, Type[Handler]]]:
"""Get all ONNX OP_TYPE handlers from Handler subclasses.
Args:
opset: An OperatorSetIdProto message containing the operator set version
information.
Returns:
A dictionary of all the ONNX handlers, where the keys are the domain
names
and the values are nested dictionaries mapping operator names to their
Handler
subclasses.
Raises:
ValueError: If there is no OP_TYPE attribute defined in the Handler class.
"""
handlers: Dict[Any, Any] = {}
for handler in Handler.__subclasses__():
if not hasattr(handler, 'OP_TYPE'):
logger.warning(
(
"%s doesn't have ONNX OP_TYPE. "
'Please use handler.register_op decorator to register it.'
),
handler.__name__,
)
domain = handler.DOMAIN
opset_dict = dict([(o.domain, o.version) for o in opset])
if handler.DOMAIN not in opset_dict:
> raise ValueError(
f'handler.DOMAIN {handler.DOMAIN} is not in opset_dict {opset_dict}'
)
E ValueError: handler.DOMAIN is not in opset_dict {'ai.onnx.preview.training': 1}
jaxonnxruntime/call_onnx.py:229: ValueError
____________ OnnxBackendNodeModelTest.test_nllloss_NC_expanded_cpu _____________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NC_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:51: in onnx_gatherelements
data_swaped = jnp.swapaxes(data, 0, axis)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(int64[3])>with<DynamicJaxprTrace(level=2/0)>, axis1 = 0
axis2 = 1
@util._wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('axis1', 'axis2'), inline=True)
def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
util.check_arraylike("swapaxes", a)
perm = np.arange(ndim(a))
> perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
E jax._src.traceback_util.UnfilteredStackTrace: IndexError: index 1 is out of bounds for axis 0 with size 1
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:898: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NC_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:51: in onnx_gatherelements
data_swaped = jnp.swapaxes(data, 0, axis)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(int64[3])>with<DynamicJaxprTrace(level=2/0)>, axis1 = 0
axis2 = 1
@util._wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('axis1', 'axis2'), inline=True)
def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
util.check_arraylike("swapaxes", a)
perm = np.arange(ndim(a))
> perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
E IndexError: index 1 is out of bounds for axis 0 with size 1
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:898: IndexError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
___________ OnnxBackendNodeModelTest.test_nllloss_NCd1_expanded_cpu ____________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,2])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[2,3])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,2])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[2,3])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
__________ OnnxBackendNodeModelTest.test_nllloss_NCd1_ii_expanded_cpu __________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,2])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[2,3])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,2])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[2,3])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_ OnnxBackendNodeModelTest.test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
________ OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_expanded_cpu ________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1_weight_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,2])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[2,3])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1_weight_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,2])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[2,3])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
______ OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_ii_expanded_cpu _______
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1_weight_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,2])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[2,3])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1_weight_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,2])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[2,3])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
__________ OnnxBackendNodeModelTest.test_nllloss_NCd1d2_expanded_cpu ___________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_ OnnxBackendNodeModelTest.test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
___ OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_mean_expanded_cpu ___
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_reduction_mean_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_reduction_mean_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
___ OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_sum_expanded_cpu ____
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_reduction_sum_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_reduction_sum_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
____ OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_expanded_cpu _____
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_with_weight_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_with_weight_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_ OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_ OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_ OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_ OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6,5])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6,5])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6,5])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6,5])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_ OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:51: in onnx_gatherelements
data_swaped = jnp.swapaxes(data, 0, axis)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(int64[3])>with<DynamicJaxprTrace(level=2/0)>, axis1 = 0
axis2 = 1
@util._wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('axis1', 'axis2'), inline=True)
def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
util.check_arraylike("swapaxes", a)
perm = np.arange(ndim(a))
> perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
E jax._src.traceback_util.UnfilteredStackTrace: IndexError: index 1 is out of bounds for axis 0 with size 1
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:898: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:51: in onnx_gatherelements
data_swaped = jnp.swapaxes(data, 0, axis)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(int64[3])>with<DynamicJaxprTrace(level=2/0)>, axis1 = 0
axis2 = 1
@util._wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('axis1', 'axis2'), inline=True)
def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
util.check_arraylike("swapaxes", a)
perm = np.arange(ndim(a))
> perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
E IndexError: index 1 is out of bounds for axis 0 with size 1
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:898: IndexError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_ OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu __
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6,5,3,4])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6,5,3,4])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6,5,3,4])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6,5,3,4])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_ OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6,5,3,4])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6,5,3,4])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: `a` array must be integer typed
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/gatherelements.py:53: in onnx_gatherelements
gathered = jnp.choose(index_swaped, data_swaped, mode='wrap')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Traced<ShapedArray(float32[1,3,5,6,6,5,3,4])>with<DynamicJaxprTrace(level=1/0)>
choices = Traced<ShapedArray(int64[6,3,6,5,3,4])>with<DynamicJaxprTrace(level=1/0)>
out = None, mode = 'wrap'
@util._wraps(np.choose, skip_params=['out'])
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
util.check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
> raise ValueError("`a` array must be integer typed")
E ValueError: `a` array must be integer typed
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1891: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_ OnnxBackendNodeModelTest.test_optional_has_element_empty_no_input_name_optional_input_cpu _
graph = node {
input: ""
output: "output"
name: "node_0"
op_type: "OptionalHasElement"
}
name: "test_optional_has_elem...ional_input"
output {
name: "output"
type {
tensor_type {
elem_type: 9
shape {
}
}
}
}
tensor_dict = {}, opset = [domain: ""
version: 18
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97443b37c0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_optional_has_element_empty_no_input_name_optional_input_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: ""
output: "output"
name: "node_0"
op_type: "OptionalHasElement"
}
name: "test_optional_has_elem...ional_input"
output {
name: "output"
type {
tensor_type {
elem_type: 9
shape {
}
}
}
}
tensor_dict = {}, opset = [domain: ""
version: 18
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names[''], the node proto isinput: ""
E output: "output"
E name: "node_0"
E op_type: "OptionalHasElement"
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_optional_has_element_empty_no_input_name_tensor_input_cpu _
graph = node {
input: ""
output: "output"
name: "node_0"
op_type: "OptionalHasElement"
}
name: "test_optional_has_elem...ensor_input"
output {
name: "output"
type {
tensor_type {
elem_type: 9
shape {
}
}
}
}
tensor_dict = {}, opset = [domain: ""
version: 18
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f974458dba0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_optional_has_element_empty_no_input_name_tensor_input_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: ""
output: "output"
name: "node_0"
op_type: "OptionalHasElement"
}
name: "test_optional_has_elem...ensor_input"
output {
name: "output"
type {
tensor_type {
elem_type: 9
shape {
}
}
}
}
tensor_dict = {}, opset = [domain: ""
version: 18
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names[''], the node proto isinput: ""
E output: "output"
E name: "node_0"
E op_type: "OptionalHasElement"
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_cpu _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1. , 1. , 0.8, 0.8], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f977c763c10>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1. , 1. , 0.8, 0.8], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "cubic_coeff_a"
E type: FLOAT
E f: -0.5
E }
E attribute {
E name: "exclude_outside"
E type: INT
E i: 1
E }
E attribute {
E name: "mode"
E type: STRING
E s: "cubic"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_downsample_scales_cubic_align_corners_cpu _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1. , 1. , 0.8, 0.8], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f972c3efdc0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_scales_cubic_align_corners_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1. , 1. , 0.8, 0.8], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "align_corners"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "cubic"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
__ OnnxBackendNodeModelTest.test_resize_downsample_scales_cubic_antialias_cpu __
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1. , 1. , 0.6, 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97445e9330>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_scales_cubic_antialias_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1. , 1. , 0.6, 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "antialias"
E type: INT
E i: 1
E }
E attribute {
E name: "mode"
E type: STRING
E s: "cubic"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_______ OnnxBackendNodeModelTest.test_resize_downsample_scales_cubic_cpu _______
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1. , 1. , 0.8, 0.8], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f972c7134c0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_scales_cubic_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1. , 1. , 0.8, 0.8], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "mode"
E type: STRING
E s: "cubic"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_downsample_scales_linear_align_corners_cpu _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'scales': array([1. , 1. , 0.6, 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97840e48e0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_scales_linear_align_corners_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'scales': array([1. , 1. , 0.6, 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "align_corners"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_downsample_scales_linear_antialias_cpu __
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1. , 1. , 0.6, 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97443b87c0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_scales_linear_antialias_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1. , 1. , 0.6, 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "antialias"
E type: INT
E i: 1
E }
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
______ OnnxBackendNodeModelTest.test_resize_downsample_scales_linear_cpu _______
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'scales': array([1. , 1. , 0.6, 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f978429f190>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_scales_linear_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'scales': array([1. , 1. , 0.6, 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_downsample_scales_linear_half_pixel_symmetric_cpu _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.]]]], dtype=float32), 'scales': array([1. , 1. , 1. , 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f9744221b40>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_scales_linear_half_pixel_symmetric_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.]]]], dtype=float32), 'scales': array([1. , 1. , 1. , 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "half_pixel_symmetric"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
______ OnnxBackendNodeModelTest.test_resize_downsample_scales_nearest_cpu ______
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'scales': array([1. , 1. , 0.6, 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f978411a530>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_scales_nearest_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'scales': array([1. , 1. , 0.6, 0.6], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
__ OnnxBackendNodeModelTest.test_resize_downsample_sizes_cubic_antialias_cpu ___
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f9784580100>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_sizes_cubic_antialias_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "antialias"
E type: INT
E i: 1
E }
E attribute {
E name: "mode"
E type: STRING
E s: "cubic"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_______ OnnxBackendNodeModelTest.test_resize_downsample_sizes_cubic_cpu ________
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f9744296bf0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_sizes_cubic_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "mode"
E type: STRING
E s: "cubic"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
__ OnnxBackendNodeModelTest.test_resize_downsample_sizes_linear_antialias_cpu __
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97087097b0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_sizes_linear_antialias_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "antialias"
E type: INT
E i: 1
E }
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_downsample_sizes_linear_pytorch_half_pixel_cpu _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 3
}
dim {
dim_value: 1
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 3, 1])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97446195d0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_sizes_linear_pytorch_half_pixel_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 3
}
dim {
dim_value: 1
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 3, 1])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "pytorch_half_pixel"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
______ OnnxBackendNodeModelTest.test_resize_downsample_sizes_nearest_cpu _______
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'sizes': array([1, 1, 1, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f972c713520>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_sizes_nearest_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'sizes': array([1, 1, 1, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_downsample_sizes_nearest_not_larger_cpu _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'sizes': array([1, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f972c1c6e00>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_sizes_nearest_not_larger_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'sizes': array([1, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "axes"
E type: INTS
E ints: 2
E ints: 3
E }
E attribute {
E name: "keep_aspect_ratio_policy"
E type: STRING
E s: "not_larger"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_downsample_sizes_nearest_not_smaller_cpu _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'sizes': array([1, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f978411b340>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_downsample_sizes_nearest_not_smaller_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.]]]], dtype=float32), 'sizes': array([1, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "axes"
E type: INTS
E ints: 2
E ints: 3
E }
E attribute {
E name: "keep_aspect_ratio_policy"
E type: STRING
E s: "not_smaller"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_____ OnnxBackendNodeModelTest.test_resize_tf_crop_and_resize_axes_2_3_cpu _____
graph = node {
input: "X"
input: "roi"
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
a... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'roi': array([0.4, 0.6, 0.6, 0.8], dtype=float32), 'sizes': array([3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97445e9ed0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_tf_crop_and_resize_axes_2_3_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: "roi"
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
a... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'roi': array([0.4, 0.6, 0.6, 0.8], dtype=float32), 'sizes': array([3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', 'roi', '', 'sizes'], the node proto isinput: "X"
E input: "roi"
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "axes"
E type: INTS
E ints: 2
E ints: 3
E }
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "tf_crop_and_resize"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_____ OnnxBackendNodeModelTest.test_resize_tf_crop_and_resize_axes_3_2_cpu _____
graph = node {
input: "X"
input: "roi"
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
a... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'roi': array([0.6, 0.4, 0.8, 0.6], dtype=float32), 'sizes': array([3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f977c7619f0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_tf_crop_and_resize_axes_3_2_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: "roi"
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
a... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'roi': array([0.6, 0.4, 0.8, 0.6], dtype=float32), 'sizes': array([3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', 'roi', '', 'sizes'], the node proto isinput: "X"
E input: "roi"
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "axes"
E type: INTS
E ints: 3
E ints: 2
E }
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "tf_crop_and_resize"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_________ OnnxBackendNodeModelTest.test_resize_tf_crop_and_resize_cpu __________
graph = node {
input: "X"
input: "roi"
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
a... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14.,...], dtype=float32), 'roi': array([0. , 0. , 0.4, 0.6, 1. , 1. , 1.2, 1.7], dtype=float32), 'sizes': array([1, 1, 3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f972c7125c0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_tf_crop_and_resize_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: "roi"
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
a... }
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14.,...], dtype=float32), 'roi': array([0. , 0. , 0.4, 0.6, 1. , 1. , 1.2, 1.7], dtype=float32), 'sizes': array([1, 1, 3, 3])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', 'roi', '', 'sizes'], the node proto isinput: "X"
E input: "roi"
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "tf_crop_and_resize"
E }
E attribute {
E name: "extrapolation_value"
E type: FLOAT
E f: 10
E }
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_cpu _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f9744619660>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "cubic_coeff_a"
E type: FLOAT
E f: -0.5
E }
E attribute {
E name: "exclude_outside"
E type: INT
E i: 1
E }
E attribute {
E name: "mode"
E type: STRING
E s: "cubic"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_upsample_scales_cubic_align_corners_cpu _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97445e9300>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_scales_cubic_align_corners_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "align_corners"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "cubic"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
__ OnnxBackendNodeModelTest.test_resize_upsample_scales_cubic_asymmetric_cpu ___
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f977c776ef0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_scales_cubic_asymmetric_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "asymmetric"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "cubic"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
________ OnnxBackendNodeModelTest.test_resize_upsample_scales_cubic_cpu ________
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97447b9ed0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_scales_cubic_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "mode"
E type: STRING
E s: "cubic"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_upsample_scales_linear_align_corners_cpu _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f977c7604c0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_scales_linear_align_corners_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "align_corners"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_______ OnnxBackendNodeModelTest.test_resize_upsample_scales_linear_cpu ________
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f977c51c8b0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_scales_linear_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([1., 1., 2., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_upsample_scales_linear_half_pixel_symmetric_cpu _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([1. , 1. , 2.3 , 2.94], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f972c3eeb60>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_scales_linear_half_pixel_symmetric_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([1. , 1. , 2.3 , 2.94], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "half_pixel_symmetric"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "linear"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
__ OnnxBackendNodeModelTest.test_resize_upsample_scales_nearest_axes_2_3_cpu ___
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 6
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([2., 3.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97840e4880>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_scales_nearest_axes_2_3_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 6
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([2., 3.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "axes"
E type: INTS
E ints: 2
E ints: 3
E }
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
__ OnnxBackendNodeModelTest.test_resize_upsample_scales_nearest_axes_3_2_cpu ___
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 6
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([3., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f9744721540>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_scales_nearest_axes_3_2_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 6
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([3., 2.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "axes"
E type: INTS
E ints: 3
E ints: 2
E }
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_______ OnnxBackendNodeModelTest.test_resize_upsample_scales_nearest_cpu _______
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 6
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([1., 1., 2., 3.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97446db190>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_scales_nearest_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: "scales"
output: "Y"
name: "node_0"
op_type: "Resize"
attribute {
... }
dim {
dim_value: 4
}
dim {
dim_value: 6
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'scales': array([1., 1., 2., 3.], dtype=float32)}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', 'scales'], the node proto isinput: "X"
E input: ""
E input: "scales"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
________ OnnxBackendNodeModelTest.test_resize_upsample_sizes_cubic_cpu _________
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 9
}
dim {
dim_value: 10
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([ 1, 1, 9, 10])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97447b9c90>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_sizes_cubic_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 9
}
dim {
dim_value: 10
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([ 1, 1, 9, 10])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "mode"
E type: STRING
E s: "cubic"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
___ OnnxBackendNodeModelTest.test_resize_upsample_sizes_nearest_axes_2_3_cpu ___
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 7
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'sizes': array([7, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f9744721990>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_sizes_nearest_axes_2_3_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 7
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'sizes': array([7, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "axes"
E type: INTS
E ints: 2
E ints: 3
E }
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
___ OnnxBackendNodeModelTest.test_resize_upsample_sizes_nearest_axes_3_2_cpu ___
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 7
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'sizes': array([8, 7])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f977c623d30>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_sizes_nearest_axes_3_2_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 7
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'sizes': array([8, 7])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "axes"
E type: INTS
E ints: 3
E ints: 2
E }
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_upsample_sizes_nearest_ceil_half_pixel_cpu _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 8, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f977c774dc0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_sizes_nearest_ceil_half_pixel_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 8, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "half_pixel"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E attribute {
E name: "nearest_mode"
E type: STRING
E s: "ceil"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_______ OnnxBackendNodeModelTest.test_resize_upsample_sizes_nearest_cpu ________
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 7
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'sizes': array([1, 1, 7, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97445e8e50>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_sizes_nearest_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 7
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'sizes': array([1, 1, 7, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_upsample_sizes_nearest_floor_align_corners_cpu _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 8, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f977c7604c0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_sizes_nearest_floor_align_corners_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 8, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "align_corners"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E attribute {
E name: "nearest_mode"
E type: STRING
E s: "floor"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
__ OnnxBackendNodeModelTest.test_resize_upsample_sizes_nearest_not_larger_cpu __
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'sizes': array([7, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f97087c44c0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_sizes_nearest_not_larger_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[1., 2.],
[3., 4.]]]], dtype=float32), 'sizes': array([7, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "axes"
E type: INTS
E ints: 2
E ints: 3
E }
E attribute {
E name: "keep_aspect_ratio_policy"
E type: STRING
E s: "not_smaller"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
_ OnnxBackendNodeModelTest.test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_cpu _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 8, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f977c621660>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "X"
input: ""
input: ""
input: "sizes"
output: "Y"
name: "node_0"
op_type: "Resize"
attr... }
dim {
dim_value: 8
}
dim {
dim_value: 8
}
}
}
}
}
tensor_dict = {'X': array([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=float32), 'sizes': array([1, 1, 8, 8])}
opset = [domain: ""
version: 19
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['X', '', '', 'sizes'], the node proto isinput: "X"
E input: ""
E input: ""
E input: "sizes"
E output: "Y"
E name: "node_0"
E op_type: "Resize"
E attribute {
E name: "coordinate_transformation_mode"
E type: STRING
E s: "asymmetric"
E }
E attribute {
E name: "mode"
E type: STRING
E s: "nearest"
E }
E attribute {
E name: "nearest_mode"
E type: STRING
E s: "round_prefer_ceil"
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
__________________ OnnxBackendNodeModelTest.test_scan_sum_cpu __________________
graph = node {
input: ""
input: "initial"
input: "x"
output: "y"
output: "z"
name: "node_0"
op_type: "Scan"
at... }
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'initial': array([[0., 0.]], dtype=float32), 'x': array([[[1., 2.],
[3., 4.],
[5., 6.]]], dtype=float32)}
opset = [domain: ""
version: 8
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f977c51d6c0>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_scan_sum_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: ""
input: "initial"
input: "x"
output: "y"
output: "z"
name: "node_0"
op_type: "Scan"
at... }
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'initial': array([[0., 0.]], dtype=float32), 'x': array([[[1., 2.],
[3., 4.],
[5., 6.]]], dtype=float32)}
opset = [domain: ""
version: 8
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['', 'initial', 'x'], the node proto isinput: ""
E input: "initial"
E input: "x"
E output: "y"
E output: "z"
E name: "node_0"
E op_type: "Scan"
E attribute {
E name: "body"
E type: GRAPH
E g {
E node {
E input: "sum_in"
E input: "next"
E output: "sum_out"
E op_type: "Add"
E }
E node {
E input: "sum_out"
E output: "scan_out"
E op_type: "Identity"
E }
E name: "scan_body"
E input {
E name: "sum_in"
E type {
E tensor_type {
E elem_type: 1
E shape {
E dim {
E dim_value: 2
E }
E }
E }
E }
E }
E input {
E name: "next"
E type {
E tensor_type {
E elem_type: 1
E shape {
E dim {
E dim_value: 2
E }
E }
E }
E }
E }
E output {
E name: "sum_out"
E type {
E tensor_type {
E elem_type: 1
E shape {
E dim {
E dim_value: 2
E }
E }
E }
E }
E }
E output {
E name: "scan_out"
E type {
E tensor_type {
E elem_type: 1
E shape {
E dim {
E dim_value: 2
E }
E }
E }
E }
E }
E }
E }
E attribute {
E name: "num_scan_inputs"
E type: INT
E i: 1
E }
E .
jaxonnxruntime/call_onnx.py:115: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 8. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 8. Set to 1.
_ OnnxBackendNodeModelTest.test_sequence_map_add_1_sequence_1_tensor_expanded_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_sequence_map_add_1_sequence_1_tensor_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:84: in call_onnx_model
**onnx_utils.maybe_convert_to_dict(inputs, input_names), **model_params
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
inputs = [[array([0.5488135 , 0.71518934, 0.60276335, 0.5448832 , 0.4236548 ,
0.6458941 , 0.4375872 , 0.891773 , 0.9636...5032, 0.56843394, 0.0187898 ,
0.6176355 , 0.6120957 , 0.616934 , 0.94374806, 0.6818203 ],
dtype=float32)]
input_names = ['x0']
def maybe_convert_to_dict(
inputs: Union[Sequence[Any], Dict[str, Any]],
input_names: Optional[Sequence[Any]] = None,
):
"""Convert inputs to a dictionary with input_names as keys."""
if isinstance(inputs, dict):
return inputs
elif isinstance(inputs, Sequence):
if input_names is None:
raise ValueError("Should provide input names if `inputs` is a Sequence!")
> assert len(inputs) == len(input_names)
E AssertionError
jaxonnxruntime/core/onnx_utils.py:116: AssertionError
___ OnnxBackendNodeModelTest.test_sequence_map_add_2_sequences_expanded_cpu ____
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_sequence_map_add_2_sequences_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:84: in call_onnx_model
**onnx_utils.maybe_convert_to_dict(inputs, input_names), **model_params
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
inputs = [[array([0.8579456 , 0.8472517 , 0.6235637 , 0.3843817 , 0.2975346 ,
0.05671298], dtype=float32), array([0.2726...=float32), array([0.87008727], dtype=float32), array([0.47360805, 0.8009108 , 0.5204775 , 0.67887956], dtype=float32)]]
input_names = ['x0']
def maybe_convert_to_dict(
inputs: Union[Sequence[Any], Dict[str, Any]],
input_names: Optional[Sequence[Any]] = None,
):
"""Convert inputs to a dictionary with input_names as keys."""
if isinstance(inputs, dict):
return inputs
elif isinstance(inputs, Sequence):
if input_names is None:
raise ValueError("Should provide input names if `inputs` is a Sequence!")
> assert len(inputs) == len(input_names)
E AssertionError
jaxonnxruntime/core/onnx_utils.py:116: AssertionError
_ OnnxBackendNodeModelTest.test_sequence_map_identity_1_sequence_1_tensor_expanded_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_sequence_map_identity_1_sequence_1_tensor_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:84: in call_onnx_model
**onnx_utils.maybe_convert_to_dict(inputs, input_names), **model_params
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
inputs = [[array([0.84426576, 0.8579456 , 0.8472517 , 0.6235637 , 0.3843817 ,
0.2975346 ], dtype=float32), array([0.9636...rray([0.95715517, 0.14035077, 0.87008727, 0.47360805], dtype=float32)], array([0.46147937, 0.7805292 ], dtype=float32)]
input_names = ['x0']
def maybe_convert_to_dict(
inputs: Union[Sequence[Any], Dict[str, Any]],
input_names: Optional[Sequence[Any]] = None,
):
"""Convert inputs to a dictionary with input_names as keys."""
if isinstance(inputs, dict):
return inputs
elif isinstance(inputs, Sequence):
if input_names is None:
raise ValueError("Should provide input names if `inputs` is a Sequence!")
> assert len(inputs) == len(input_names)
E AssertionError
jaxonnxruntime/core/onnx_utils.py:116: AssertionError
_ OnnxBackendNodeModelTest.test_sequence_map_identity_2_sequences_expanded_cpu _
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_sequence_map_identity_2_sequences_expanded_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:84: in call_onnx_model
**onnx_utils.maybe_convert_to_dict(inputs, input_names), **model_params
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
inputs = [[array([0.84426576, 0.8579456 , 0.8472517 , 0.6235637 , 0.3843817 ,
0.2975346 ], dtype=float32), array([0.9636... 0.2645556 , 0.7742337 , 0.45615032], dtype=float32), array([0.13521817, 0.324141 , 0.14967486], dtype=float32)]]
input_names = ['x0']
def maybe_convert_to_dict(
inputs: Union[Sequence[Any], Dict[str, Any]],
input_names: Optional[Sequence[Any]] = None,
):
"""Convert inputs to a dictionary with input_names as keys."""
if isinstance(inputs, dict):
return inputs
elif isinstance(inputs, Sequence):
if input_names is None:
raise ValueError("Should provide input names if `inputs` is a Sequence!")
> assert len(inputs) == len(input_names)
E AssertionError
jaxonnxruntime/core/onnx_utils.py:116: AssertionError
____________________ OnnxBackendNodeModelTest.test_stft_cpu ____________________
graph = node {
input: "signal"
input: "frame_step"
input: ""
input: "frame_length"
output: "output"
name: "node_0"... }
dim {
dim_value: 9
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'frame_length': array(16), 'frame_step': array(8), 'signal': array([[[ 0.],
[ 1.],
[ 2.],
...21.],
[122.],
[123.],
[124.],
[125.],
[126.],
[127.]]], dtype=float32)}
opset = [domain: ""
version: 17
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
jaxonnxruntime/call_onnx.py:113:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.0 = <list_iterator object at 0x7f9784590190>
> node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
E KeyError: ''
jaxonnxruntime/call_onnx.py:113: KeyError
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_stft_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "signal"
input: "frame_step"
input: ""
input: "frame_length"
output: "output"
name: "node_0"... }
dim {
dim_value: 9
}
dim {
dim_value: 2
}
}
}
}
}
tensor_dict = {'frame_length': array(16), 'frame_step': array(8), 'signal': array([[[ 0.],
[ 1.],
[ 2.],
...21.],
[122.],
[123.],
[124.],
[125.],
[126.],
[127.]]], dtype=float32)}
opset = [domain: ""
version: 17
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
> raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
E ValueError: Fail to get the input tensor of node input names['signal', 'frame_step', '', 'frame_length'], the node proto isinput: "signal"
E input: "frame_step"
E input: ""
E input: "frame_length"
E output: "output"
E name: "node_0"
E op_type: "STFT"
E .
jaxonnxruntime/call_onnx.py:115: ValueError
______________ OnnxBackendNodeModelTest.test_training_dropout_cpu ______________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_training_dropout_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "x"
input: "r"
input: "t"
output: "y"
name: "node_0"
op_type: "Dropout"
attribute {
na... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'r': array(0.75, dtype=float32), 't': array(True), 'x': array([[[ 1.7640524 , 0.4001572 , 0.978738 , 2.2408931 ,
... -0.02818223],
[ 0.42833188, 0.06651722, 0.3024719 , -0.6343221 ,
-0.36274117]]], dtype=float32)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
jit_func = _get_jit_func(node, node_inputs, handlers=handlers)
jit_func_dict[node.name] = jit_func
if config.jaxort_experimental_support_abtract_input_shape:
outputs = jax.eval_shape(jit_func, *node_inputs, **node.attrs_dict)
else:
> outputs = jit_func(*node_inputs, **node.attrs_dict)
E ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'onnx_dropout' while trying to hash an object of type <class 'numpy.ndarray'>, True. The error was:
E TypeError: unhashable type: 'numpy.ndarray'
jaxonnxruntime/call_onnx.py:126: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
__________ OnnxBackendNodeModelTest.test_training_dropout_default_cpu __________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_training_dropout_default_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "x"
input: "r"
input: "t"
output: "y"
name: "node_0"
op_type: "Dropout"
attribute {
na... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'r': array(0.5, dtype=float32), 't': array(True), 'x': array([[[ 1.7640524 , 0.4001572 , 0.978738 , 2.2408931 ,
... -0.02818223],
[ 0.42833188, 0.06651722, 0.3024719 , -0.6343221 ,
-0.36274117]]], dtype=float32)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
jit_func = _get_jit_func(node, node_inputs, handlers=handlers)
jit_func_dict[node.name] = jit_func
if config.jaxort_experimental_support_abtract_input_shape:
outputs = jax.eval_shape(jit_func, *node_inputs, **node.attrs_dict)
else:
> outputs = jit_func(*node_inputs, **node.attrs_dict)
E ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'onnx_dropout' while trying to hash an object of type <class 'numpy.ndarray'>, True. The error was:
E TypeError: unhashable type: 'numpy.ndarray'
jaxonnxruntime/call_onnx.py:126: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
_______ OnnxBackendNodeModelTest.test_training_dropout_default_mask_cpu ________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_training_dropout_default_mask_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "x"
input: "r"
input: "t"
output: "y"
output: "z"
name: "node_0"
op_type: "Dropout"
attr... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'r': array(0.5, dtype=float32), 't': array(True), 'x': array([[[ 1.7640524 , 0.4001572 , 0.978738 , 2.2408931 ,
... -0.02818223],
[ 0.42833188, 0.06651722, 0.3024719 , -0.6343221 ,
-0.36274117]]], dtype=float32)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
jit_func = _get_jit_func(node, node_inputs, handlers=handlers)
jit_func_dict[node.name] = jit_func
if config.jaxort_experimental_support_abtract_input_shape:
outputs = jax.eval_shape(jit_func, *node_inputs, **node.attrs_dict)
else:
> outputs = jit_func(*node_inputs, **node.attrs_dict)
E ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'onnx_dropout' while trying to hash an object of type <class 'numpy.ndarray'>, True. The error was:
E TypeError: unhashable type: 'numpy.ndarray'
jaxonnxruntime/call_onnx.py:126: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
___________ OnnxBackendNodeModelTest.test_training_dropout_mask_cpu ____________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_training_dropout_mask_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "x"
input: "r"
input: "t"
output: "y"
output: "z"
name: "node_0"
op_type: "Dropout"
attr... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'r': array(0.75, dtype=float32), 't': array(True), 'x': array([[[ 1.7640524 , 0.4001572 , 0.978738 , 2.2408931 ,
... -0.02818223],
[ 0.42833188, 0.06651722, 0.3024719 , -0.6343221 ,
-0.36274117]]], dtype=float32)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
jit_func = _get_jit_func(node, node_inputs, handlers=handlers)
jit_func_dict[node.name] = jit_func
if config.jaxort_experimental_support_abtract_input_shape:
outputs = jax.eval_shape(jit_func, *node_inputs, **node.attrs_dict)
else:
> outputs = jit_func(*node_inputs, **node.attrs_dict)
E ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'onnx_dropout' while trying to hash an object of type <class 'numpy.ndarray'>, True. The error was:
E TypeError: unhashable type: 'numpy.ndarray'
jaxonnxruntime/call_onnx.py:126: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
________ OnnxBackendNodeModelTest.test_training_dropout_zero_ratio_cpu _________
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_training_dropout_zero_ratio_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "x"
input: "r"
input: "t"
output: "y"
name: "node_0"
op_type: "Dropout"
attribute {
na... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'r': array(0., dtype=float32), 't': array(True), 'x': array([[[ 1.7640524 , 0.4001572 , 0.978738 , 2.2408931 ,
... -0.02818223],
[ 0.42833188, 0.06651722, 0.3024719 , -0.6343221 ,
-0.36274117]]], dtype=float32)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
jit_func = _get_jit_func(node, node_inputs, handlers=handlers)
jit_func_dict[node.name] = jit_func
if config.jaxort_experimental_support_abtract_input_shape:
outputs = jax.eval_shape(jit_func, *node_inputs, **node.attrs_dict)
else:
> outputs = jit_func(*node_inputs, **node.attrs_dict)
E ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'onnx_dropout' while trying to hash an object of type <class 'numpy.ndarray'>, True. The error was:
E TypeError: unhashable type: 'numpy.ndarray'
jaxonnxruntime/call_onnx.py:126: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
______ OnnxBackendNodeModelTest.test_training_dropout_zero_ratio_mask_cpu ______
test_self = <onnx_ops_test.OnnxBackendNodeModelTest testMethod=test_training_dropout_zero_ratio_mask_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
graph = node {
input: "x"
input: "r"
input: "t"
output: "y"
output: "z"
name: "node_0"
op_type: "Dropout"
attr... }
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
tensor_dict = {'r': array(0., dtype=float32), 't': array(True), 'x': array([[[ 1.7640524 , 0.4001572 , 0.978738 , 2.2408931 ,
... -0.02818223],
[ 0.42833188, 0.06651722, 0.3024719 , -0.6343221 ,
-0.36274117]]], dtype=float32)}
opset = [domain: ""
version: 13
]
def call_onnx_graph(
graph: onnx.GraphProto,
tensor_dict: Dict[str, Any],
opset: ... = None,
) -> Callable[..., Any]:
"""Convert ONNX.GraphProto to jax_func with ONNX.GraphProto.initializer as parameters."""
tensor_ref_dict = build_ref_dict(graph)
graph_helper = OnnxGraph(graph)
# step 1: Trace those static info
jit_func_dict = {}
onnx_node_dict = {}
if opset is None:
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
handlers = _get_all_handlers(opset)
node_execute_order_list = graph_helper.topological_sort()
logger.info('Start tracing the jax_func model to get some static info')
for node_proto in node_execute_order_list:
node = OnnxNode(node_proto, graph_helper)
onnx_node_dict[node.name] = node
try:
node_inputs = [tensor_dict[x] for x in node.inputs + node.subgraph_inputs]
except Exception as e:
raise ValueError(
'Fail to get the input tensor of node input names'
f'{node.inputs + node.subgraph_inputs}, the node proto is'
f'{node.node_proto}.'
) from e
jit_func = _get_jit_func(node, node_inputs, handlers=handlers)
jit_func_dict[node.name] = jit_func
if config.jaxort_experimental_support_abtract_input_shape:
outputs = jax.eval_shape(jit_func, *node_inputs, **node.attrs_dict)
else:
> outputs = jit_func(*node_inputs, **node.attrs_dict)
E ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'onnx_dropout' while trying to hash an object of type <class 'numpy.ndarray'>, True. The error was:
E TypeError: unhashable type: 'numpy.ndarray'
jaxonnxruntime/call_onnx.py:126: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 13. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 13. Set to 1.
___________ OnnxBackendPyTorchConvertedModelTest.test_MaxPool1d_cpu ____________
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool1d_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/maxpool.py:159: in onnx_maxpool
return lax.reduce_window(
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:78: in reduce_window
return monoid_reducer(operand, window_dimensions, window_strides, padding,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:156: in _reduce_window_max
return reduce_window_max_p.bind(
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:380: in bind
return self.bind_with_trace(find_top_trace(args), args, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:383: in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1855: in process_primitive
return self.default_process_primitive(primitive, tracers, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1859: in default_process_primitive
out_avals, effects = primitive.abstract_eval(*avals, **params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:416: in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/utils.py:63: in standard_abstract_eval
return core.ShapedArray(shape_rule(*avals, **kwargs),
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:430: in _common_reduce_window_shape_rule
return reduce_window_shape_tuple(operand.shape, window_dimensions,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:442: in reduce_window_shape_tuple
operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ss = ((2, 10, 4), (0, 0), (0, 0))
def sum_shapes(*ss: Shape) -> Shape:
> return tuple(map(sum_dim, *ss))
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: safe_map() argument 2 is shorter than argument 1
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:1983: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool1d_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ceil_mode = 0, strides = (1, 1, 4), pads = ((0, 0), (0, 0))
dilations = (1, 1, 1)
@functools.partial(
jit,
static_argnames=(
"ceil_mode",
"strides",
"pads",
"dilations",
"kernel_shape",
),
)
def onnx_maxpool(
*input_args,
ceil_mode: int,
strides: Sequence[int],
pads: Union[Sequence[tuple[int, int]], str],
dilations: Sequence[int],
kernel_shape: Sequence[int],
storage_order: int,
):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#MaxPool for more details."""
assert len(input_args) == 1
x = input_args[0]
if ceil_mode != 0:
raise ValueError("ceil_mode = 1 is not implement yet.")
> return lax.reduce_window(
x, -jnp.inf, lax.max, kernel_shape, strides, pads, None, dilations
)
E ValueError: safe_map() argument 2 is shorter than argument 1
jaxonnxruntime/onnx_ops/maxpool.py:159: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
________ OnnxBackendPyTorchConvertedModelTest.test_MaxPool1d_stride_cpu ________
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool1d_stride_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/maxpool.py:159: in onnx_maxpool
return lax.reduce_window(
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:78: in reduce_window
return monoid_reducer(operand, window_dimensions, window_strides, padding,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:156: in _reduce_window_max
return reduce_window_max_p.bind(
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:380: in bind
return self.bind_with_trace(find_top_trace(args), args, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:383: in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1855: in process_primitive
return self.default_process_primitive(primitive, tracers, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1859: in default_process_primitive
out_avals, effects = primitive.abstract_eval(*avals, **params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:416: in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/utils.py:63: in standard_abstract_eval
return core.ShapedArray(shape_rule(*avals, **kwargs),
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:430: in _common_reduce_window_shape_rule
return reduce_window_shape_tuple(operand.shape, window_dimensions,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:442: in reduce_window_shape_tuple
operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ss = ((2, 10, 4), (0, 0), (0, 0))
def sum_shapes(*ss: Shape) -> Shape:
> return tuple(map(sum_dim, *ss))
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: safe_map() argument 2 is shorter than argument 1
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:1983: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool1d_stride_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ceil_mode = 0, strides = (1, 1, 4), pads = ((0, 0), (0, 0))
dilations = (1, 1, 1)
@functools.partial(
jit,
static_argnames=(
"ceil_mode",
"strides",
"pads",
"dilations",
"kernel_shape",
),
)
def onnx_maxpool(
*input_args,
ceil_mode: int,
strides: Sequence[int],
pads: Union[Sequence[tuple[int, int]], str],
dilations: Sequence[int],
kernel_shape: Sequence[int],
storage_order: int,
):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#MaxPool for more details."""
assert len(input_args) == 1
x = input_args[0]
if ceil_mode != 0:
raise ValueError("ceil_mode = 1 is not implement yet.")
> return lax.reduce_window(
x, -jnp.inf, lax.max, kernel_shape, strides, pads, None, dilations
)
E ValueError: safe_map() argument 2 is shorter than argument 1
jaxonnxruntime/onnx_ops/maxpool.py:159: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
_ OnnxBackendPyTorchConvertedModelTest.test_MaxPool1d_stride_padding_dilation_cpu _
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool1d_stride_padding_dilation_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/maxpool.py:159: in onnx_maxpool
return lax.reduce_window(
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:78: in reduce_window
return monoid_reducer(operand, window_dimensions, window_strides, padding,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:156: in _reduce_window_max
return reduce_window_max_p.bind(
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:380: in bind
return self.bind_with_trace(find_top_trace(args), args, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:383: in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1855: in process_primitive
return self.default_process_primitive(primitive, tracers, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1859: in default_process_primitive
out_avals, effects = primitive.abstract_eval(*avals, **params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:416: in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/utils.py:63: in standard_abstract_eval
return core.ShapedArray(shape_rule(*avals, **kwargs),
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:430: in _common_reduce_window_shape_rule
return reduce_window_shape_tuple(operand.shape, window_dimensions,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:442: in reduce_window_shape_tuple
operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ss = ((1, 1, 220000), (0, 100), (0, 100))
def sum_shapes(*ss: Shape) -> Shape:
> return tuple(map(sum_dim, *ss))
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: safe_map() argument 2 is shorter than argument 1
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:1983: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool1d_stride_padding_dilation_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ceil_mode = 0, strides = (1, 1, 10), pads = ((0, 0), (100, 100))
dilations = (1, 1, 10)
@functools.partial(
jit,
static_argnames=(
"ceil_mode",
"strides",
"pads",
"dilations",
"kernel_shape",
),
)
def onnx_maxpool(
*input_args,
ceil_mode: int,
strides: Sequence[int],
pads: Union[Sequence[tuple[int, int]], str],
dilations: Sequence[int],
kernel_shape: Sequence[int],
storage_order: int,
):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#MaxPool for more details."""
assert len(input_args) == 1
x = input_args[0]
if ceil_mode != 0:
raise ValueError("ceil_mode = 1 is not implement yet.")
> return lax.reduce_window(
x, -jnp.inf, lax.max, kernel_shape, strides, pads, None, dilations
)
E ValueError: safe_map() argument 2 is shorter than argument 1
jaxonnxruntime/onnx_ops/maxpool.py:159: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 12. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 12. Set to 1.
___________ OnnxBackendPyTorchConvertedModelTest.test_MaxPool3d_cpu ____________
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool3d_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/maxpool.py:159: in onnx_maxpool
return lax.reduce_window(
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:78: in reduce_window
return monoid_reducer(operand, window_dimensions, window_strides, padding,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:156: in _reduce_window_max
return reduce_window_max_p.bind(
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:380: in bind
return self.bind_with_trace(find_top_trace(args), args, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:383: in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1855: in process_primitive
return self.default_process_primitive(primitive, tracers, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1859: in default_process_primitive
out_avals, effects = primitive.abstract_eval(*avals, **params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:416: in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/utils.py:63: in standard_abstract_eval
return core.ShapedArray(shape_rule(*avals, **kwargs),
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:430: in _common_reduce_window_shape_rule
return reduce_window_shape_tuple(operand.shape, window_dimensions,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:442: in reduce_window_shape_tuple
operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ss = ((2, 3, 5, 5, 5), (0, 0, 0, 0, 0, 0), (0, 0, 0, 0, 0, 0))
def sum_shapes(*ss: Shape) -> Shape:
> return tuple(map(sum_dim, *ss))
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: safe_map() argument 2 is longer than argument 1
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:1983: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool3d_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ceil_mode = 0, strides = (1, 1, 2, 2, 2)
pads = ((0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0))
dilations = (1, 1, 1, 1, 1)
@functools.partial(
jit,
static_argnames=(
"ceil_mode",
"strides",
"pads",
"dilations",
"kernel_shape",
),
)
def onnx_maxpool(
*input_args,
ceil_mode: int,
strides: Sequence[int],
pads: Union[Sequence[tuple[int, int]], str],
dilations: Sequence[int],
kernel_shape: Sequence[int],
storage_order: int,
):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#MaxPool for more details."""
assert len(input_args) == 1
x = input_args[0]
if ceil_mode != 0:
raise ValueError("ceil_mode = 1 is not implement yet.")
> return lax.reduce_window(
x, -jnp.inf, lax.max, kernel_shape, strides, pads, None, dilations
)
E ValueError: safe_map() argument 2 is longer than argument 1
jaxonnxruntime/onnx_ops/maxpool.py:159: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
________ OnnxBackendPyTorchConvertedModelTest.test_MaxPool3d_stride_cpu ________
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool3d_stride_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/maxpool.py:159: in onnx_maxpool
return lax.reduce_window(
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:78: in reduce_window
return monoid_reducer(operand, window_dimensions, window_strides, padding,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:156: in _reduce_window_max
return reduce_window_max_p.bind(
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:380: in bind
return self.bind_with_trace(find_top_trace(args), args, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:383: in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1855: in process_primitive
return self.default_process_primitive(primitive, tracers, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1859: in default_process_primitive
out_avals, effects = primitive.abstract_eval(*avals, **params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:416: in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/utils.py:63: in standard_abstract_eval
return core.ShapedArray(shape_rule(*avals, **kwargs),
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:430: in _common_reduce_window_shape_rule
return reduce_window_shape_tuple(operand.shape, window_dimensions,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:442: in reduce_window_shape_tuple
operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ss = ((2, 3, 5, 5, 5), (0, 0, 0, 0, 0, 0), (0, 0, 0, 0, 0, 0))
def sum_shapes(*ss: Shape) -> Shape:
> return tuple(map(sum_dim, *ss))
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: safe_map() argument 2 is longer than argument 1
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:1983: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool3d_stride_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ceil_mode = 0, strides = (1, 1, 2, 2, 2)
pads = ((0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0))
dilations = (1, 1, 1, 1, 1)
@functools.partial(
jit,
static_argnames=(
"ceil_mode",
"strides",
"pads",
"dilations",
"kernel_shape",
),
)
def onnx_maxpool(
*input_args,
ceil_mode: int,
strides: Sequence[int],
pads: Union[Sequence[tuple[int, int]], str],
dilations: Sequence[int],
kernel_shape: Sequence[int],
storage_order: int,
):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#MaxPool for more details."""
assert len(input_args) == 1
x = input_args[0]
if ceil_mode != 0:
raise ValueError("ceil_mode = 1 is not implement yet.")
> return lax.reduce_window(
x, -jnp.inf, lax.max, kernel_shape, strides, pads, None, dilations
)
E ValueError: safe_map() argument 2 is longer than argument 1
jaxonnxruntime/onnx_ops/maxpool.py:159: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
____ OnnxBackendPyTorchConvertedModelTest.test_MaxPool3d_stride_padding_cpu ____
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool3d_stride_padding_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/maxpool.py:159: in onnx_maxpool
return lax.reduce_window(
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:78: in reduce_window
return monoid_reducer(operand, window_dimensions, window_strides, padding,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:156: in _reduce_window_max
return reduce_window_max_p.bind(
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:380: in bind
return self.bind_with_trace(find_top_trace(args), args, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:383: in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1855: in process_primitive
return self.default_process_primitive(primitive, tracers, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1859: in default_process_primitive
out_avals, effects = primitive.abstract_eval(*avals, **params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:416: in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/utils.py:63: in standard_abstract_eval
return core.ShapedArray(shape_rule(*avals, **kwargs),
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:430: in _common_reduce_window_shape_rule
return reduce_window_shape_tuple(operand.shape, window_dimensions,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:442: in reduce_window_shape_tuple
operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ss = ((2, 3, 5, 5, 5), (0, 0, 0, 1, 1, 1), (0, 0, 0, 1, 1, 1))
def sum_shapes(*ss: Shape) -> Shape:
> return tuple(map(sum_dim, *ss))
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: safe_map() argument 2 is longer than argument 1
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:1983: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_MaxPool3d_stride_padding_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ceil_mode = 0, strides = (1, 1, 2, 2, 2)
pads = ((0, 0), (0, 0), (0, 0), (1, 1), (1, 1), (1, 1))
dilations = (1, 1, 1, 1, 1)
@functools.partial(
jit,
static_argnames=(
"ceil_mode",
"strides",
"pads",
"dilations",
"kernel_shape",
),
)
def onnx_maxpool(
*input_args,
ceil_mode: int,
strides: Sequence[int],
pads: Union[Sequence[tuple[int, int]], str],
dilations: Sequence[int],
kernel_shape: Sequence[int],
storage_order: int,
):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#MaxPool for more details."""
assert len(input_args) == 1
x = input_args[0]
if ceil_mode != 0:
raise ValueError("ceil_mode = 1 is not implement yet.")
> return lax.reduce_window(
x, -jnp.inf, lax.max, kernel_shape, strides, pads, None, dilations
)
E ValueError: safe_map() argument 2 is longer than argument 1
jaxonnxruntime/onnx_ops/maxpool.py:159: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
______ OnnxBackendPyTorchConvertedModelTest.test_PReLU_1d_multiparam_cpu _______
shapes = ((3,), (2, 3, 4))
def broadcast_shapes(*shapes):
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
# NOTE: We have both cached and uncached versions to handle Tracers in shapes.
try:
> return _broadcast_shapes_cached(*shapes)
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:147:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../test_py_env/lib/python3.11/site-packages/jax/_src/util.py:284: in wrapper
return cached(config._trace_context(), *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/util.py:277: in cached
return f(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:153: in _broadcast_shapes_cached
return _broadcast_shapes_uncached(*shapes)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
shapes = ((3,), (2, 3, 4)), fst = (3,), rst = [(2, 3, 4)]
shape_list = [(1, 1, 3), (2, 3, 4)]
def _broadcast_shapes_uncached(*shapes):
_validate_shapes(shapes)
fst, *rst = shapes
if not rst: return fst
# First check if we need only rank promotion (and not singleton-broadcasting).
try: return _reduce(_broadcast_ranks, rst, fst)
except ValueError: pass
# Next try singleton-broadcasting, padding out ranks using singletons.
ndim = _max(len(shape) for shape in shapes)
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
result_shape = _try_broadcast_shapes(shape_list)
if result_shape is None:
> raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
E ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (2, 3, 4)]
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:169: ValueError
During handling of the above exception, another exception occurred:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_PReLU_1d_multiparam_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/prelu.py:68: in onnx_prelu
return jnp.where(data >= 0, data, slope * data)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:791: in op
return getattr(self.aval, f"_{name}")(self, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258: in deferring_binary_op
return binary_op(*args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:96: in fn
x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:357: in promote_args
return promote_shapes(fun_name, *promote_dtypes(*args))
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:247: in promote_shapes
result_rank = len(lax.broadcast_shapes(*shapes))
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:149: in broadcast_shapes
return _broadcast_shapes_uncached(*shapes)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
shapes = ((3,), (2, 3, 4)), fst = (3,), rst = [(2, 3, 4)]
shape_list = [(1, 1, 3), (2, 3, 4)]
def _broadcast_shapes_uncached(*shapes):
_validate_shapes(shapes)
fst, *rst = shapes
if not rst: return fst
# First check if we need only rank promotion (and not singleton-broadcasting).
try: return _reduce(_broadcast_ranks, rst, fst)
except ValueError: pass
# Next try singleton-broadcasting, padding out ranks using singletons.
ndim = _max(len(shape) for shape in shapes)
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
result_shape = _try_broadcast_shapes(shape_list)
if result_shape is None:
> raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (2, 3, 4)]
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:169: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_PReLU_1d_multiparam_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/prelu.py:68: in onnx_prelu
return jnp.where(data >= 0, data, slope * data)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:791: in op
return getattr(self.aval, f"_{name}")(self, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258: in deferring_binary_op
return binary_op(*args)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:96: in fn
x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:357: in promote_args
return promote_shapes(fun_name, *promote_dtypes(*args))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun_name = 'multiply'
args = (Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=2/0)>, Traced<ShapedArray(float32[2,3,4])>with<DynamicJaxprTrace(level=2/0)>)
nonscalar_ranks = {1, 3}
def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return [lax.asarray(arg) for arg in args]
else:
shapes = [np.shape(arg) for arg in args]
if config.jax_dynamic_shapes:
# With dynamic shapes we don't support singleton-dimension broadcasting;
# we instead broadcast out to the full shape as a temporary workaround.
# TODO(mattjj): revise this workaround
res_shape = lax.broadcast_shapes(*shapes) # Can raise an error!
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
else:
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
nonscalar_ranks = {len(shp) for shp in shapes if shp}
if len(nonscalar_ranks) < 2:
return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
> result_rank = len(lax.broadcast_shapes(*shapes))
E ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (2, 3, 4)]
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:247: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
______ OnnxBackendPyTorchConvertedModelTest.test_PReLU_2d_multiparam_cpu _______
shapes = ((3,), (2, 3, 4, 5))
def broadcast_shapes(*shapes):
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
# NOTE: We have both cached and uncached versions to handle Tracers in shapes.
try:
> return _broadcast_shapes_cached(*shapes)
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:147:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../test_py_env/lib/python3.11/site-packages/jax/_src/util.py:284: in wrapper
return cached(config._trace_context(), *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/util.py:277: in cached
return f(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:153: in _broadcast_shapes_cached
return _broadcast_shapes_uncached(*shapes)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
shapes = ((3,), (2, 3, 4, 5)), fst = (3,), rst = [(2, 3, 4, 5)]
shape_list = [(1, 1, 1, 3), (2, 3, 4, 5)]
def _broadcast_shapes_uncached(*shapes):
_validate_shapes(shapes)
fst, *rst = shapes
if not rst: return fst
# First check if we need only rank promotion (and not singleton-broadcasting).
try: return _reduce(_broadcast_ranks, rst, fst)
except ValueError: pass
# Next try singleton-broadcasting, padding out ranks using singletons.
ndim = _max(len(shape) for shape in shapes)
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
result_shape = _try_broadcast_shapes(shape_list)
if result_shape is None:
> raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
E ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (2, 3, 4, 5)]
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:169: ValueError
During handling of the above exception, another exception occurred:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_PReLU_2d_multiparam_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/prelu.py:68: in onnx_prelu
return jnp.where(data >= 0, data, slope * data)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:791: in op
return getattr(self.aval, f"_{name}")(self, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258: in deferring_binary_op
return binary_op(*args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:96: in fn
x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:357: in promote_args
return promote_shapes(fun_name, *promote_dtypes(*args))
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:247: in promote_shapes
result_rank = len(lax.broadcast_shapes(*shapes))
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:149: in broadcast_shapes
return _broadcast_shapes_uncached(*shapes)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
shapes = ((3,), (2, 3, 4, 5)), fst = (3,), rst = [(2, 3, 4, 5)]
shape_list = [(1, 1, 1, 3), (2, 3, 4, 5)]
def _broadcast_shapes_uncached(*shapes):
_validate_shapes(shapes)
fst, *rst = shapes
if not rst: return fst
# First check if we need only rank promotion (and not singleton-broadcasting).
try: return _reduce(_broadcast_ranks, rst, fst)
except ValueError: pass
# Next try singleton-broadcasting, padding out ranks using singletons.
ndim = _max(len(shape) for shape in shapes)
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
result_shape = _try_broadcast_shapes(shape_list)
if result_shape is None:
> raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (2, 3, 4, 5)]
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:169: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_PReLU_2d_multiparam_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/prelu.py:68: in onnx_prelu
return jnp.where(data >= 0, data, slope * data)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:791: in op
return getattr(self.aval, f"_{name}")(self, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258: in deferring_binary_op
return binary_op(*args)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:96: in fn
x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:357: in promote_args
return promote_shapes(fun_name, *promote_dtypes(*args))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun_name = 'multiply'
args = (Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=2/0)>, Traced<ShapedArray(float32[2,3,4,5])>with<DynamicJaxprTrace(level=2/0)>)
nonscalar_ranks = {1, 4}
def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return [lax.asarray(arg) for arg in args]
else:
shapes = [np.shape(arg) for arg in args]
if config.jax_dynamic_shapes:
# With dynamic shapes we don't support singleton-dimension broadcasting;
# we instead broadcast out to the full shape as a temporary workaround.
# TODO(mattjj): revise this workaround
res_shape = lax.broadcast_shapes(*shapes) # Can raise an error!
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
else:
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
nonscalar_ranks = {len(shp) for shp in shapes if shp}
if len(nonscalar_ranks) < 2:
return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
> result_rank = len(lax.broadcast_shapes(*shapes))
E ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (2, 3, 4, 5)]
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:247: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
______ OnnxBackendPyTorchConvertedModelTest.test_PReLU_3d_multiparam_cpu _______
shapes = ((3,), (2, 3, 4, 5, 6))
def broadcast_shapes(*shapes):
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
# NOTE: We have both cached and uncached versions to handle Tracers in shapes.
try:
> return _broadcast_shapes_cached(*shapes)
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:147:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../test_py_env/lib/python3.11/site-packages/jax/_src/util.py:284: in wrapper
return cached(config._trace_context(), *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/util.py:277: in cached
return f(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:153: in _broadcast_shapes_cached
return _broadcast_shapes_uncached(*shapes)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
shapes = ((3,), (2, 3, 4, 5, 6)), fst = (3,), rst = [(2, 3, 4, 5, 6)]
shape_list = [(1, 1, 1, 1, 3), (2, 3, 4, 5, 6)]
def _broadcast_shapes_uncached(*shapes):
_validate_shapes(shapes)
fst, *rst = shapes
if not rst: return fst
# First check if we need only rank promotion (and not singleton-broadcasting).
try: return _reduce(_broadcast_ranks, rst, fst)
except ValueError: pass
# Next try singleton-broadcasting, padding out ranks using singletons.
ndim = _max(len(shape) for shape in shapes)
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
result_shape = _try_broadcast_shapes(shape_list)
if result_shape is None:
> raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
E ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (2, 3, 4, 5, 6)]
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:169: ValueError
During handling of the above exception, another exception occurred:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_PReLU_3d_multiparam_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/prelu.py:68: in onnx_prelu
return jnp.where(data >= 0, data, slope * data)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:791: in op
return getattr(self.aval, f"_{name}")(self, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258: in deferring_binary_op
return binary_op(*args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:96: in fn
x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:357: in promote_args
return promote_shapes(fun_name, *promote_dtypes(*args))
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:247: in promote_shapes
result_rank = len(lax.broadcast_shapes(*shapes))
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:149: in broadcast_shapes
return _broadcast_shapes_uncached(*shapes)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
shapes = ((3,), (2, 3, 4, 5, 6)), fst = (3,), rst = [(2, 3, 4, 5, 6)]
shape_list = [(1, 1, 1, 1, 3), (2, 3, 4, 5, 6)]
def _broadcast_shapes_uncached(*shapes):
_validate_shapes(shapes)
fst, *rst = shapes
if not rst: return fst
# First check if we need only rank promotion (and not singleton-broadcasting).
try: return _reduce(_broadcast_ranks, rst, fst)
except ValueError: pass
# Next try singleton-broadcasting, padding out ranks using singletons.
ndim = _max(len(shape) for shape in shapes)
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
result_shape = _try_broadcast_shapes(shape_list)
if result_shape is None:
> raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (2, 3, 4, 5, 6)]
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:169: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendPyTorchConvertedModelTest testMethod=test_PReLU_3d_multiparam_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
jaxonnxruntime/onnx_ops/prelu.py:68: in onnx_prelu
return jnp.where(data >= 0, data, slope * data)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:791: in op
return getattr(self.aval, f"_{name}")(self, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:258: in deferring_binary_op
return binary_op(*args)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:96: in fn
x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:357: in promote_args
return promote_shapes(fun_name, *promote_dtypes(*args))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun_name = 'multiply'
args = (Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=2/0)>, Traced<ShapedArray(float32[2,3,4,5,6])>with<DynamicJaxprTrace(level=2/0)>)
nonscalar_ranks = {1, 5}
def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return [lax.asarray(arg) for arg in args]
else:
shapes = [np.shape(arg) for arg in args]
if config.jax_dynamic_shapes:
# With dynamic shapes we don't support singleton-dimension broadcasting;
# we instead broadcast out to the full shape as a temporary workaround.
# TODO(mattjj): revise this workaround
res_shape = lax.broadcast_shapes(*shapes) # Can raise an error!
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
else:
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
nonscalar_ranks = {len(shp) for shp in shapes if shp}
if len(nonscalar_ranks) < 2:
return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
> result_rank = len(lax.broadcast_shapes(*shapes))
E ValueError: Incompatible shapes for broadcasting: shapes=[(3,), (2, 3, 4, 5, 6)]
../test_py_env/lib/python3.11/site-packages/jax/_src/numpy/util.py:247: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
_________ OnnxBackendPyTorchOperatorModelTest.test_operator_chunk_cpu __________
test_self = <onnx_ops_test.OnnxBackendPyTorchOperatorModelTest testMethod=test_operator_chunk_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
> self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
jaxonnxruntime/runner.py:361:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:257: in assert_similar_outputs
np.testing.assert_allclose(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x7f97ec159c60>, array([0.], dtype=float32), array([0., 1.], dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0.001, atol=1e-07', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=0.001, atol=1e-07
E
E (shapes (1,), (2,) mismatch)
E x: array([0.], dtype=float32)
E y: array([0., 1.], dtype=float32)
/usr/lib/python3.11/contextlib.py:81: AssertionError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
_________ OnnxBackendPyTorchOperatorModelTest.test_operator_index_cpu __________
test_self = <onnx_ops_test.OnnxBackendPyTorchOperatorModelTest testMethod=test_operator_index_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:120: in call_onnx_graph
jit_func = _get_jit_func(node, node_inputs, handlers=handlers)
jaxonnxruntime/call_onnx.py:253: in _get_jit_func
return handler.handle(node, inputs, **kwargs)
jaxonnxruntime/core/handler.py:80: in handle
return ver_handle(node, inputs, **kwargs) # pylint: disable=not-callable
jaxonnxruntime/onnx_ops/slice.py:61: in version_1
cls._prepare(node, inputs, onnx_slice)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class 'jaxonnxruntime.onnx_ops.slice.Slice'>
node = <jaxonnxruntime.core.onnx_node.OnnxNode object at 0x7f97f1039a50>
inputs = [array([[0.]], dtype=float32)]
onnx_jax_impl = <PjitFunction of <function onnx_slice at 0x7f97f20f5d00>>
@classmethod
def _prepare(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any
):
> node.attrs_dict['starts'] = tuple(inputs[1].tolist())
E IndexError: list index out of range
jaxonnxruntime/onnx_ops/slice.py:45: IndexError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
________ OnnxBackendPyTorchOperatorModelTest.test_operator_maxpool_cpu _________
test_self = <onnx_ops_test.OnnxBackendPyTorchOperatorModelTest testMethod=test_operator_maxpool_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
jaxonnxruntime/onnx_ops/maxpool.py:159: in onnx_maxpool
return lax.reduce_window(
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:78: in reduce_window
return monoid_reducer(operand, window_dimensions, window_strides, padding,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:156: in _reduce_window_max
return reduce_window_max_p.bind(
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:380: in bind
return self.bind_with_trace(find_top_trace(args), args, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:383: in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1855: in process_primitive
return self.default_process_primitive(primitive, tracers, params)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1859: in default_process_primitive
out_avals, effects = primitive.abstract_eval(*avals, **params)
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:416: in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/utils.py:63: in standard_abstract_eval
return core.ShapedArray(shape_rule(*avals, **kwargs),
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:430: in _common_reduce_window_shape_rule
return reduce_window_shape_tuple(operand.shape, window_dimensions,
../test_py_env/lib/python3.11/site-packages/jax/_src/lax/windowed_reductions.py:442: in reduce_window_shape_tuple
operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ss = ((20, 16, 50), (0, 0), (0, 0))
def sum_shapes(*ss: Shape) -> Shape:
> return tuple(map(sum_dim, *ss))
E jax._src.traceback_util.UnfilteredStackTrace: ValueError: safe_map() argument 2 is shorter than argument 1
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
../test_py_env/lib/python3.11/site-packages/jax/_src/core.py:1983: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendPyTorchOperatorModelTest testMethod=test_operator_maxpool_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ceil_mode = 0, strides = (1, 1, 2), pads = ((0, 0), (0, 0))
dilations = (1, 1, 1)
@functools.partial(
jit,
static_argnames=(
"ceil_mode",
"strides",
"pads",
"dilations",
"kernel_shape",
),
)
def onnx_maxpool(
*input_args,
ceil_mode: int,
strides: Sequence[int],
pads: Union[Sequence[tuple[int, int]], str],
dilations: Sequence[int],
kernel_shape: Sequence[int],
storage_order: int,
):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#MaxPool for more details."""
assert len(input_args) == 1
x = input_args[0]
if ceil_mode != 0:
raise ValueError("ceil_mode = 1 is not implement yet.")
> return lax.reduce_window(
x, -jnp.inf, lax.max, kernel_shape, strides, pads, None, dilations
)
E ValueError: safe_map() argument 2 is shorter than argument 1
jaxonnxruntime/onnx_ops/maxpool.py:159: ValueError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
__________ OnnxBackendPyTorchOperatorModelTest.test_operator_view_cpu __________
test_self = <onnx_ops_test.OnnxBackendPyTorchOperatorModelTest testMethod=test_operator_view_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
> outputs = list(prepared_model.run(inputs))
jaxonnxruntime/runner.py:355:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:250: in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:158: in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
../test_py_env/lib/python3.11/site-packages/jax/_src/api.py:306: in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:505: in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:971: in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:345: in memoized_fun
ans = call(fun, *args)
../test_py_env/lib/python3.11/site-packages/jax/_src/pjit.py:924: in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/profiler.py:314: in wrapper
return func(*args, **kwargs)
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2155: in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
../test_py_env/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2177: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
../test_py_env/lib/python3.11/site-packages/jax/_src/linear_util.py:188: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
axis = 1
input_args = (Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>,)
x = Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>, dim = 1
@functools.partial(jit, static_argnames="axis")
def onnx_flatten(*input_args, axis):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Flatten for more details."""
axis = 1 if axis is None else axis
assert len(input_args) == 1
x = input_args[0]
dim = len(x.shape)
> assert axis < dim and axis >= -dim, f"axis should with [{-dim}, {dim}]"
E jax._src.traceback_util.UnfilteredStackTrace: AssertionError: axis should with [-1, 1]
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
jaxonnxruntime/onnx_ops/flatten.py:91: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
test_self = <onnx_ops_test.OnnxBackendPyTorchOperatorModelTest testMethod=test_operator_view_cpu>
device = 'CPU'
def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, 'model.onnx')
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, 'is_compatible')
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model, device)
):
raise unittest.SkipTest('Not compatible with backend')
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
for test_data_npz in glob.glob(
os.path.join(model_dir, 'test_data_*.npz')
):
test_data = np.load(test_data_npz, encoding='bytes')
inputs = list(test_data['inputs'])
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
raise e
ref_outputs = test_data['outputs']
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, 'test_data_set*')):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f'input_{i}.pb')
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, 'output_*.pb'))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f'output_{i}.pb')
self._load_proto(output_file, ref_outputs, model.graph.output[i].type)
try:
outputs = list(prepared_model.run(inputs))
except NotImplementedError as e:
print(e)
return
except Exception as e:
> raise e
jaxonnxruntime/runner.py:360:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxonnxruntime/runner.py:355: in run
outputs = list(prepared_model.run(inputs))
jaxonnxruntime/backend.py:64: in run
model_func, model_params = call_onnx.call_onnx_model(
jaxonnxruntime/call_onnx.py:86: in call_onnx_model
model_func = call_onnx_graph(graph, tensor_dict, opset=opset)
jaxonnxruntime/call_onnx.py:126: in call_onnx_graph
outputs = jit_func(*node_inputs, **node.attrs_dict)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
axis = 1
input_args = (Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>,)
x = Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>, dim = 1
@functools.partial(jit, static_argnames="axis")
def onnx_flatten(*input_args, axis):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Flatten for more details."""
axis = 1 if axis is None else axis
assert len(input_args) == 1
x = input_args[0]
dim = len(x.shape)
> assert axis < dim and axis >= -dim, f"axis should with [{-dim}, {dim}]"
E AssertionError: axis should with [-1, 1]
jaxonnxruntime/onnx_ops/flatten.py:91: AssertionError
------------------------------ Captured log call -------------------------------
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Acosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Asinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atan in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Atanh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of BitShift in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of CastLike in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of ConstantOfShape in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cos in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Cosh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Einsum in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Erf in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Expand in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of GatherElements in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of LessOrEqual in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of NonZero in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of OneHot in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Range in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sin in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Sinh in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Trilu in domain with max_inclusive_version= 6. Set to 1.
WARNING jaxonnxruntime.core.handler:handler.py:64 Fail to get since_version of Where in domain with max_inclusive_version= 6. Set to 1.
=========================== short test summary info ============================
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_adagrad_cpu - V...
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_adagrad_multiple_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_adam_cpu - Valu...
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_adam_multiple_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_ai_onnx_ml_array_feature_extractor_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_ai_onnx_ml_binarizer_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_cast_FLOAT_to_BFLOAT16_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_cast_FLOAT_to_STRING_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_castlike_FLOAT_to_BFLOAT16_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_center_crop_pad_crop_axes_chw_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_center_crop_pad_crop_axes_hwc_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_center_crop_pad_crop_negative_axes_hwc_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_clip_default_inbounds_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_clip_default_int8_inbounds_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_clip_default_int8_max_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_clip_default_max_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_loop11_cpu - Va...
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_loop13_seq_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_loop16_seq_none_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_maxpool_with_argmax_2d_precomputed_pads_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_maxpool_with_argmax_2d_precomputed_strides_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_momentum_cpu - ...
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_momentum_multiple_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nesterov_momentum_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NC_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_ii_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_ii_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_mean_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_sum_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_optional_has_element_empty_no_input_name_optional_input_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_optional_has_element_empty_no_input_name_tensor_input_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_scales_cubic_align_corners_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_scales_cubic_antialias_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_scales_cubic_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_scales_linear_align_corners_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_scales_linear_antialias_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_scales_linear_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_scales_linear_half_pixel_symmetric_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_scales_nearest_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_sizes_cubic_antialias_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_sizes_cubic_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_sizes_linear_antialias_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_sizes_linear_pytorch_half_pixel_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_sizes_nearest_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_sizes_nearest_not_larger_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_downsample_sizes_nearest_not_smaller_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_tf_crop_and_resize_axes_2_3_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_tf_crop_and_resize_axes_3_2_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_tf_crop_and_resize_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_scales_cubic_align_corners_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_scales_cubic_asymmetric_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_scales_cubic_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_scales_linear_align_corners_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_scales_linear_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_scales_linear_half_pixel_symmetric_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_scales_nearest_axes_2_3_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_scales_nearest_axes_3_2_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_scales_nearest_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_sizes_cubic_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_sizes_nearest_axes_2_3_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_sizes_nearest_axes_3_2_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_sizes_nearest_ceil_half_pixel_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_sizes_nearest_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_sizes_nearest_floor_align_corners_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_sizes_nearest_not_larger_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_scan_sum_cpu - ...
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_sequence_map_add_1_sequence_1_tensor_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_sequence_map_add_2_sequences_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_sequence_map_identity_1_sequence_1_tensor_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_sequence_map_identity_2_sequences_expanded_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_stft_cpu - Valu...
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_training_dropout_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_training_dropout_default_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_training_dropout_default_mask_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_training_dropout_mask_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_training_dropout_zero_ratio_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendNodeModelTest::test_training_dropout_zero_ratio_mask_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchConvertedModelTest::test_MaxPool1d_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchConvertedModelTest::test_MaxPool1d_stride_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchConvertedModelTest::test_MaxPool1d_stride_padding_dilation_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchConvertedModelTest::test_MaxPool3d_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchConvertedModelTest::test_MaxPool3d_stride_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchConvertedModelTest::test_MaxPool3d_stride_padding_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchConvertedModelTest::test_PReLU_1d_multiparam_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchConvertedModelTest::test_PReLU_2d_multiparam_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchConvertedModelTest::test_PReLU_3d_multiparam_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchOperatorModelTest::test_operator_chunk_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchOperatorModelTest::test_operator_index_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchOperatorModelTest::test_operator_maxpool_cpu
FAILED tests/onnx_ops_test.py::OnnxBackendPyTorchOperatorModelTest::test_operator_view_cpu
========= 106 failed, 1155 passed, 2616 skipped, 47 xfailed in 38.12s ==========
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment