Skip to content

Instantly share code, notes, and snippets.

@josephrocca
Last active November 28, 2022 10:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save josephrocca/db146f00593f86e86a0ecc87c49453d3 to your computer and use it in GitHub Desktop.
Save josephrocca/db146f00593f86e86a0ecc87c49453d3 to your computer and use it in GitHub Desktop.
jax2tf2onnx - jax.lax.switch with functions.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"name": "jax2tf2onnx - jax.lax.switch with functions.ipynb",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"gpuClass": "premium"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/josephrocca/db146f00593f86e86a0ecc87c49453d3/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install git+https://github.com/onnx/tensorflow-onnx@v1.12.1\n",
"!pip install --upgrade jax jaxlib"
],
"metadata": {
"id": "0YHLndloz1U_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import tensorflow as tf\n",
"from jax.experimental import jax2tf\n",
"import tf2onnx"
],
"metadata": {
"id": "UFMtdmPeyxpi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def test_jax(n, operand):\n",
"\n",
" def fn1(a):\n",
" return a+2\n",
" \n",
" def fn2(a):\n",
" return a*2\n",
"\n",
" result = jax.lax.switch(\n",
" n,\n",
" [fn1, fn2],\n",
" operand,\n",
" )\n",
"\n",
" # This approach is incompatible with JAX\n",
" # if n == 0:\n",
" # result = fn1(operand)\n",
" # if n == 1:\n",
" # result = fn2(operand)\n",
"\n",
" # This approach is also incompatible with JAX\n",
" # fns = [fn1, fn2]\n",
" # result = fns[n](operand)\n",
"\n",
" return result\n"
],
"metadata": {
"id": "m0nGZcx_ZAnU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"jax.jit(test_jax)(1, 2)"
],
"metadata": {
"id": "qrvnptD7GW5c"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"my_model = tf.Module()\n",
"my_model.f = tf.function(jax2tf.convert(test_jax, enable_xla=False), jit_compile=True, autograph=False, input_signature=[\n",
" tf.TensorSpec([], tf.uint32, name=\"n\"),\n",
" tf.TensorSpec([], tf.uint32, name=\"operand\"),\n",
"])\n",
"tf.saved_model.save(my_model, './test')"
],
"metadata": {
"id": "Mvk1dkd20zqm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# NOT WORKING:\n",
"# ValueError: You passed in an iterable attribute but I cannot figure out its applicable type.\n",
"tf2onnx.convert.from_function(my_model.f, input_signature=[\n",
" tf.TensorSpec([], tf.uint32, name=\"n\"),\n",
" tf.TensorSpec([], tf.uint32, name=\"operand\"),\n",
"], opset=17, output_path=\"test.onnx\")"
],
"metadata": {
"id": "HKFd3XqHJ7L9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "9637614f-51d9-4db8-d548-74fa93ebda92"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tf2onnx/tf_loader.py:715: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
"ERROR:tf2onnx.tf_utils:pass1 convert failed for name: \"jax2tf_test_jax_/switch_case/indexed_case\"\n",
"op: \"StatelessCase\"\n",
"input: \"jax2tf_test_jax_/clip_by_value\"\n",
"input: \"operand\"\n",
"attr {\n",
" key: \"Tin\"\n",
" value {\n",
" list {\n",
" type: DT_UINT32\n",
" }\n",
" }\n",
"}\n",
"attr {\n",
" key: \"Tout\"\n",
" value {\n",
" list {\n",
" type: DT_UINT32\n",
" }\n",
" }\n",
"}\n",
"attr {\n",
" key: \"_read_only_resource_inputs\"\n",
" value {\n",
" list {\n",
" }\n",
" }\n",
"}\n",
"attr {\n",
" key: \"_xla_propagate_compile_time_consts\"\n",
" value {\n",
" b: true\n",
" }\n",
"}\n",
"attr {\n",
" key: \"branches\"\n",
" value {\n",
" list {\n",
" func {\n",
" name: \"jax2tf_test_jax__switch_case_indexed_case_branch0_225\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_test_jax__switch_case_indexed_case_branch1_226\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_test_jax__switch_case_indexed_case_branch2_227\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_test_jax__switch_case_indexed_case_branch3_228\"\n",
" }\n",
" }\n",
" }\n",
"}\n",
"attr {\n",
" key: \"output_shapes\"\n",
" value {\n",
" list {\n",
" shape {\n",
" }\n",
" }\n",
" }\n",
"}\n",
", ex=You passed in an iterable attribute but I cannot figure out its applicable type.\n"
]
},
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-6-e084234668ae>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensorSpec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muint32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"n\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensorSpec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muint32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"operand\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m ], opset=17, output_path=\"test.onnx\")\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py\u001b[0m in \u001b[0;36mfrom_function\u001b[0;34m(function, input_signature, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw, outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path)\u001b[0m\n\u001b[1;32m 577\u001b[0m \u001b[0mtensors_to_rename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtensors_to_rename\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 578\u001b[0m \u001b[0minitialized_tables\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minitialized_tables\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 579\u001b[0;31m output_path=output_path)\n\u001b[0m\u001b[1;32m 580\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 581\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel_proto\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexternal_tensor_storage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py\u001b[0m in \u001b[0;36m_convert_common\u001b[0;34m(frozen_graph, name, large_model, output_path, output_frozen_graph, custom_ops, custom_op_handlers, optimizers, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_graph_def\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrozen_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m g = process_tf_graph(tf_graph, const_node_values=const_node_values,\n\u001b[0;32m--> 165\u001b[0;31m custom_op_handlers=custom_op_handlers, **kwargs)\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconstants\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mENV_TF2ONNX_CATCH_ERRORS\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menviron\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mcatch_errors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconstants\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mENV_TF2ONNX_CATCH_ERRORS\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"TRUE\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py\u001b[0m in \u001b[0;36mprocess_tf_graph\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 458\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 459\u001b[0m main_g, subgraphs = graphs_from_tf(tf_graph, input_names, output_names, shape_override, const_node_values,\n\u001b[0;32m--> 460\u001b[0;31m ignore_default, use_default)\n\u001b[0m\u001b[1;32m 461\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mmain_g\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msubgraphs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py\u001b[0m in \u001b[0;36mgraphs_from_tf\u001b[0;34m(tf_graph, input_names, output_names, shape_override, const_node_values, ignore_default, use_default)\u001b[0m\n\u001b[1;32m 472\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshape_override\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 473\u001b[0m \u001b[0mshape_override\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 474\u001b[0;31m \u001b[0mordered_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresolve_functions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 475\u001b[0m \u001b[0msubgraphs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfunc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mordered_func\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/tf_loader.py\u001b[0m in \u001b[0;36mresolve_functions\u001b[0;34m(tf_graph)\u001b[0m\n\u001b[1;32m 766\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdep\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mordered\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mordered\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 767\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 768\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunctions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtflist_to_onnx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 769\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 770\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfdef\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtf_graph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_functions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/tf_utils.py\u001b[0m in \u001b[0;36mtflist_to_onnx\u001b[0;34m(g, shape_override, const_node_values, ignore_default, use_default)\u001b[0m\n\u001b[1;32m 459\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtakeit\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 460\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 461\u001b[0;31m \u001b[0monnx_node\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhelper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake_node\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mattr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 462\u001b[0m \u001b[0monnx_nodes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0monnx_node\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 463\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/onnx/helper.py\u001b[0m in \u001b[0;36mmake_node\u001b[0;34m(op_type, inputs, outputs, name, doc_string, domain, **kwargs)\u001b[0m\n\u001b[1;32m 118\u001b[0m node.attribute.extend(\n\u001b[1;32m 119\u001b[0m \u001b[0mmake_attribute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 120\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 121\u001b[0m if value is not None)\n\u001b[1;32m 122\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/onnx/helper.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0mmake_attribute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 121\u001b[0;31m if value is not None)\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/onnx/helper.py\u001b[0m in \u001b[0;36mmake_attribute\u001b[0;34m(key, value, doc_string)\u001b[0m\n\u001b[1;32m 538\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 539\u001b[0m raise ValueError(\n\u001b[0;32m--> 540\u001b[0;31m \u001b[0;34m\"You passed in an iterable attribute but I cannot figure out \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 541\u001b[0m \"its applicable type.\")\n\u001b[1;32m 542\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: You passed in an iterable attribute but I cannot figure out its applicable type."
]
}
]
},
{
"cell_type": "code",
"source": [
"# This doesn't work either - same error.\n",
"# !python -m tf2onnx.convert --opset 17 --saved-model test --output test.onnx"
],
"metadata": {
"id": "QgU9FaVLNDnt"
},
"execution_count": null,
"outputs": []
}
]
}
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment