Skip to content

Instantly share code, notes, and snippets.

@josephrocca
Last active October 20, 2022 14:42
Show Gist options
  • Save josephrocca/7e3b9723b263ab1dbf9d8b106c1fb721 to your computer and use it in GitHub Desktop.
Save josephrocca/7e3b9723b263ab1dbf9d8b106c1fb721 to your computer and use it in GitHub Desktop.
stable_diffusion_jax-to-onnx - scheduler_loop_body.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [],
"machine_shape": "hm",
"name": "stable_diffusion_jax-to-onnx - scheduler_loop_body.ipynb",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"gpuClass": "premium"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/josephrocca/7e3b9723b263ab1dbf9d8b106c1fb721/stable_diffusion_jax-to-onnx.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install transformers==4.23.1 huggingface_hub==0.10.0 ftfy==6.1.1 flax==0.6.1 git+https://github.com/huggingface/diffusers.git@v0.5.0 git+https://github.com/onnx/tensorflow-onnx@v1.12.1\n",
"!pip install --upgrade jax jaxlib\n",
"!pip install tensorflow==2.9.2"
],
"metadata": {
"id": "0YHLndloz1U_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from huggingface_hub.hf_api import HfFolder\n",
"HfFolder.save_token('h'+'f'+'_'+'AUxlCqSud'+'NTSgaWmE'+'jrUgRytG'+'JiBTLoYSD') # Don't worry! This key can be safely made public. It's just a read-only key for an \"empty\"/dummy Hugging Face account (temp email) that was SPECIFICALLY created to make it easier to access the Stable Diffusion model in Colab (less copy-pasting my token during many runtime resets). The `+` concatenation is just so it doesn't trigger any Github API key detection alarms, or whatever.\n",
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from PIL import Image\n",
"import tensorflow as tf\n",
"from jax.experimental import jax2tf\n",
"from diffusers import FlaxStableDiffusionPipeline\n",
"import tf2onnx"
],
"metadata": {
"id": "UFMtdmPeyxpi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", revision=\"flax\", dtype=jnp.float32, safety_checker=None)"
],
"metadata": {
"id": "PPtraQX34Az7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"unet_params = params[\"unet\"]\n",
"num_inference_steps = 10\n",
"guidance_scale = 7.5\n",
"batch_size = 1\n",
"from diffusers.schedulers.scheduling_pndm_flax import PNDMSchedulerState\n",
"\n",
"def scheduler_loop_body(unet_params, step, latents, context, scheduler_state_dict):\n",
"\n",
" scheduler_state = PNDMSchedulerState(scheduler_state_dict[\"_timesteps\"], scheduler_state_dict[\"num_inference_steps\"], scheduler_state_dict[\"prk_timesteps\"], scheduler_state_dict[\"plms_timesteps\"], scheduler_state_dict[\"timesteps\"], scheduler_state_dict[\"cur_model_output\"], scheduler_state_dict[\"counter\"], scheduler_state_dict[\"cur_sample\"], scheduler_state_dict[\"ets\"])\n",
"\n",
" # NOTE: I've successfully exported this group of 6 lines on their own, so I don't think any of these lines are the source of the issue.\n",
" latents_input = jnp.concatenate([latents] * 2)\n",
" t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]\n",
" timestep = jnp.broadcast_to(t, latents_input.shape[0])\n",
" noise_pred = pipeline.unet.apply({\"params\": unet_params}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context,).sample\n",
" noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)\n",
" noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)\n",
"\n",
" # compute the previous noisy sample x_t -> x_t-1\n",
" latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()\n",
"\n",
" scheduler_state_dict = {\"_timesteps\": scheduler_state._timesteps, \"num_inference_steps\": scheduler_state.num_inference_steps, \"prk_timesteps\": scheduler_state.prk_timesteps, \"plms_timesteps\": scheduler_state.plms_timesteps, \"timesteps\": scheduler_state.timesteps, \"cur_model_output\": scheduler_state.cur_model_output, \"counter\": scheduler_state.counter, \"cur_sample\": scheduler_state.cur_sample, \"ets\": scheduler_state.ets}\n",
"\n",
" return latents, scheduler_state_dict"
],
"metadata": {
"id": "m0nGZcx_ZAnU"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"unet_params_vars = tf.nest.map_structure(tf.Variable, unet_params)\n",
"\n",
"# Build the prediction function by closing over the `params_vars`. If you instead were to close over `params` your SavedModel would have no variables and the parameters will be included in the function graph.\n",
"scheduler_loop_body_tf = lambda step, latents, context, scheduler_state_dict: jax2tf.convert(scheduler_loop_body, enable_xla=False)(unet_params_vars, step, latents, context, scheduler_state_dict)\n",
"\n",
"scheduler_state_dict_signature = { \"_timesteps\": tf.TensorSpec([10], tf.int32, name=\"_timesteps\"), \"num_inference_steps\": tf.TensorSpec([], tf.int32, name=\"num_inference_steps\"), \"prk_timesteps\": tf.TensorSpec([0], tf.float32, name=\"prk_timesteps\"), \"plms_timesteps\": tf.TensorSpec([11], tf.int32, name=\"plms_timesteps\"), \"timesteps\": tf.TensorSpec([11], tf.int32, name=\"timesteps\"), \"cur_model_output\": tf.TensorSpec([1, 4, 64, 64], tf.float32, name=\"cur_model_output\"), \"counter\": tf.TensorSpec([], tf.int32, name=\"counter\"), \"cur_sample\": tf.TensorSpec([1, 4, 64, 64], tf.float32, name=\"cur_sample\"), \"ets\": tf.TensorSpec([4, 1, 4, 64, 64], tf.float32, name=\"ets\") }\n",
"\n",
"my_model = tf.Module()\n",
"my_model.f = tf.function(scheduler_loop_body_tf, autograph=False, jit_compile=True, input_signature=[\n",
" tf.TensorSpec([], tf.uint32, name=\"step\"),\n",
" tf.TensorSpec([1, 4, 64, 64], tf.float32, name=\"latents\"),\n",
" tf.TensorSpec([2, 77, 768], tf.float32, name=\"context\"),\n",
" scheduler_state_dict_signature,\n",
"]) "
],
"metadata": {
"id": "QSlZAKV6ysXt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# NOT WORKING:\n",
"# ERROR:tf2onnx.tf_utils:pass1 convert failed for name: \"jax2tf_scheduler_loop_body_/switch_case/indexed_case\", op: \"StatelessCase\"\n",
"# ValueError: You passed in an iterable attribute but I cannot figure out its applicable type.\n",
"tf2onnx.convert.from_function(my_model.f, input_signature=[\n",
" tf.TensorSpec([], tf.uint32, name=\"step\"),\n",
" tf.TensorSpec([1, 4, 64, 64], tf.float32, name=\"latents\"),\n",
" tf.TensorSpec([2, 77, 768], tf.float32, name=\"context\"),\n",
" scheduler_state_dict_signature,\n",
"], large_model=True, opset=16, output_path=\"scheduler_loop_body.onnx\")"
],
"metadata": {
"id": "HKFd3XqHJ7L9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "f33d0abf-69a8-45c5-f6b6-fd6f46cfcd96"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"ERROR:tf2onnx.tf_utils:pass1 convert failed for name: \"jax2tf_scheduler_loop_body_/switch_case/indexed_case\"\n",
"op: \"StatelessCase\"\n",
"input: \"jax2tf_scheduler_loop_body_/clip_by_value_1\"\n",
"input: \"jax2tf_scheduler_loop_body_/jit_fn_/AddV2\"\n",
"input: \"jax2tf_arg_693\"\n",
"input: \"jax2tf_arg_689\"\n",
"input: \"jax2tf_arg_694\"\n",
"input: \"jax2tf_arg_696\"\n",
"input: \"jax2tf_arg_695\"\n",
"input: \"jax2tf_arg_697\"\n",
"input: \"jax2tf_arg_690\"\n",
"input: \"jax2tf_arg_687\"\n",
"input: \"jax2tf_arg_692\"\n",
"attr {\n",
" key: \"Tin\"\n",
" value {\n",
" list {\n",
" type: DT_FLOAT\n",
" type: DT_FLOAT\n",
" type: DT_INT32\n",
" type: DT_INT32\n",
" type: DT_FLOAT\n",
" type: DT_INT32\n",
" type: DT_INT32\n",
" type: DT_INT32\n",
" type: DT_FLOAT\n",
" type: DT_FLOAT\n",
" }\n",
" }\n",
"}\n",
"attr {\n",
" key: \"Tout\"\n",
" value {\n",
" list {\n",
" type: DT_INT32\n",
" type: DT_INT32\n",
" type: DT_FLOAT\n",
" type: DT_INT32\n",
" type: DT_INT32\n",
" type: DT_FLOAT\n",
" type: DT_INT32\n",
" type: DT_FLOAT\n",
" type: DT_FLOAT\n",
" }\n",
" }\n",
"}\n",
"attr {\n",
" key: \"_read_only_resource_inputs\"\n",
" value {\n",
" list {\n",
" }\n",
" }\n",
"}\n",
"attr {\n",
" key: \"_xla_propagate_compile_time_consts\"\n",
" value {\n",
" b: true\n",
" }\n",
"}\n",
"attr {\n",
" key: \"branches\"\n",
" value {\n",
" list {\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch0_33143\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch1_33144\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch2_33145\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch3_33146\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch4_33147\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch5_33148\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch6_33149\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch7_33150\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch8_33151\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch9_33152\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch10_33153\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch11_33154\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch12_33155\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch13_33156\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch14_33157\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch15_33158\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch16_33159\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch17_33160\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch18_33161\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch19_33162\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch20_33163\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch21_33164\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch22_33165\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch23_33166\"\n",
" }\n",
" func {\n",
" name: \"jax2tf_scheduler_loop_body__switch_case_indexed_case_branch24_33167\"\n",
" }\n",
" }\n",
" }\n",
"}\n",
"attr {\n",
" key: \"output_shapes\"\n",
" value {\n",
" list {\n",
" shape {\n",
" dim {\n",
" size: 10\n",
" }\n",
" }\n",
" shape {\n",
" }\n",
" shape {\n",
" dim {\n",
" }\n",
" }\n",
" shape {\n",
" dim {\n",
" size: 11\n",
" }\n",
" }\n",
" shape {\n",
" dim {\n",
" size: 11\n",
" }\n",
" }\n",
" shape {\n",
" dim {\n",
" size: 1\n",
" }\n",
" dim {\n",
" size: 4\n",
" }\n",
" dim {\n",
" size: 64\n",
" }\n",
" dim {\n",
" size: 64\n",
" }\n",
" }\n",
" shape {\n",
" }\n",
" shape {\n",
" dim {\n",
" size: 1\n",
" }\n",
" dim {\n",
" size: 4\n",
" }\n",
" dim {\n",
" size: 64\n",
" }\n",
" dim {\n",
" size: 64\n",
" }\n",
" }\n",
" shape {\n",
" dim {\n",
" size: 4\n",
" }\n",
" dim {\n",
" size: 1\n",
" }\n",
" dim {\n",
" size: 4\n",
" }\n",
" dim {\n",
" size: 64\n",
" }\n",
" dim {\n",
" size: 64\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
", ex=You passed in an iterable attribute but I cannot figure out its applicable type.\n"
]
},
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-6-61da6cf24366>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensorSpec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m77\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m768\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"context\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mscheduler_state_dict_signature\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m ], large_model=True, opset=16, output_path=\"scheduler_loop_body.onnx\")\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py\u001b[0m in \u001b[0;36mfrom_function\u001b[0;34m(function, input_signature, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw, outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path)\u001b[0m\n\u001b[1;32m 577\u001b[0m \u001b[0mtensors_to_rename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtensors_to_rename\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 578\u001b[0m \u001b[0minitialized_tables\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minitialized_tables\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 579\u001b[0;31m output_path=output_path)\n\u001b[0m\u001b[1;32m 580\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 581\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel_proto\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexternal_tensor_storage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py\u001b[0m in \u001b[0;36m_convert_common\u001b[0;34m(frozen_graph, name, large_model, output_path, output_frozen_graph, custom_ops, custom_op_handlers, optimizers, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_graph_def\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrozen_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m g = process_tf_graph(tf_graph, const_node_values=const_node_values,\n\u001b[0;32m--> 165\u001b[0;31m custom_op_handlers=custom_op_handlers, **kwargs)\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconstants\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mENV_TF2ONNX_CATCH_ERRORS\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menviron\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mcatch_errors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconstants\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mENV_TF2ONNX_CATCH_ERRORS\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"TRUE\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py\u001b[0m in \u001b[0;36mprocess_tf_graph\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 458\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 459\u001b[0m main_g, subgraphs = graphs_from_tf(tf_graph, input_names, output_names, shape_override, const_node_values,\n\u001b[0;32m--> 460\u001b[0;31m ignore_default, use_default)\n\u001b[0m\u001b[1;32m 461\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mmain_g\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msubgraphs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py\u001b[0m in \u001b[0;36mgraphs_from_tf\u001b[0;34m(tf_graph, input_names, output_names, shape_override, const_node_values, ignore_default, use_default)\u001b[0m\n\u001b[1;32m 472\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshape_override\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 473\u001b[0m \u001b[0mshape_override\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 474\u001b[0;31m \u001b[0mordered_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresolve_functions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 475\u001b[0m \u001b[0msubgraphs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfunc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mordered_func\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/tf_loader.py\u001b[0m in \u001b[0;36mresolve_functions\u001b[0;34m(tf_graph)\u001b[0m\n\u001b[1;32m 766\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdep\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mordered\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mordered\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 767\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 768\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunctions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtflist_to_onnx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 769\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 770\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfdef\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtf_graph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_functions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tf2onnx/tf_utils.py\u001b[0m in \u001b[0;36mtflist_to_onnx\u001b[0;34m(g, shape_override, const_node_values, ignore_default, use_default)\u001b[0m\n\u001b[1;32m 459\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtakeit\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 460\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 461\u001b[0;31m \u001b[0monnx_node\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhelper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake_node\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mattr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 462\u001b[0m \u001b[0monnx_nodes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0monnx_node\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 463\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/onnx/helper.py\u001b[0m in \u001b[0;36mmake_node\u001b[0;34m(op_type, inputs, outputs, name, doc_string, domain, **kwargs)\u001b[0m\n\u001b[1;32m 118\u001b[0m node.attribute.extend(\n\u001b[1;32m 119\u001b[0m \u001b[0mmake_attribute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 120\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 121\u001b[0m if value is not None)\n\u001b[1;32m 122\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/onnx/helper.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0mmake_attribute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 121\u001b[0;31m if value is not None)\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/onnx/helper.py\u001b[0m in \u001b[0;36mmake_attribute\u001b[0;34m(key, value, doc_string)\u001b[0m\n\u001b[1;32m 538\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 539\u001b[0m raise ValueError(\n\u001b[0;32m--> 540\u001b[0;31m \u001b[0;34m\"You passed in an iterable attribute but I cannot figure out \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 541\u001b[0m \"its applicable type.\")\n\u001b[1;32m 542\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: You passed in an iterable attribute but I cannot figure out its applicable type."
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment