Last active
January 9, 2024 00:12
-
-
Save i418c/d58f02b7f89adf4511792e12cd5ea65f to your computer and use it in GitHub Desktop.
tfp-distribution-unsubscriptable.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyMMFjuTV9G8La0eHWfT+ETc", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/i418c/d58f02b7f89adf4511792e12cd5ea65f/tfp-distribution-unsubscriptable.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip uninstall -y tensorflow tensorflow_probability\n", | |
"!pip install tf-nightly tfp-nightly" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "NWMS4SXpIJsU", | |
"outputId": "a7d67d03-989e-46da-ce0c-2dee340b1d4c" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Found existing installation: tensorflow 2.15.0\n", | |
"Uninstalling tensorflow-2.15.0:\n", | |
" Successfully uninstalled tensorflow-2.15.0\n", | |
"Found existing installation: tensorflow-probability 0.22.0\n", | |
"Uninstalling tensorflow-probability-0.22.0:\n", | |
" Successfully uninstalled tensorflow-probability-0.22.0\n", | |
"Collecting tf-nightly\n", | |
" Downloading tf_nightly-2.16.0.dev20240106-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (566.2 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m566.2/566.2 MB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting tfp-nightly\n", | |
" Downloading tfp_nightly-0.24.0.dev20240108-py2.py3-none-any.whl (6.9 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.9/6.9 MB\u001b[0m \u001b[31m84.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (1.4.0)\n", | |
"Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (1.6.3)\n", | |
"Requirement already satisfied: flatbuffers>=23.5.26 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (23.5.26)\n", | |
"Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (0.5.4)\n", | |
"Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (0.2.0)\n", | |
"Collecting h5py>=3.10.0 (from tf-nightly)\n", | |
" Downloading h5py-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.8 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.8/4.8 MB\u001b[0m \u001b[31m94.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (16.0.6)\n", | |
"Collecting ml-dtypes~=0.3.1 (from tf-nightly)\n", | |
" Downloading ml_dtypes-0.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m89.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (3.3.0)\n", | |
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (23.2)\n", | |
"Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (3.20.3)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (67.7.2)\n", | |
"Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (1.16.0)\n", | |
"Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (2.4.0)\n", | |
"Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (4.5.0)\n", | |
"Requirement already satisfied: wrapt<1.15,>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (1.14.1)\n", | |
"Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (1.60.0)\n", | |
"Collecting tb-nightly~=2.16.0.a (from tf-nightly)\n", | |
" Downloading tb_nightly-2.16.0a20240108-py3-none-any.whl (5.5 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m104.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting tf-estimator-nightly~=2.14.0.dev (from tf-nightly)\n", | |
" Downloading tf_estimator_nightly-2.14.0.dev2023080308-py2.py3-none-any.whl (440 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m440.9/440.9 kB\u001b[0m \u001b[31m41.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting keras-nightly~=3.0.0.dev (from tf-nightly)\n", | |
" Downloading keras_nightly-3.0.3.dev2024010803-py3-none-any.whl (1.0 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m72.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (0.35.0)\n", | |
"Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /usr/local/lib/python3.10/dist-packages (from tf-nightly) (1.23.5)\n", | |
"Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from tfp-nightly) (4.4.2)\n", | |
"Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.10/dist-packages (from tfp-nightly) (2.2.1)\n", | |
"Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from tfp-nightly) (0.1.8)\n", | |
"Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from astunparse>=1.6.0->tf-nightly) (0.42.0)\n", | |
"Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras-nightly~=3.0.0.dev->tf-nightly) (13.7.0)\n", | |
"Collecting namex (from keras-nightly~=3.0.0.dev->tf-nightly)\n", | |
" Downloading namex-0.0.7-py3-none-any.whl (5.8 kB)\n", | |
"Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tb-nightly~=2.16.0.a->tf-nightly) (2.17.3)\n", | |
"Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tb-nightly~=2.16.0.a->tf-nightly) (1.2.0)\n", | |
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tb-nightly~=2.16.0.a->tf-nightly) (3.5.1)\n", | |
"Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tb-nightly~=2.16.0.a->tf-nightly) (2.31.0)\n", | |
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tb-nightly~=2.16.0.a->tf-nightly) (0.7.2)\n", | |
"Collecting tf-keras-nightly (from tb-nightly~=2.16.0.a->tf-nightly)\n", | |
" Downloading tf_keras_nightly-2.16.0.dev2024010810-py3-none-any.whl (1.7 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tb-nightly~=2.16.0.a->tf-nightly) (3.0.1)\n", | |
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.16.0.a->tf-nightly) (5.3.2)\n", | |
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.16.0.a->tf-nightly) (0.3.0)\n", | |
"Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.16.0.a->tf-nightly) (4.9)\n", | |
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<2,>=0.5->tb-nightly~=2.16.0.a->tf-nightly) (1.3.1)\n", | |
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tb-nightly~=2.16.0.a->tf-nightly) (3.3.2)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tb-nightly~=2.16.0.a->tf-nightly) (3.6)\n", | |
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tb-nightly~=2.16.0.a->tf-nightly) (2.0.7)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tb-nightly~=2.16.0.a->tf-nightly) (2023.11.17)\n", | |
"Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tb-nightly~=2.16.0.a->tf-nightly) (2.1.3)\n", | |
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras-nightly~=3.0.0.dev->tf-nightly) (3.0.0)\n", | |
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras-nightly~=3.0.0.dev->tf-nightly) (2.16.1)\n", | |
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras-nightly~=3.0.0.dev->tf-nightly) (0.1.2)\n", | |
"Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tb-nightly~=2.16.0.a->tf-nightly) (0.5.1)\n", | |
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tb-nightly~=2.16.0.a->tf-nightly) (3.2.2)\n", | |
"Installing collected packages: namex, tfp-nightly, tf-keras-nightly, tf-estimator-nightly, ml-dtypes, h5py, keras-nightly, tb-nightly, tf-nightly\n", | |
" Attempting uninstall: ml-dtypes\n", | |
" Found existing installation: ml-dtypes 0.2.0\n", | |
" Uninstalling ml-dtypes-0.2.0:\n", | |
" Successfully uninstalled ml-dtypes-0.2.0\n", | |
" Attempting uninstall: h5py\n", | |
" Found existing installation: h5py 3.9.0\n", | |
" Uninstalling h5py-3.9.0:\n", | |
" Successfully uninstalled h5py-3.9.0\n", | |
"Successfully installed h5py-3.10.0 keras-nightly-3.0.3.dev2024010803 ml-dtypes-0.3.2 namex-0.0.7 tb-nightly-2.16.0a20240108 tf-estimator-nightly-2.14.0.dev2023080308 tf-keras-nightly-2.16.0.dev2024010810 tf-nightly-2.16.0.dev20240106 tfp-nightly-0.24.0.dev20240108\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"id": "tpaR6JRyHHfI", | |
"outputId": "af354477-848b-4209-f7cb-9d0686b14f11", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"2.16.0-dev20240106\n", | |
"0.24.0-dev20240108\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"import tensorflow_probability as tfp\n", | |
"from tensorflow import keras\n", | |
"\n", | |
"print(tf.__version__)\n", | |
"print(tfp.__version__)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encoded_size=2\n", | |
"\n", | |
"class EncoderDecoder(keras.Layer):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" self.dense1=keras.layers.Dense(128, activation='relu')\n", | |
" self.dense2=keras.layers.Dense(tfp.layers.IndependentNormal.params_size(encoded_size))\n", | |
" self.dist=tfp.layers.IndependentNormal(event_shape=[encoded_size],validate_args=True)\n", | |
"\n", | |
" def call(self,inputs):\n", | |
" x=self.dense1(inputs=inputs)\n", | |
" x=self.dense2(x)\n", | |
" return self.dist(x)\n", | |
"\n", | |
"class AutoEncoder(keras.Model):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" self.encoder = EncoderDecoder()\n", | |
" self.decoder = EncoderDecoder()\n", | |
"\n", | |
" def call(self, x):\n", | |
" encoded = self.encoder(x)\n", | |
" decoded = self.decoder(inputs=encoded)\n", | |
" return decoded\n", | |
"\n", | |
"def negative_log_likelihood(y_true, y_pred):\n", | |
" return -y_pred.log_prob(y_true)" | |
], | |
"metadata": { | |
"id": "oKaB8_XWKQvk" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"batch_size=32\n", | |
"output_dim=2\n", | |
"\n", | |
"train=np.random.rand(batch_size,40,5)\n", | |
"verify=np.random.rand(batch_size,10,output_dim)" | |
], | |
"metadata": { | |
"id": "qmGP6BkSN8dW" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model=AutoEncoder()\n", | |
"model.compile(optimizer=tf.keras.optimizers.Adam(\n", | |
" learning_rate=0.0001), loss=negative_log_likelihood,\n", | |
" run_eagerly=False, jit_compile=False)" | |
], | |
"metadata": { | |
"id": "XosLavY5LLCj" | |
}, | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model.fit(train,verify,epochs=10,batch_size=batch_size)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 425 | |
}, | |
"id": "1c-723fFRBtr", | |
"outputId": "25f99fe0-17a4-4b85-ac4c-6877c5610124" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Epoch 1/10\n" | |
] | |
}, | |
{ | |
"output_type": "error", | |
"ename": "TypeError", | |
"evalue": "Exception encountered when calling EncoderDecoder.call().\n\n\u001b[1m'NoneType' object is not subscriptable\u001b[0m\n\nArguments received by EncoderDecoder.call():\n • inputs=<tfp.distributions._TensorCoercible 'tensor_coercible' batch_shape=[32, 40] event_shape=[2] dtype=float32>", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-8-ed773d6be14e>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mverify\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0;31m# To get the full stack trace, call:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0;31m# `keras.config.disable_traceback_filtering()`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 123\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 124\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-5-758248108003>\u001b[0m in \u001b[0;36mcall\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mencoded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mdecoded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mencoded\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdecoded\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-5-758248108003>\u001b[0m in \u001b[0;36mcall\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdense1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdense2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mTypeError\u001b[0m: Exception encountered when calling EncoderDecoder.call().\n\n\u001b[1m'NoneType' object is not subscriptable\u001b[0m\n\nArguments received by EncoderDecoder.call():\n • inputs=<tfp.distributions._TensorCoercible 'tensor_coercible' batch_shape=[32, 40] event_shape=[2] dtype=float32>" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment