Skip to content

Instantly share code, notes, and snippets.

@PaulSayantan
Last active August 4, 2020 17:21
Show Gist options
  • Save PaulSayantan/779fb44c9f533062e03868c8b01e859f to your computer and use it in GitHub Desktop.
Save PaulSayantan/779fb44c9f533062e03868c8b01e859f to your computer and use it in GitHub Desktop.
Tensorflow
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "tensorflow-basics.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "IwkK3jqVNZWE",
"colab_type": "text"
},
"source": [
"# <h1 align=\"center\">Tensorflow For Beginners</h1>\n",
"\n",
"<h6 align=\"right\">𝓢𝓪𝔂𝓪𝓷𝓽𝓪𝓷 𝓟𝓪𝓾𝓵</h6>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "97JiPuTIRnue",
"colab_type": "text"
},
"source": [
"## Install Tensorflow"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KrtAtRtPRghC",
"colab_type": "text"
},
"source": [
"To install tensorflow in local machine\n",
"\n",
"~~~\n",
"pip install -q tensorflow tensorflow-addons tensorflow-gpu\n",
"~~~"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AHPTAvG_T5v7",
"colab_type": "text"
},
"source": [
"## Import Tensorflow"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5TNjDOQ0TQvu",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf\n",
"import cProfile"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "JX8ec78tEkwx",
"colab_type": "text"
},
"source": [
"# <h1 align=\"center\">-----Part - 1-----</h1>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xe4A9AtQEmLJ",
"colab_type": "text"
},
"source": [
"## Eager execution\n",
"\n",
"TensorFlow's eager execution is an imperative programming environment that evaluates operations immediately, without building graphs: operations return concrete values instead of constructing a computational graph to run later. This makes it easy to get started with TensorFlow and debug models, and it reduces boilerplate as well. To follow along with this guide, run the code samples below in an interactive python interpreter.\n",
"\n",
"Eager execution is a flexible machine learning platform for research and experimentation, providing:\n",
"\n",
" - An **intuitive interface** — Structure your code naturally and use Python data structures. Quickly iterate on small models and small data.\n",
" - **Easier debugging** — Call ops directly to inspect running models and test changes. Use standard Python debugging tools for immediate error reporting.\n",
" - **Natural control flow** — Use Python control flow instead of graph control flow, simplifying the specification of dynamic models.\n",
"\n",
"Eager execution supports most TensorFlow operations and GPU acceleration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sVADt9EWFH70",
"colab_type": "text"
},
"source": [
"In Tensorflow 2.0, eager execution is enabled by default."
]
},
{
"cell_type": "code",
"metadata": {
"id": "XFME4pXGFQaH",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "17bcf5d3-f7ec-4294-dbbb-e0aff69642b6"
},
"source": [
"tf.executing_eagerly()"
],
"execution_count": 2,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {
"tags": []
},
"execution_count": 2
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ISciwl9fFYPv",
"colab_type": "text"
},
"source": [
"Now you can run TensorFlow operations and the results will return immediately:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "p68paqllFZud",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "f8418c95-dcd0-4a84-f835-d0435dc40ddd"
},
"source": [
"x = [[2.]]\n",
"m = tf.matmul(x, x)\n",
"print(\"hello, {}\".format(m))"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"hello, [[4.]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CSv0MgWBFg1P",
"colab_type": "text"
},
"source": [
"Eager execution works nicely with NumPy. NumPy operations accept `tf.Tensor` arguments. The TensorFlow `tf.math` operations convert Python objects and NumPy arrays to `tf.Tensor` objects. The `tf.Tensor.numpy` method returns the object's value as a NumPy `ndarray`."
]
},
{
"cell_type": "code",
"metadata": {
"id": "Z5EpgU-rFupc",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "15f78ef0-7dcc-4f7e-a449-4c13bef0d921"
},
"source": [
"a = tf.constant([[1, 2],\n",
" [3, 4]])\n",
"print(a)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[1 2]\n",
" [3 4]], shape=(2, 2), dtype=int32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "563JWqChF03W",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "b5ba034a-f015-4ad7-dafa-b8ac73764c5e"
},
"source": [
"# Broadcasting support\n",
"b = tf.add(a, 1)\n",
"print(b)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[2 3]\n",
" [4 5]], shape=(2, 2), dtype=int32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "1qI3F6dwF5PV",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "5e79f471-e802-46c1-df76-3746c7396ed8"
},
"source": [
"# Operator overloading is supported\n",
"print(a * b)"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[ 2 6]\n",
" [12 20]], shape=(2, 2), dtype=int32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y20Ic3HpUxBy",
"colab_type": "text"
},
"source": [
"## Create Tensors\n",
"\n",
"**What is a Tensor?**\n",
"Tensors are multi-dimensional arrays with a uniform type (called a `dtype`). You can see all supported dtypes at `tf.dtypes.DType`.\n",
"\n",
"If you're familiar with NumPy, tensors are (kind of) like `np.arrays`.\n",
"\n",
"All tensors are immutable like python numbers and strings: you can never update the contents of a tensor, only create a new one."
]
},
{
"cell_type": "code",
"metadata": {
"id": "j3DNdSTDUwXX",
"colab_type": "code",
"colab": {}
},
"source": [
"# a string tensor variable\n",
"string = tf.Variable(\"Tensorflow is awesome!\", tf.string)\n",
"\n",
"# a integer tensor variable\n",
"num = tf.Variable(324, tf.int32)\n",
"\n",
"# a floating tensor variable\n",
"decimal = tf.Variable(3.256, tf.float64)\n",
"\n",
"# a boolean tensor variable\n",
"boolean = tf.Variable([False, False, False, True])\n",
"\n",
"# a complex tensor variable\n",
"complex_var = tf.Variable([5 + 4j, 6 + 1j])"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "fcE_YGXRVUcF",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 100
},
"outputId": "c5b08ace-d1e4-489a-836c-898a6cc5f77f"
},
"source": [
"print(string)\n",
"\n",
"print(num)\n",
"\n",
"print(decimal)\n",
"\n",
"print(boolean)\n",
"\n",
"print(complex_var)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"<tf.Variable 'Variable:0' shape=() dtype=string, numpy=b'Tensorflow is awesome!'>\n",
"<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=324>\n",
"<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.256>\n",
"<tf.Variable 'Variable:0' shape=(4,) dtype=bool, numpy=array([False, False, False, True])>\n",
"<tf.Variable 'Variable:0' shape=(2,) dtype=complex128, numpy=array([5.+4.j, 6.+1.j])>\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GBSalmJfaD34",
"colab_type": "text"
},
"source": [
"Most tensor operations work on variables as expected, although variables cannot be reshaped."
]
},
{
"cell_type": "code",
"metadata": {
"id": "JXKfWk46aF6g",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 183
},
"outputId": "414f7f60-dc59-4ccc-b21d-de01e771e875"
},
"source": [
"my_tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]])\n",
"my_var = tf.Variable(my_tensor)\n",
"\n",
"print(\"\\nViewing Variable as Tensor: \\n\", tf.convert_to_tensor(my_var))\n",
"print(\"\\nIndex of highest value: \\n\", tf.argmax(my_var))\n",
"\n",
"# This creates a new tensor; it does not reshape the variable.\n",
"print(\"\\nCopying and reshaping: \", tf.reshape(my_var, ([1,4])))"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"\n",
"Viewing Variable as Tensor: \n",
" tf.Tensor(\n",
"[[1. 2.]\n",
" [3. 4.]], shape=(2, 2), dtype=float32)\n",
"\n",
"Index of highest value: \n",
" tf.Tensor([1 1], shape=(2,), dtype=int64)\n",
"\n",
"Copying and reshaping: tf.Tensor([[1. 2. 3. 4.]], shape=(1, 4), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zIg02e-pbESj",
"colab_type": "text"
},
"source": [
"Variables are backed by tensors. We can reassign the tensor using `tf.Variable.assign`. Calling assign does not (usually) allocate a new tensor; instead, the existing tensor's memory is reused."
]
},
{
"cell_type": "code",
"metadata": {
"id": "vo3gtTTdaFkK",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "b21c65d7-acdf-4dea-f09b-a73ee6fe417c"
},
"source": [
"a = tf.Variable([2.0, 3.0])\n",
"\n",
"# This will keep the same dtype, float32\n",
"# Assigning is successful since shape is same\n",
"a.assign([1, 2])\n",
"a"
],
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([1., 2.], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "RsMCrevCbvTy",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "c3dda58c-5121-4cf7-c373-6171b7e6d6b0"
},
"source": [
"# Not allowed as it resizes the variable: \n",
"try:\n",
" a.assign([1.0, 2.0, 3.0])\n",
"except Exception as e:\n",
" print(f\"{type(e).__name__}: {e}\")"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"ValueError: Shapes (2,) and (3,) are incompatible\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GiEA_KE4mR4a",
"colab_type": "text"
},
"source": [
"When we use a variable like a tensor in operations, we should usually operate on the backing tensor.\n",
"\n",
"Creating new variables from existing variables duplicates the backing tensors. Two variables will not share the same memory."
]
},
{
"cell_type": "code",
"metadata": {
"id": "rX0AgqoVmegJ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 100
},
"outputId": "e5dacc37-ab10-4b2e-a17a-d2d769f6128f"
},
"source": [
"a = tf.Variable([2.0, 3.0])\n",
"\n",
"# Create b based on the value of a\n",
"b = tf.Variable(a)\n",
"a.assign([5, 6])\n",
"\n",
"# values in a and b might be same but memory address is different\n",
"print(a.numpy())\n",
"print(b.numpy())\n",
"\n",
"print()\n",
"\n",
"# other versions of assign\n",
"print(a.assign_add([2,3]).numpy())\n",
"print(a.assign_sub([7,9]).numpy())\n"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": [
"[5. 6.]\n",
"[2. 3.]\n",
"\n",
"[7. 9.]\n",
"[0. 0.]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jA7CfRe6nG9-",
"colab_type": "text"
},
"source": [
"## Lifecycles, naming, and watching\n",
"\n",
"In TensorFlow, `tf.Variable` instance have the same lifecycle as other Python objects. When there are no references to a variable it is automatically deallocated.\n",
"\n",
"Variables can also be named which can help us to track and debug them. We can give two variables the same name."
]
},
{
"cell_type": "code",
"metadata": {
"id": "hx_AzCC4nV-5",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 183
},
"outputId": "f4a539c6-f4ab-40c9-b99d-38131e8f409f"
},
"source": [
"# Create a and b; they have the same value but are backed by different tensors.\n",
"a = tf.Variable(my_tensor, name=\"Mark\")\n",
"\n",
"# A new variable with the same name, but different value\n",
"# Note that the scalar operations like +, -, * or / are broadcast\n",
"b = tf.Variable(my_tensor ** 2, name=\"Mark\")\n",
"\n",
"print(a)\n",
"print(b)\n",
"\n",
"print()\n",
"\n",
"# Elementwise-unequal, despite having the same name\n",
"print(a == b)"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"<tf.Variable 'Mark:0' shape=(2, 2) dtype=float32, numpy=\n",
"array([[1., 2.],\n",
" [3., 4.]], dtype=float32)>\n",
"<tf.Variable 'Mark:0' shape=(2, 2) dtype=float32, numpy=\n",
"array([[ 1., 4.],\n",
" [ 9., 16.]], dtype=float32)>\n",
"\n",
"tf.Tensor(\n",
"[[ True False]\n",
" [False False]], shape=(2, 2), dtype=bool)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "awP_HcJ8pDBW",
"colab_type": "text"
},
"source": [
"Variable names are preserved when saving and loading models. By default, variables in models will acquire unique variable names automatically, so we don't need to assign them ourself unless we want to.\n",
"\n",
"Although variables are important for differentiation, some variables will not need to be differentiated. We can turn off gradients for a variable by setting `trainable` to false at creation. \n",
"\n",
"An example of a variable that would not need gradients is a training step counter"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rB3ADZchpN1Q",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "8e7a8557-1dba-4787-b434-b67518769068"
},
"source": [
"step_counter = tf.Variable(1, trainable=False)\n",
"step_counter"
],
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=1>"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OFruI9-VpezL",
"colab_type": "text"
},
"source": [
"## Placing variables and tensors\n",
"\n",
"For better performance, TensorFlow will attempt to place tensors and variables on the fastest device compatible with its dtype. This means most variables are placed on a GPU if one is available."
]
},
{
"cell_type": "code",
"metadata": {
"id": "VG85AHZOp5cf",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "26f13e60-5bee-456c-85d0-5dc3be496c59"
},
"source": [
"# with tf.device('CPU:0'):\n",
"with tf.device('GPU:0'):\n",
"\n",
" # Create some tensors\n",
" a = tf.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])\n",
" b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])\n",
" c = tf.matmul(a, b)\n",
"\n",
"print(c)"
],
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[22. 28.]\n",
" [49. 64.]], shape=(2, 2), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4MMymJNGqxJu",
"colab_type": "text"
},
"source": [
"However, we can override this. We can place a float tensor and a variable on the CPU, even if a GPU is available. By turning on device placement logging, we can see where the variable is placed.\n",
"\n",
"It's possible to set the location of a variable or tensor on one device and do the computation on another device. This will introduce delay, as data needs to be copied between the devices."
]
},
{
"cell_type": "code",
"metadata": {
"id": "KZql_DfPrHJb",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "816acf2e-abec-46df-f74c-201386c9212e"
},
"source": [
"with tf.device('CPU:0'):\n",
" a = tf.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])\n",
" b = tf.Variable([[1.0, 2.0, 3.0]])\n",
"\n",
"with tf.device('GPU:0'):\n",
" # Element-wise multiply\n",
" k = a * b\n",
"\n",
"print(k)"
],
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[ 1. 4. 9.]\n",
" [ 4. 10. 18.]], shape=(2, 3), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2ECFtOKmXV9H",
"colab_type": "text"
},
"source": [
"## Rank/Degree of a Tensor\n",
"\n",
"This simply means number of dimensions involved in a tensor. \n",
"\n",
"**Note:** A tensor of rank 0 is known as scalar tensor"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Kri2I1maW4K3",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "7099a509-1fdf-4990-d415-d2f7e6cc5b5c"
},
"source": [
"rank_0_tensor = tf.constant(4)\n",
"print(rank_0_tensor)"
],
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(4, shape=(), dtype=int32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "61GnKe2IX3S-",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "407cca26-a679-4fd3-e452-73b3564277b5"
},
"source": [
"rank_1_tensor = tf.constant([2.0, 3.0, 4.0])\n",
"print(rank_1_tensor)"
],
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor([2. 3. 4.], shape=(3,), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "YzRLzDmYX5n4",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 83
},
"outputId": "a5dab959-a6c2-45c6-88c3-0f1025f80a92"
},
"source": [
"rank_2_tensor = tf.constant([[1, 2],\n",
" [3, 4],\n",
" [5, 6]], dtype=tf.float16)\n",
"print(rank_2_tensor)"
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[1. 2.]\n",
" [3. 4.]\n",
" [5. 6.]], shape=(3, 2), dtype=float16)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "yVC0MrGhbpH7",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 166
},
"outputId": "9631bcfd-a63b-4afc-dfd9-ddc75ff3329a"
},
"source": [
"# There can be an arbitrary number of axes (sometimes called \"dimensions\")\n",
"rank_3_tensor = tf.constant([\n",
" [[0, 1, 2, 3, 4],\n",
" [5, 6, 7, 8, 9]],\n",
" [[10, 11, 12, 13, 14],\n",
" [15, 16, 17, 18, 19]],\n",
" [[20, 21, 22, 23, 24],\n",
" [25, 26, 27, 28, 29]],])\n",
" \n",
"print(rank_3_tensor)"
],
"execution_count": 20,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[[ 0 1 2 3 4]\n",
" [ 5 6 7 8 9]]\n",
"\n",
" [[10 11 12 13 14]\n",
" [15 16 17 18 19]]\n",
"\n",
" [[20 21 22 23 24]\n",
" [25 26 27 28 29]]], shape=(3, 2, 5), dtype=int32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zgeSDcwcr-Lp",
"colab_type": "text"
},
"source": [
"We can convert a tensor to a NumPy array either using `np.array` or the `tensor.numpy` method"
]
},
{
"cell_type": "code",
"metadata": {
"id": "l5M_Y7j7r0gf",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "49e94515-d129-4ea7-e104-923665d112b5"
},
"source": [
"np.array(rank_2_tensor)"
],
"execution_count": 21,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[1., 2.],\n",
" [3., 4.],\n",
" [5., 6.]], dtype=float16)"
]
},
"metadata": {
"tags": []
},
"execution_count": 21
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "IYUWzAhesIDY",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "2d8f9e9a-6ed2-4d58-e305-dde1539f1d68"
},
"source": [
"rank_2_tensor.numpy()"
],
"execution_count": 22,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[1., 2.],\n",
" [3., 4.],\n",
" [5., 6.]], dtype=float16)"
]
},
"metadata": {
"tags": []
},
"execution_count": 22
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9s7BICRLsZyv",
"colab_type": "text"
},
"source": [
"The base `tf.Tensor` class requires tensors to be \"rectangular\"---that is, along each axis, every element is the same size. However, there are specialized types of Tensors that can handle different shapes:\n",
"\n",
" - ragged (see RaggedTensor below)\n",
" - sparse (see SparseTensor below)\n",
"\n",
"We can do basic math on tensors, including addition, element-wise multiplication, and matrix multiplication."
]
},
{
"cell_type": "code",
"metadata": {
"id": "q8SCu8UhsSTf",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 216
},
"outputId": "9ca20654-da03-4e41-9434-de412ca26c12"
},
"source": [
"a = tf.constant([[1, 2],\n",
" [3, 4]])\n",
"b = tf.constant([[1, 1],\n",
" [1, 1]])\n",
"\n",
"print(tf.add(a, b), \"\\n\")\n",
"print(tf.multiply(a, b), \"\\n\")\n",
"print(tf.matmul(a, b), \"\\n\")"
],
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[2 3]\n",
" [4 5]], shape=(2, 2), dtype=int32) \n",
"\n",
"tf.Tensor(\n",
"[[1 2]\n",
" [3 4]], shape=(2, 2), dtype=int32) \n",
"\n",
"tf.Tensor(\n",
"[[3 3]\n",
" [7 7]], shape=(2, 2), dtype=int32) \n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "qRbavpSOsxFA",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 216
},
"outputId": "5abc35fe-6bc0-46bb-f6d2-7e15abea7f7b"
},
"source": [
"print(a + b, \"\\n\") # element-wise addition\n",
"print(a * b, \"\\n\") # element-wise multiplication\n",
"print(a @ b, \"\\n\") # matrix multiplication"
],
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[2 3]\n",
" [4 5]], shape=(2, 2), dtype=int32) \n",
"\n",
"tf.Tensor(\n",
"[[1 2]\n",
" [3 4]], shape=(2, 2), dtype=int32) \n",
"\n",
"tf.Tensor(\n",
"[[3 3]\n",
" [7 7]], shape=(2, 2), dtype=int32) \n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6mXVX2tztCQE",
"colab_type": "text"
},
"source": [
"Tensors are used in all kinds of operations."
]
},
{
"cell_type": "code",
"metadata": {
"id": "xiuhGSBJs0bU",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 100
},
"outputId": "cd38aeff-0a37-44d1-dde5-bc7e5c6375d0"
},
"source": [
"c = tf.constant([[4.0, 5.0], [10.0, 1.0]])\n",
"\n",
"# Find the largest value\n",
"print(tf.reduce_max(c))\n",
"# Find the index of the largest value\n",
"print(tf.argmax(c))\n",
"# Compute the softmax\n",
"print(tf.nn.softmax(c))"
],
"execution_count": 25,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(10.0, shape=(), dtype=float32)\n",
"tf.Tensor([1 0], shape=(2,), dtype=int64)\n",
"tf.Tensor(\n",
"[[2.6894143e-01 7.3105860e-01]\n",
" [9.9987662e-01 1.2339458e-04]], shape=(2, 2), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JfazDoZptKPt",
"colab_type": "text"
},
"source": [
"## Tensor Shapes\n",
"\n",
"Tensors have shapes. Some vocabulary:\n",
"\n",
" - `Shape`: The length (number of elements) of each of the dimensions of a tensor.\n",
" - `Rank`: Number of tensor dimensions. A scalar has rank 0, a vector has rank 1, a matrix is rank 2.\n",
" - `Axis` or `Dimension`: A particular dimension of a tensor.\n",
" - `Size`: The total number of items in the tensor, the product shape vector"
]
},
{
"cell_type": "code",
"metadata": {
"id": "NTqRarNatF4j",
"colab_type": "code",
"colab": {}
},
"source": [
"rank_4_tensor = tf.zeros([3, 2, 4, 5])"
],
"execution_count": 26,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hMpWFT1BuaVt",
"colab_type": "text"
},
"source": [
"![rank-4-tensor](https://user-images.githubusercontent.com/53504602/88804798-0b386400-d1cc-11ea-921f-10da4bee975f.png)\n",
"<center> <img src=\"https://user-images.githubusercontent.com/53504602/88805060-52bef000-d1cc-11ea-871a-56a463d387dc.png\"> </center>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "g81qP9IAtZgC",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 216
},
"outputId": "fe1f71a2-458d-4171-e130-b0e5fa19926c"
},
"source": [
"print(\"\\nType of every element:\", rank_4_tensor.dtype)\n",
"print(\"\\nNumber of dimensions:\", rank_4_tensor.ndim)\n",
"print(\"\\nShape of tensor:\", rank_4_tensor.shape)\n",
"print(\"\\nElements along axis 0 of tensor:\", rank_4_tensor.shape[0])\n",
"print(\"\\nElements along the last axis of tensor:\", rank_4_tensor.shape[-1])\n",
"print(\"\\nTotal number of elements (3*2*4*5): \", tf.size(rank_4_tensor).numpy())"
],
"execution_count": 27,
"outputs": [
{
"output_type": "stream",
"text": [
"\n",
"Type of every element: <dtype: 'float32'>\n",
"\n",
"Number of dimensions: 4\n",
"\n",
"Shape of tensor: (3, 2, 4, 5)\n",
"\n",
"Elements along axis 0 of tensor: 3\n",
"\n",
"Elements along the last axis of tensor: 5\n",
"\n",
"Total number of elements (3*2*4*5): 120\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7X1Z-ep9vTMK",
"colab_type": "text"
},
"source": [
"## Indexing\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vh4uX3RAwNON",
"colab_type": "text"
},
"source": [
"\n",
"### Single-axis indexing\n",
"\n",
"TensorFlow follow standard python indexing rules, similar to indexing a list or a string in python, and the bacic rules for numpy indexing.\n",
"\n",
" - indexes start at 0\n",
" - negative indices count backwards from the end\n",
" - colons, `:`, are used for slices `start:stop:step`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "l1VLrObctgZH",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 83
},
"outputId": "5480eaad-5454-42ed-991c-c438fcc3ca99"
},
"source": [
"rank_1_tensor = tf.constant([0, 1, 1, 2, 3, 5, 8, 13, 21, 34])\n",
"print(rank_1_tensor.numpy())\n",
"\n",
"# Indexing with a scalar removes the dimension\n",
"print(\"First:\", rank_1_tensor[0].numpy())\n",
"print(\"Second:\", rank_1_tensor[1].numpy())\n",
"print(\"Last:\", rank_1_tensor[-1].numpy())"
],
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"text": [
"[ 0 1 1 2 3 5 8 13 21 34]\n",
"First: 0\n",
"Second: 1\n",
"Last: 34\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "A7tmaCdev-RG",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 116
},
"outputId": "a699562c-72fc-412f-dc0f-0444ae504264"
},
"source": [
"# Indexing with a : slice keeps the dimension\n",
"print(\"Everything:\", rank_1_tensor[:].numpy())\n",
"print(\"Before 4:\", rank_1_tensor[:4].numpy())\n",
"print(\"From 4 to the end:\", rank_1_tensor[4:].numpy())\n",
"print(\"From 2, before 7:\", rank_1_tensor[2:7].numpy())\n",
"print(\"Every other item:\", rank_1_tensor[::2].numpy())\n",
"print(\"Reversed:\", rank_1_tensor[::-1].numpy())"
],
"execution_count": 29,
"outputs": [
{
"output_type": "stream",
"text": [
"Everything: [ 0 1 1 2 3 5 8 13 21 34]\n",
"Before 4: [0 1 1 2]\n",
"From 4 to the end: [ 3 5 8 13 21 34]\n",
"From 2, before 7: [1 2 3 5 8]\n",
"Every other item: [ 0 1 3 8 21]\n",
"Reversed: [34 21 13 8 5 3 2 1 1 0]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5YIKuBrowJGl",
"colab_type": "text"
},
"source": [
"### Multi-axis indexing\n",
"\n",
"Higher rank tensors are indexed by passing multiple indices.\n",
"\n",
"The single-axis exact same rules as in the single-axis case apply to each axis independently."
]
},
{
"cell_type": "code",
"metadata": {
"id": "HwX5M-9pwCrW",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 150
},
"outputId": "5576f636-cb34-4d34-f369-d1278b9b916d"
},
"source": [
"print(rank_3_tensor.numpy())"
],
"execution_count": 30,
"outputs": [
{
"output_type": "stream",
"text": [
"[[[ 0 1 2 3 4]\n",
" [ 5 6 7 8 9]]\n",
"\n",
" [[10 11 12 13 14]\n",
" [15 16 17 18 19]]\n",
"\n",
" [[20 21 22 23 24]\n",
" [25 26 27 28 29]]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "slTishmMwR0a",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "3be4e10f-74fa-4bd3-b657-45ef367dc1a7"
},
"source": [
"# Pull out a 2nd row from 2nd batch a 3-rank tensor\n",
"print(rank_3_tensor[1, 1].numpy())"
],
"execution_count": 31,
"outputs": [
{
"output_type": "stream",
"text": [
"[15 16 17 18 19]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "2jr2fjeYwVv5",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "a38f40ca-6fc7-49e8-e611-89f9f2b5d74a"
},
"source": [
"# Pull out a 2nd value from 2nd row from 2nd batch of a 3-rank tensor\n",
"print(rank_3_tensor[1, 1][1].numpy())"
],
"execution_count": 32,
"outputs": [
{
"output_type": "stream",
"text": [
"16\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "m3v0fiNJxLEH",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 83
},
"outputId": "63251a50-d64a-45cb-c812-712263e16972"
},
"source": [
"print(rank_2_tensor)"
],
"execution_count": 33,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[1. 2.]\n",
" [3. 4.]\n",
" [5. 6.]], shape=(3, 2), dtype=float16)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "JS5fwlqhwmHF",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 150
},
"outputId": "7be47527-86f1-434e-beb4-fb85964b39ba"
},
"source": [
"# Get row and column tensors\n",
"print(\"Second row:\", rank_2_tensor[1, :].numpy())\n",
"print(\"Second column:\", rank_2_tensor[:, 1].numpy())\n",
"print(\"Last row:\", rank_2_tensor[-1, :].numpy())\n",
"print(\"First item in last column:\", rank_2_tensor[0, -1].numpy())\n",
"print(\"Skip the first row:\")\n",
"print(rank_2_tensor[1:, :].numpy(), \"\\n\")"
],
"execution_count": 34,
"outputs": [
{
"output_type": "stream",
"text": [
"Second row: [3. 4.]\n",
"Second column: [2. 4. 6.]\n",
"Last row: [5. 6.]\n",
"First item in last column: 2.0\n",
"Skip the first row:\n",
"[[3. 4.]\n",
" [5. 6.]] \n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SEp0lPJhyK6e",
"colab_type": "text"
},
"source": [
"## Manipulating Shapes\n",
"\n",
"Reshaping a tensor is of great utility.\n",
"\n",
"The `tf.reshape` operation is fast and cheap as the underlying data does not need to be duplicated."
]
},
{
"cell_type": "code",
"metadata": {
"id": "B1HAwcRnx1dg",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "55c058fe-865a-4a6b-ee6b-902d25157ac7"
},
"source": [
"# Shape returns a `TensorShape` object that shows the size on each dimension\n",
"var_x = tf.Variable(tf.constant([[1], [2], [3]]))\n",
"print(var_x.shape)"
],
"execution_count": 35,
"outputs": [
{
"output_type": "stream",
"text": [
"(3, 1)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "UiCWod2qyWtM",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "690c342d-da09-43fc-ea9f-88d7ce0dc93a"
},
"source": [
"type(var_x)"
],
"execution_count": 36,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensorflow.python.ops.resource_variable_ops.ResourceVariable"
]
},
"metadata": {
"tags": []
},
"execution_count": 36
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "9j75D1bYybvp",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "5e0f7553-666b-4ca3-f2ea-aaf6d56de39d"
},
"source": [
"# We can convert this object into a Python list, too\n",
"print(var_x.shape.as_list())"
],
"execution_count": 37,
"outputs": [
{
"output_type": "stream",
"text": [
"[3, 1]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FHzsAtwQyvQs",
"colab_type": "text"
},
"source": [
"Reshaping is fast and cheap as the underlying data does not need to be duplicated."
]
},
{
"cell_type": "code",
"metadata": {
"id": "nqxRtSIZylMZ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "3fb1eec4-3c04-4e55-b293-a9584d6ca065"
},
"source": [
"# We can reshape a tensor to a new shape.\n",
"# Note that we're passing in a list\n",
"reshaped = tf.reshape(var_x, [1, 3])\n",
"\n",
"print(var_x.shape)\n",
"print(reshaped.shape)"
],
"execution_count": 38,
"outputs": [
{
"output_type": "stream",
"text": [
"(3, 1)\n",
"(1, 3)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3n2K3HoAy6iy",
"colab_type": "text"
},
"source": [
"The data maintains it's layout in memory and a new tensor is created, with the requested shape, pointing to the same data. TensorFlow uses C-style \"row-major\" memory ordering, where incrementing the right-most index corresponds to a single step in memory."
]
},
{
"cell_type": "code",
"metadata": {
"id": "D5fWmZlyyzD6",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 416
},
"outputId": "300943e6-0d04-4679-de40-e38f9a36c2c3"
},
"source": [
"print(rank_3_tensor)\n",
"\n",
"# A `-1` passed in the `shape` argument says \"Whatever fits\".\n",
"print(tf.reshape(rank_3_tensor, [-1]))\n",
"\n",
"print(tf.reshape(rank_3_tensor, [3*2, 5]), \"\\n\")\n",
"print(tf.reshape(rank_3_tensor, [3, -1]))"
],
"execution_count": 39,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[[ 0 1 2 3 4]\n",
" [ 5 6 7 8 9]]\n",
"\n",
" [[10 11 12 13 14]\n",
" [15 16 17 18 19]]\n",
"\n",
" [[20 21 22 23 24]\n",
" [25 26 27 28 29]]], shape=(3, 2, 5), dtype=int32)\n",
"tf.Tensor(\n",
"[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
" 24 25 26 27 28 29], shape=(30,), dtype=int32)\n",
"tf.Tensor(\n",
"[[ 0 1 2 3 4]\n",
" [ 5 6 7 8 9]\n",
" [10 11 12 13 14]\n",
" [15 16 17 18 19]\n",
" [20 21 22 23 24]\n",
" [25 26 27 28 29]], shape=(6, 5), dtype=int32) \n",
"\n",
"tf.Tensor(\n",
"[[ 0 1 2 3 4 5 6 7 8 9]\n",
" [10 11 12 13 14 15 16 17 18 19]\n",
" [20 21 22 23 24 25 26 27 28 29]], shape=(3, 10), dtype=int32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9sa_4J3yzkSq",
"colab_type": "text"
},
"source": [
"Swapping axes in `tf.reshape` does not work, you need `tf.transpose` for that. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "2c4CWu7pzPqV",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300
},
"outputId": "b0902080-330b-499f-a63f-dcf05c0c9abb"
},
"source": [
"# We can't reorder axes with reshape.\n",
"print(tf.reshape(rank_3_tensor, [2, 3, 5]), \"\\n\") \n",
"\n",
"# This is a mess\n",
"print(tf.reshape(rank_3_tensor, [5, 6]), \"\\n\")\n",
"\n",
"# This doesn't work at all\n",
"try:\n",
" tf.reshape(rank_3_tensor, [7, -1])\n",
"except Exception as e:\n",
" print(f\"{type(e).__name__}: {e}\")"
],
"execution_count": 40,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[[ 0 1 2 3 4]\n",
" [ 5 6 7 8 9]\n",
" [10 11 12 13 14]]\n",
"\n",
" [[15 16 17 18 19]\n",
" [20 21 22 23 24]\n",
" [25 26 27 28 29]]], shape=(2, 3, 5), dtype=int32) \n",
"\n",
"tf.Tensor(\n",
"[[ 0 1 2 3 4 5]\n",
" [ 6 7 8 9 10 11]\n",
" [12 13 14 15 16 17]\n",
" [18 19 20 21 22 23]\n",
" [24 25 26 27 28 29]], shape=(5, 6), dtype=int32) \n",
"\n",
"InvalidArgumentError: Input to reshape is a tensor with 30 values, but the requested shape requires a multiple of 7 [Op:Reshape]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hVEr-zzP0J4U",
"colab_type": "text"
},
"source": [
"![rank-3-tensor-reordering](https://user-images.githubusercontent.com/53504602/88807586-88190d00-d1cf-11ea-9307-cc3f8084339a.png)\n",
"\n",
"Here we see that we were able to reshape the tensor into (2, 3, 5) by swapping the axes, but couldn't maintain the ordering of values"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rvVNLCwN32DI",
"colab_type": "text"
},
"source": [
"## More on DTypes\n",
"\n",
"To inspect a `tf.Tensor`'s data type use the `Tensor.dtype` property.\n",
"\n",
"When creating a `tf.Tensor` from a Python object you may optionally specify the datatype.\n",
"\n",
"If we don't, TensorFlow chooses a datatype that can represent our data. TensorFlow converts Python integers to `tf.int32` and python floating point numbers to `tf.float32`. Otherwise TensorFlow uses the same rules NumPy uses when converting to arrays.\n",
"\n",
"We can cast from type to type."
]
},
{
"cell_type": "code",
"metadata": {
"id": "Wqz9NS7azuMn",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "fcd78d1f-62f8-486e-ec64-c581355819f1"
},
"source": [
"f64_tensor = tf.constant([2.2, 3.3, 4.4], dtype=tf.float64)\n",
"f16_tensor = tf.cast(f64_tensor, dtype=tf.float16)\n",
"\n",
"# Now, let's cast to an uint8 and lose the decimal precision\n",
"u8_tensor = tf.cast(f16_tensor, dtype=tf.uint8)\n",
"print(u8_tensor)"
],
"execution_count": 41,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor([2 3 4], shape=(3,), dtype=uint8)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FSvLaWuR4PSC",
"colab_type": "text"
},
"source": [
"## Broadcasting\n",
"\n",
"Broadcasting is a concept borrowed from the equivalent feature in NumPy. In short, under certain conditions, smaller tensors are \"stretched\" automatically to fit larger tensors when running combined operations on them.\n",
"\n",
"The simplest and most common case is when you attempt to multiply or add a tensor to a scalar. In that case, the scalar is broadcast to be the same shape as the other argument."
]
},
{
"cell_type": "code",
"metadata": {
"id": "YyKrFa5x4KEb",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "a5495ec5-ec7e-4557-c9a6-916abfb7a1c6"
},
"source": [
"x = tf.constant([1, 2, 3])\n",
"\n",
"y = tf.constant(2)\n",
"z = tf.constant([2, 2, 2])\n",
"# All of these are the same computation\n",
"print(tf.multiply(x, 2))\n",
"print(x * y)\n",
"print(x * z)"
],
"execution_count": 42,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor([2 4 6], shape=(3,), dtype=int32)\n",
"tf.Tensor([2 4 6], shape=(3,), dtype=int32)\n",
"tf.Tensor([2 4 6], shape=(3,), dtype=int32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p-_yEAfs4_Fr",
"colab_type": "text"
},
"source": [
"Likewise, 1-sized dimensions can be stretched out to match the other arguments. Both arguments can be stretched in the same computation.\n",
"\n",
"In this case a 3x1 matrix is element-wise multiplied by a 1x4 matrix to produce a 3x4 matrix. Note how the leading 1 is optional: The shape of y is `[4]`."
]
},
{
"cell_type": "code",
"metadata": {
"id": "wsOahbW-40AT",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 200
},
"outputId": "4607535d-7af9-4dda-c5ab-f7b3380d1fc2"
},
"source": [
"# These are the same computations\n",
"x = tf.reshape(x,[3,1])\n",
"y = tf.range(1, 5)\n",
"print(x, \"\\n\")\n",
"print(y, \"\\n\")\n",
"print(tf.multiply(x, y))"
],
"execution_count": 43,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[1]\n",
" [2]\n",
" [3]], shape=(3, 1), dtype=int32) \n",
"\n",
"tf.Tensor([1 2 3 4], shape=(4,), dtype=int32) \n",
"\n",
"tf.Tensor(\n",
"[[ 1 2 3 4]\n",
" [ 2 4 6 8]\n",
" [ 3 6 9 12]], shape=(3, 4), dtype=int32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ndRbyaUq5p0T",
"colab_type": "text"
},
"source": [
"![broadcasting](https://user-images.githubusercontent.com/53504602/88810398-eb586e80-d1d2-11ea-8b48-1deb5f402783.png)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "aAsaaSoi5L3B",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 83
},
"outputId": "1b117153-c83b-43e4-c387-4a9964b9afec"
},
"source": [
"x_stretch = tf.constant([[1, 1, 1, 1],\n",
" [2, 2, 2, 2],\n",
" [3, 3, 3, 3]])\n",
"\n",
"y_stretch = tf.constant([[1, 2, 3, 4],\n",
" [1, 2, 3, 4],\n",
" [1, 2, 3, 4]])\n",
"\n",
"print(x_stretch * y_stretch)"
],
"execution_count": 44,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[ 1 2 3 4]\n",
" [ 2 4 6 8]\n",
" [ 3 6 9 12]], shape=(3, 4), dtype=int32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ciVjIcjV6BzD",
"colab_type": "text"
},
"source": [
"we can also do broadcasting using `tf.broadcast_to`\n",
"\n",
"`tf.broadcast_to` does nothing special to save memory. Here, we are materializing the tensor."
]
},
{
"cell_type": "code",
"metadata": {
"id": "UwW1m04C55AA",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 83
},
"outputId": "3b09c477-a594-480d-8324-3b9aadca3996"
},
"source": [
"print(tf.broadcast_to(tf.constant([1, 2, 3]), [3, 3]))"
],
"execution_count": 45,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[1 2 3]\n",
" [1 2 3]\n",
" [1 2 3]], shape=(3, 3), dtype=int32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "El7ffoa_6iyB",
"colab_type": "text"
},
"source": [
"## `tf.convert_to_tensor`\n",
"\n",
"Most ops, like `tf.matmul` and `tf.reshape` take arguments of class `tf.Tensor`. However, in the above case, we frequently pass Python objects shaped like tensors.\n",
"\n",
"Most, but not all, operations call `convert_to_tensor` on non-tensor arguments. There is a registry of conversions, and most object classes like NumPy's `ndarray`, `TensorShape`, Python lists, and `tf.Variable` will all convert automatically.\n",
"\n",
"Use `tf.register_tensor_conversion_function`, and if we have our own type we'd like to automatically convert to a tensor."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zxATPa5G7BFs",
"colab_type": "text"
},
"source": [
"## Ragged Tensors\n",
"\n",
"A tensor with variable numbers of elements along some axis is called \"ragged\". Use `tf.ragged.RaggedTensor` for ragged data.\n",
"\n",
"For example, This cannot be represented as a regular tensor:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "aLDcaYo26OPX",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "fb22ad45-d703-4a02-ecba-9472b33f1cee"
},
"source": [
"ragged_list = [\n",
" [0, 1, 2, 3],\n",
" [4, 5],\n",
" [6, 7, 8],\n",
" [9]]\n",
"\n",
"try:\n",
" tensor = tf.constant(ragged_list)\n",
"except Exception as e:\n",
" print(f\"{type(e).__name__}: {e}\")"
],
"execution_count": 46,
"outputs": [
{
"output_type": "stream",
"text": [
"ValueError: Can't convert non-rectangular Python sequence to Tensor.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HP9Dl9bC_up6",
"colab_type": "text"
},
"source": [
"Instead create a `tf.RaggedTensor` using `tf.ragged.constant`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "N1D0m3id_qy5",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "135e8cda-396c-4573-ff72-a8e0005a6ac4"
},
"source": [
"ragged_tensor = tf.ragged.constant(ragged_list)\n",
"print(ragged_tensor)"
],
"execution_count": 47,
"outputs": [
{
"output_type": "stream",
"text": [
"<tf.RaggedTensor [[0, 1, 2, 3], [4, 5], [6, 7, 8], [9]]>\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "nVMP-yQC_ztc",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "557a70ec-7103-49da-cca2-47e526132557"
},
"source": [
"# shape of a tf.RaggedTensor contains unknown dimensions\n",
"print(ragged_tensor.shape)"
],
"execution_count": 48,
"outputs": [
{
"output_type": "stream",
"text": [
"(4, None)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "is29pqhcACFW",
"colab_type": "text"
},
"source": [
"## String tensors\n",
"\n",
"`tf.string` is a `dtype`, which is to say we can represent data as strings (variable-length byte arrays) in tensors.\n",
"\n",
"The strings are atomic and cannot be indexed the way Python strings are. The length of the string is not one of the dimensions of the tensor.\n",
"\n",
"Here is a scalar string tensor:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "6fg2Q1sb_9IQ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "3f9f1d22-ed33-4dc8-b855-337d8b23d283"
},
"source": [
"# Tensors can be strings, too here is a scalar string.\n",
"scalar_string_tensor = tf.constant(\"Gray wolf\")\n",
"print(scalar_string_tensor)"
],
"execution_count": 49,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(b'Gray wolf', shape=(), dtype=string)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "R6iDUykGANzj",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "fba08646-2d34-4383-b49d-36282aee2220"
},
"source": [
"# three string tensors of different lengths, this is OK.\n",
"tensor_of_strings = tf.constant([\"Gray wolf\",\n",
" \"Quick brown fox\",\n",
" \"Lazy dog\"])\n",
"\n",
"# Note that the shape is (3,). The string length is not included.\n",
"print(tensor_of_strings)"
],
"execution_count": 50,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor([b'Gray wolf' b'Quick brown fox' b'Lazy dog'], shape=(3,), dtype=string)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qL4lkx7KAZj0",
"colab_type": "text"
},
"source": [
"In the above printout the `b` prefix indicates that `tf.string` dtype is not a unicode string, but a byte-string.\n",
"\n",
"If we pass unicode characters they are utf-8 encoded."
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZorYTCozAUY3",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "976a2b4e-d5d1-4b5c-d03b-e1ed81f80571"
},
"source": [
"tf.constant(\"🥳👍\")"
],
"execution_count": 51,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: shape=(), dtype=string, numpy=b'\\xf0\\x9f\\xa5\\xb3\\xf0\\x9f\\x91\\x8d'>"
]
},
"metadata": {
"tags": []
},
"execution_count": 51
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cqcac_S_Alrc",
"colab_type": "text"
},
"source": [
"Some basic functions with strings can be found in `tf.strings`, including `tf.strings.split`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "XtOpHL7jAihj",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "dccd99e9-ffe2-4cf7-def2-eeebb55226e6"
},
"source": [
"# We can use split to split a string into a set of tensors\n",
"print(tf.strings.split(scalar_string_tensor, sep=\" \"))"
],
"execution_count": 52,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor([b'Gray' b'wolf'], shape=(2,), dtype=string)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Tbpvyl__Aqgl",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "497e4809-216f-4431-e784-7fa96966df9a"
},
"source": [
"# because each string might be separated into many different parts\n",
"# we get a RaggedTensor\n",
"print(tf.strings.split(tensor_of_strings))"
],
"execution_count": 53,
"outputs": [
{
"output_type": "stream",
"text": [
"<tf.RaggedTensor [[b'Gray', b'wolf'], [b'Quick', b'brown', b'fox'], [b'Lazy', b'dog']]>\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "N6M3DxJUBCEo",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "416ca18f-176e-4637-f865-cafc7cba7095"
},
"source": [
"# convert a string into a floating tensor\n",
"text = tf.constant(\"1 10 100\")\n",
"print(tf.strings.to_number(tf.strings.split(text, \" \")))"
],
"execution_count": 54,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor([ 1. 10. 100.], shape=(3,), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pJhYhDDQBVMq",
"colab_type": "text"
},
"source": [
"We can't use `tf.cast` to turn a string tensor into numbers, but we can convert it into bytes, and then into numbers."
]
},
{
"cell_type": "code",
"metadata": {
"id": "41E3rtFNBR0N",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "35ce211a-8122-4a4d-96ca-87887bdb0dd1"
},
"source": [
"byte_strings = tf.strings.bytes_split(tf.constant(\"Duck\"))\n",
"byte_ints = tf.io.decode_raw(tf.constant(\"Duck\"), tf.uint8)\n",
"print(\"Byte strings:\", byte_strings)\n",
"print(\"Bytes:\", byte_ints)"
],
"execution_count": 55,
"outputs": [
{
"output_type": "stream",
"text": [
"Byte strings: tf.Tensor([b'D' b'u' b'c' b'k'], shape=(4,), dtype=string)\n",
"Bytes: tf.Tensor([ 68 117 99 107], shape=(4,), dtype=uint8)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LO-h_5KwBkqv",
"colab_type": "text"
},
"source": [
"The `tf.string` dtype is used for all raw bytes data in TensorFlow. The `tf.io` module contains functions for converting data to and from bytes, including decoding images and parsing csv."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G-SkL7ETBq6o",
"colab_type": "text"
},
"source": [
"## Sparse tensors\n",
"\n",
"Sometimes, our data is sparse, like a very wide embedding space. TensorFlow supports `tf.sparse.SparseTensor` and related operations to store sparse data efficiently."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SHzIWWjhDqyk",
"colab_type": "text"
},
"source": [
"![sparseTensor](https://user-images.githubusercontent.com/53504602/88815483-075f0e80-d1d9-11ea-9714-b2bc7ac278f0.png)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "L0dCQJx0BeyX",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 170
},
"outputId": "c6530722-2813-4dda-b19e-4bc500921ab1"
},
"source": [
"# Sparse tensors store values by index in a memory-efficient manner\n",
"sparse_tensor = tf.sparse.SparseTensor(indices=[[0, 0], [1, 2]],\n",
" values=[1, 2],\n",
" dense_shape=[3, 4])\n",
"print(sparse_tensor, \"\\n\")\n",
"\n",
"# We can convert sparse tensors to dense\n",
"print(tf.sparse.to_dense(sparse_tensor))"
],
"execution_count": 56,
"outputs": [
{
"output_type": "stream",
"text": [
"SparseTensor(indices=tf.Tensor(\n",
"[[0 0]\n",
" [1 2]], shape=(2, 2), dtype=int64), values=tf.Tensor([1 2], shape=(2,), dtype=int32), dense_shape=tf.Tensor([3 4], shape=(2,), dtype=int64)) \n",
"\n",
"tf.Tensor(\n",
"[[1 0 0 0]\n",
" [0 0 2 0]\n",
" [0 0 0 0]], shape=(3, 4), dtype=int32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "22uRy3N5HdeB",
"colab_type": "text"
},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mHv9tN0tEFrC",
"colab_type": "text"
},
"source": [
"# <h1 align=\"center\">-----Part - 2-----</h1>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9YxabPGKHy24",
"colab_type": "text"
},
"source": [
"## Gradients and Automatic Differentiation\n",
"\n",
"[Automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) is useful for implementing machine learning algorithms such as [backpropagation](https://en.wikipedia.org/wiki/Backpropagation) for training neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JuZBm6VgIXJj",
"colab_type": "text"
},
"source": [
"### Computing gradients\n",
"\n",
"To differentiate automatically, TensorFlow needs to remember what operations happen in what order during the *forward pass*. Then, during the *backward pass*, TensorFlow traverses this list of operations in reverse order to compute gradients."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DgHi8FqsI8mQ",
"colab_type": "text"
},
"source": [
"### Gradient tapes\n",
"\n",
"TensorFlow provides the `tf.GradientTape` API for automatic differentiation; that is, computing the gradient of a computation with respect to some inputs, usually `tf.Variables`. TensorFlow \"records\" relevant operations executed inside the context of a `tf.GradientTape` onto a \"tape\". TensorFlow then uses that tape to compute the gradients of a \"recorded\" computation using reverse mode differentiation."
]
},
{
"cell_type": "code",
"metadata": {
"id": "o5ezpS7vELU_",
"colab_type": "code",
"colab": {}
},
"source": [
"x = tf.Variable(3.0)\n",
"\n",
"with tf.GradientTape() as tape:\n",
" y = x**2"
],
"execution_count": 57,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "HezrbCY1JlS3",
"colab_type": "text"
},
"source": [
"Once we've recorded some operations, `GradientTape.gradient(target, sources)` calculates the gradient of some target (often a loss) relative to some source (often the model's variables)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "2d1LfXvBJbuf",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "c33cfa7d-e3f4-4070-b9bf-85e0662f5be4"
},
"source": [
"dy_dx = tape.gradient(y, x)\n",
"\n",
"# dy/dx = 2x\n",
"# for x = 3.0, dy/dx = 6.0\n",
"dy_dx.numpy()"
],
"execution_count": 58,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"6.0"
]
},
"metadata": {
"tags": []
},
"execution_count": 58
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VYI9TwHzKEwl",
"colab_type": "text"
},
"source": [
"The above example uses scalars, but `tf.GradientTape` works as easily on any tensor:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "CArN7kiyKBC-",
"colab_type": "code",
"colab": {}
},
"source": [
"w = tf.Variable(tf.random.normal((3, 2)), name='w')\n",
"b = tf.Variable(tf.zeros(2, dtype=tf.float32), name='b')\n",
"x = [[1., 2., 3.]]\n",
"\n",
"with tf.GradientTape(persistent=True) as tape:\n",
" y = x @ w + b\n",
" loss = tf.reduce_mean(y**2)"
],
"execution_count": 59,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "2DvBVXOEKrpT",
"colab_type": "text"
},
"source": [
"To get the gradient of `y` with respect to both variables, we can pass both as sources to the gradient method. The tape is flexible about how sources are passed and will accept any nested combination of lists or dictionaries and return the gradient structured the same way (`tf.nest`)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "xD0rhGNXKqiK",
"colab_type": "code",
"colab": {}
},
"source": [
"[dl_dw, dl_db] = tape.gradient(loss, [w, b])"
],
"execution_count": 60,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "WCgCu_M7K5MI",
"colab_type": "text"
},
"source": [
"The gradient with respect to each source has the shape of the source:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "2laVs2_WKzfs",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "f6fffc4c-563c-4de0-c32c-28ecd9a89b47"
},
"source": [
"print(w.shape)\n",
"print(dl_dw.shape)"
],
"execution_count": 61,
"outputs": [
{
"output_type": "stream",
"text": [
"(3, 2)\n",
"(3, 2)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "821lABq4K8rC",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "02fe3337-1ca2-463a-c88b-ca69603cfc27"
},
"source": [
"print(b.shape)\n",
"print(dl_db.shape)"
],
"execution_count": 62,
"outputs": [
{
"output_type": "stream",
"text": [
"(2,)\n",
"(2,)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JUIrVIHSLN9z",
"colab_type": "text"
},
"source": [
"## Gradients with respect to a model\n",
"\n",
"It's common to collect `tf.Variables` into a `tf.Module` or one of its subclasses (`layers.Layer`, `keras.Model`) for checkpointing and exporting.\n",
"\n",
"In most cases, we would like to calculate gradients with respect to a model's trainable variables. Since all subclasses of `tf.Module` aggregate their variables in the `Module.trainable_variables` property, we can calculate these gradients in a few lines of code:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "my7rOIUbLFEi",
"colab_type": "code",
"colab": {}
},
"source": [
"layer = tf.keras.layers.Dense(2, activation='sigmoid')\n",
"x = tf.constant([[1., 2., 3.]])\n",
"\n",
"with tf.GradientTape() as tape:\n",
" # Forward pass\n",
" y = layer(x)\n",
" loss = tf.reduce_mean(y**2)\n",
"\n",
"# Calculate gradients with respect to every trainable variable\n",
"grad = tape.gradient(loss, layer.trainable_variables)"
],
"execution_count": 63,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dG2YYqaiLnON",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "7092a17a-114b-4f62-c876-39c3f4e3bca0"
},
"source": [
"for var, g in zip(layer.trainable_variables, grad):\n",
" print(f'{var.name}, shape: {g.shape}')"
],
"execution_count": 64,
"outputs": [
{
"output_type": "stream",
"text": [
"dense/kernel:0, shape: (3, 2)\n",
"dense/bias:0, shape: (2,)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3AbYd9ruL7of",
"colab_type": "text"
},
"source": [
"## Controlling what the tape watches\n",
"\n",
"The default behavior is to record all operations after accessing a trainable `tf.Variable`. The reasons for this are:\n",
"\n",
" - The tape needs to know which operations to record in the forward pass to calculate the gradients in the backwards pass.\n",
" - The tape holds references to intermediate outputs, so you don't want to record unnecessary operations.\n",
" - The most common use case involves calculating the gradient of a loss with respect to all a model's trainable variables.\n",
"\n",
"For example the following fails to calculate a gradient because the `tf.Tensor` is not \"watched\" by default, and the `tf.Variable` is not trainable:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "fX3Q46rOL4NH",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 83
},
"outputId": "6386fd21-c51f-4067-c383-eb815fb2ddb6"
},
"source": [
"# A trainable variable\n",
"x0 = tf.Variable(3.0, name='x0')\n",
"# Not trainable\n",
"x1 = tf.Variable(3.0, name='x1', trainable=False)\n",
"# Not a Variable: A variable + tensor returns a tensor.\n",
"x2 = tf.Variable(2.0, name='x2') + 1.0\n",
"# Not a variable\n",
"x3 = tf.constant(3.0, name='x3')\n",
"\n",
"with tf.GradientTape() as tape:\n",
" y = (x0**5 + 3*x0**4 + 2*x0**2 + 2) + (x1**2) + (x2**2)\n",
"\n",
"grad = tape.gradient(y, [x0, x1, x2, x3])\n",
"\n",
"for g in grad:\n",
" print(g)"
],
"execution_count": 65,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(741.0, shape=(), dtype=float32)\n",
"None\n",
"None\n",
"None\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rWJuOhYRMmeP",
"colab_type": "text"
},
"source": [
"`tf.GradientTape` provides hooks that give user control over what is or is not watched.\n",
"\n",
"To record gradients with respect to a `tf.Tensor`, we need to call `GradientTape.watch(x)`:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "VeNUbJUwMwit",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "68fb667d-5f4e-457f-fe25-da7761cf658f"
},
"source": [
"x = tf.constant(3.0)\n",
"with tf.GradientTape() as tape:\n",
" tape.watch(x)\n",
" y = x**2\n",
"\n",
"# dy = 2x * dx\n",
"dy_dx = tape.gradient(y, x)\n",
"print(dy_dx.numpy())"
],
"execution_count": 66,
"outputs": [
{
"output_type": "stream",
"text": [
"6.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OooqRAdEOl2v",
"colab_type": "text"
},
"source": [
"Conversely, to disable the default behavior of watching all `tf.Variables`, set `watch_accessed_variables=False` when creating the gradient tape. This calculation uses two variables, but only connects the gradient for one of the variables:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xUlEtEY1MlYL",
"colab_type": "code",
"colab": {}
},
"source": [
"x0 = tf.Variable(0.0)\n",
"x1 = tf.Variable(10.0)\n",
"\n",
"with tf.GradientTape(watch_accessed_variables=False) as tape:\n",
" tape.watch(x1)\n",
" y0 = tf.math.sin(x0)\n",
" y1 = tf.nn.softplus(x1)\n",
" y = y0 + y1\n",
" ys = tf.reduce_sum(y)"
],
"execution_count": 67,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "A8LLYndmNLtk",
"colab_type": "text"
},
"source": [
"Since `GradientTape.watch` was not called on `x0`, no gradient is computed with respect to it:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "AC-c8q5LNSrf",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "68cfa686-b54a-4b3b-fc8a-eb97b3412c48"
},
"source": [
"# dy = 2x * dx\n",
"grad = tape.gradient(ys, {'x0': x0, 'x1': x1})\n",
"\n",
"print('dy/dx0:', grad['x0'])\n",
"print('dy/dx1:', grad['x1'].numpy())"
],
"execution_count": 68,
"outputs": [
{
"output_type": "stream",
"text": [
"dy/dx0: None\n",
"dy/dx1: 0.9999546\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gIcaRPwTNqM4",
"colab_type": "text"
},
"source": [
"## Intermediate Results\n",
"\n",
"We can also request gradients of the output with respect to intermediate values computed inside the `tf.GradientTape` context."
]
},
{
"cell_type": "code",
"metadata": {
"id": "_OAl_MnwNZiL",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "62e0e941-1f14-431f-a507-6ddab89c31a0"
},
"source": [
"x = tf.constant(3.0)\n",
"\n",
"with tf.GradientTape() as tape:\n",
" tape.watch(x)\n",
" y = x * x\n",
" z = y * y\n",
"\n",
"# Use the tape to compute the gradient of z w.r.t y\n",
"# dz/dy = 2 * y, where y = x ** 2\n",
"print(tape.gradient(z, y).numpy())"
],
"execution_count": 69,
"outputs": [
{
"output_type": "stream",
"text": [
"18.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BlYvb2YxPO63",
"colab_type": "text"
},
"source": [
"The resources held by a `GradientTape` are released as soon as `GradientTape.gradient()` method is called. To compute multiple gradients over the same computation, create a `persistent` gradient tape. This allows multiple calls to the `gradient()` method as resources are released when the tape object is garbage collected."
]
},
{
"cell_type": "code",
"metadata": {
"id": "FX1IW0DSN8M1",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "a6968ebd-cc38-424c-a427-4ca46c47c5ce"
},
"source": [
"x = tf.constant([1, 3.0])\n",
"# when gradient is not persistent\n",
"with tf.GradientTape(persistent=False) as tape:\n",
" tape.watch(x)\n",
" y = x * x\n",
" z = y * y\n",
"\n",
"try:\n",
" print(tape.gradient(z, x).numpy()) # 108.0 (4 * x**3 at x = 3)\n",
" print(tape.gradient(y, x).numpy()) # 6.0 (2 * x)\n",
"except Exception as e:\n",
" print(f\"{type(e).__name__}: {e}\")\n"
],
"execution_count": 70,
"outputs": [
{
"output_type": "stream",
"text": [
"[ 4. 108.]\n",
"RuntimeError: GradientTape.gradient can only be called once on non-persistent tapes.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3mu4uKBLPgVL",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "c19c2340-86ab-44d1-f29a-89b648be024d"
},
"source": [
"x = tf.constant([1, 3.0])\n",
"# when gradient is persistent\n",
"with tf.GradientTape(persistent=True) as tape:\n",
" tape.watch(x)\n",
" y = x * x\n",
" z = y * y\n",
"\n",
"try:\n",
" print(tape.gradient(z, x).numpy()) # 108.0 (4 * x**3 at x = 3)\n",
" print(tape.gradient(y, x).numpy()) # 6.0 (2 * x)\n",
"except Exception as e:\n",
" print(f\"{type(e).__name__}: {e}\")"
],
"execution_count": 71,
"outputs": [
{
"output_type": "stream",
"text": [
"[ 4. 108.]\n",
"[2. 6.]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "R6HlvMcVP95s",
"colab_type": "code",
"colab": {}
},
"source": [
"del tape # Drop the reference to the tape"
],
"execution_count": 72,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "WcQlN74wQShH",
"colab_type": "text"
},
"source": [
"## Gradients of non-scalar targets\n",
"\n",
"A gradient is fundamentally an operation on a scalar."
]
},
{
"cell_type": "code",
"metadata": {
"id": "kEP0NQ_SQEae",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "3aadf23e-2b0b-49af-845d-248ab9456dca"
},
"source": [
"x = tf.Variable(2.0)\n",
"with tf.GradientTape(persistent=True) as tape:\n",
" y0 = x**2\n",
" y1 = 1 / x\n",
"\n",
"print(tape.gradient(y0, x).numpy())\n",
"print(tape.gradient(y1, x).numpy())"
],
"execution_count": 73,
"outputs": [
{
"output_type": "stream",
"text": [
"4.0\n",
"-0.25\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xnEcROy-RUA3",
"colab_type": "text"
},
"source": [
"Thus, when we are asked for the gradient of multiple targets, the result for each source is:\n",
"\n",
" - The gradient of the sum of the targets, or equivalently\n",
" - The sum of the gradients of each target.\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "1Inb-6lyROD6",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "96e1379f-f67e-422f-cd9d-eb1ddf30931a"
},
"source": [
"x = tf.Variable(2.0)\n",
"with tf.GradientTape() as tape:\n",
" y0 = x**2\n",
" y1 = 1 / x\n",
"\n",
"print(tape.gradient({'y0': y0, 'y1': y1}, x).numpy())"
],
"execution_count": 74,
"outputs": [
{
"output_type": "stream",
"text": [
"3.75\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "B8U0U3lwRdH7",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "a3585a04-c9f3-49d7-933c-4b6f74fea36f"
},
"source": [
"x = tf.Variable(2.)\n",
"\n",
"with tf.GradientTape() as tape:\n",
" y = x * [3., 4.]\n",
"\n",
"print(tape.gradient(y, x).numpy())"
],
"execution_count": 75,
"outputs": [
{
"output_type": "stream",
"text": [
"7.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A53bVdyiR3fW",
"colab_type": "text"
},
"source": [
"This makes it simple to take the gradient of the sum of a collection of losses, or the gradient of the sum of an element-wise loss calculation.\n",
"\n",
"For an element-wise calculation, the gradient of the sum gives the derivative of each element with respect to its input-element, since each element is independent:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "1JWD8uEBR08t",
"colab_type": "code",
"colab": {}
},
"source": [
"x = tf.linspace(-10.0, 10.0, 200+1)"
],
"execution_count": 76,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8j9xC7S2SIIA",
"colab_type": "code",
"colab": {}
},
"source": [
"with tf.GradientTape() as tape:\n",
" tape.watch(x)\n",
" y = tf.nn.sigmoid(x)\n",
"\n",
"dy_dx = tape.gradient(y, x)"
],
"execution_count": 77,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5zGGoChOSO78",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 278
},
"outputId": "022eb0f2-5b75-4f86-f77c-13936060b520"
},
"source": [
"plt.plot(x, y, label='y')\n",
"plt.plot(x, dy_dx, label='dy/dx')\n",
"plt.legend()\n",
"_ = plt.xlabel('x')"
],
"execution_count": 78,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEGCAYAAAB1iW6ZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxU9b3/8dcnk41AWANhCRhURBZFVperVEEBsaJoraCt/dW61Ft7a3d7banWent7b9t7a2ttrVZrK+LuxYoCVq1WRVkEJCCyyBIgAQKEQLZZvr8/zoBDnMAAMzmTyfv5eMwjM+d8Z+YzZybvnHznnO/XnHOIiEjrl+V3ASIikhwKdBGRDKFAFxHJEAp0EZEMoUAXEckQ2X49cVFRkSstLfXr6UVEWqXFixfvdM51j7fOt0AvLS1l0aJFfj29iEirZGYbm1unLhcRkQyhQBcRyRAKdBGRDOFbH3o8wWCQ8vJy6uvr/S6lxeXn51NSUkJOTo7fpYhIK5VWgV5eXk5hYSGlpaWYmd/ltBjnHFVVVZSXl9O/f3+/yxGRVuqIXS5m9icz225mK5pZb2Z2r5mtNbPlZjbiWIupr6+nW7dubSrMAcyMbt26tcn/TEQkeRLpQ38EmHSY9RcDA6KXm4D7j6egthbmB7TV1y0iyXPELhfn3BtmVnqYJpcBjzpvHN4FZtbZzHo557YlqUYRyVDOORpCEeqDYeqDEYLhCKGII3TwpyMU+eR6OOIIRiKEY5aHI46Ic0Qi4ICIc+DA4Yg4cM5b5qLPd/C21+zQZcSsi/48WOshdccud3GXN71P7Mrxg4oZ1rfzcW+/ppLRh94H2Bxzuzy67FOBbmY34e3F069fvyQ8tYj4xTnH7togO2oa2F5Tz679jeytD7G3LkhNfYi99cFDrtc1hg8Gd33ok+ttyYF/xHt0zE/bQE+Yc+4B4AGAUaNGaWYNkTQWjji27K5jQ9V+Nu2q9S5VtWyrrmN7TQM79zUQDMf/Nc4NZNGxXTYd83MobJdDx/xsunfIIz8nQH5OVvRn4JPb2d71nICRE8gikGXkBIxAVhbZWUZ2wKLLouuyPmmTlWUEzDCDrGhiHrh+8Cdet+ahtz+97MB9zMCIXo95XbFdo4cuj9+mpSUj0LcAfWNul0SXtTozZsyga9eu3HbbbQDccccd9OjRg2984xs+VyaSWo2hCB9s2cOKLXv5sGIvK7fV8FFFDXXB8ME2udlZ9O3Sjt6d23Fyj0J6dMyje4e8gz+7dcilY7scOubnkJ8T8PHVtF3JCPTZwK1mNgs4E6hORv/5XS+UsXLr3uMuLtbg3h358aVDml1//fXXc8UVV3DbbbcRiUSYNWsW7733XlJrEEkH9cEwizfu5t2Pd/Hex1W8v2kPDSGv+6NzQQ6DenZk2pi+DCwupLSoPSd0K6C4MJ+sLH15n86OGOhm9jhwPlBkZuXAj4EcAOfc74E5wGRgLVALfDlVxaZaaWkp3bp14/3336eyspLhw4fTrVs3v8sSSYrquiCvrKxk/spK3lizg9rGMFnm7ehce+YJjOnflTP6dqa4Y56OumqlEjnKZfoR1jvga0mrKOpwe9KpdMMNN/DII49QUVHB9ddf70sNIskSCkd4c81Onl5SzvyVlTSGIhR3zGPq8D6MO7UHo/t3pWO+zk7OFGl1pmg6mDp1KjNmzCAYDDJz5ky/yxE5JjX1QZ5cVM7Db31M+e46uhTkcM2Yflw+vA/DSjppDzxDKdCbyM3N5YILLqBz584EAvpiR1qXvfVB/vjGeh55awM1DSHGlHbljsmDGD+omNxsjcWX6RToTUQiERYsWMBTTz3ldykiCWsIhfnLOxu577W17K4NMvm0ntw89qSUHOss6UuBHmPlypV89rOfZerUqQwYMMDvckQS8u76Kn7w7Aes37mf8wYU8b2Jp3JaSSe/yxIfKNBjDB48mPXr1/tdhkhC9tYH+c+XPmTmu5vo27Udj3x5NOcP7OF3WeIjBbpIK/RBeTW3PLaYrXvquPG8/nzzolMoyNWvc1unT4BIK+Kc4/H3NnPn7DKKOuTy1FfPYeQJXfwuS9KEAl2klQiFI/zw+RXMWriZ8wYU8etpw+naPtfvsiSNKNBFWoG6xjBff3wJr6zazq0XnMw3LzqFgE7DlyZ0YOoR3HnnnfziF784bJtZs2Zxzz33fGp5aWkpO3fuTFVp0kZU1wb54kPv8vcPt3P35UP5zsSBCnOJS4GeBC+99BKTJh1uUieRY7O3Psg1Dy5gWfkefjt9BF886wS/S5I0pkCP45577uGUU07h3HPPZfXq1YTDYUaM+GSq1DVr1hy87Zxj6dKljBgxgqqqKiZMmMCQIUO44YYbcNEZShYuXMjpp59OfX09+/fvZ8iQIaxYEXeKVpGD6hrD3PDIIlZX1PDAdaO45PRefpckaS59+9Bfuh0qPkjuY/Y8DS7+z8M2Wbx4MbNmzWLp0qWEQiFGjBjByJEj6dSpE0uXLuWMM87g4Ycf5stf9gaVfP/99xk2bBhmxl133cW5557LjBkzePHFF3nooYcAGD16NFOmTOGHP/whdXV1fOELX2Do0KHJfW2SUYLhCF+buYSFG3dx77ThXKDjyyUB6RvoPnnzzTeZOnUqBQUFAEyZMgXwRmF8+OGH+dWvfsUTTzxxcJz0l19+mYsvvhiAN954g2effRaASy65hC5dPjmcbMaMGYwePZr8/HzuvffelnxJ0so45/j+M8t59cPt3DN1KJcO6+13SdJKpG+gH2FPuqVdeeWV3HXXXYwbN46RI0ceHCd93rx5PPPMM0e8f1VVFfv27SMYDFJfX0/79u1TXbK0Un96awPPLtnCNy88hWvPVJ+5JE596E2MHTuW559/nrq6OmpqanjhhRcAyM/PZ+LEidxyyy0Hu1uqq6sJhUIHw33s2LEHh9x96aWX2L1798HHvfnmm7n77ru59tpr+f73v9/Cr0pai3fWVfEfc1YxcUgx/zb+ZL/LkVYmfffQfTJixAiuvvpqhg0bRo8ePRg9evTBdddeey3PPfccEyZMAGD+/PlceOGFB9f/+Mc/Zvr06QwZMoRzzjmHfv36AfDoo4+Sk5PDNddcQzgc5pxzzuHVV19l3LhxLfviJK1tq67j1plLKO1WwC+uGqYxy+Wo2YEjMVraqFGj3KJFiw5ZtmrVKgYNGuRLPYn4xS9+QXV1NXfffTfg9avfcMMNnHXWWUl5/HR//ZI6oXCEq/7wDmsq9/H81/6Fk3t08LskSVNmttg5NyreOu2hJ2jq1KmsW7eOV1999eCyBx980MeKJJP84Y31vL9pD/dOH64wl2OmQE/Qc88953cJkqFWbt3L/77yEZec3ospOqJFjkPafSnqVxeQ39rq627rGkMRvv3UMjq1y+Xuy3RughyftAr0/Px8qqqq2ly4OeeoqqoiPz/f71Kkhf3m1TWs2raXn11xmkZOlOOWVl0uJSUllJeXs2PHDr9LaXH5+fmUlJT4XYa0oLXba7j/9XVcMaIPFw0u9rscyQBpFeg5OTn079/f7zJEUs45x10vrKQgN8Adk3VkkyRHWnW5iLQVc8sqeXPNTr510Sl065DndzmSIRToIi2sPhjmpy+uZGBxIV/QcLiSRGnV5SLSFvzhH+sp313H4zeeRXZA+1SSPPo0ibSg7TX13P+PtVxyWi/OPqmb3+VIhlGgi7Sg3722jmDY8d2JA/0uRTKQAl2khWzdU8fMdzdx1cgSSos0fLIknwJdpIX85tU1AHx9/ACfK5FMlVCgm9kkM1ttZmvN7PY46/uZ2Wtm9r6ZLTezyckvVaT12rBzP08uKmf6mL706dzO73IkQx0x0M0sANwHXAwMBqab2eAmzX4IPOmcGw5MA36X7EJFWrN7/76G7Czjaxdo0gpJnUT20McAa51z651zjcAs4LImbRzQMXq9E7A1eSWKtG6bqmp5fukWrjv7BHp01Hg9kjqJBHofYHPM7fLoslh3Al8ws3JgDvD1eA9kZjeZ2SIzW9QWx2uRtunBf64nkGXccN6JfpciGS5ZX4pOBx5xzpUAk4G/mNmnHts594BzbpRzblT37t2T9NQi6WvX/kaeXLSZy8/oQ7H2ziXFEgn0LUDfmNsl0WWxvgI8CeCcewfIB4qSUaBIa/boOxuoD0a4aaz2ziX1Egn0hcAAM+tvZrl4X3rObtJmEzAewMwG4QW6+lSkTatrDPPoOxsZd2oPBhQX+l2OtAFHDHTnXAi4FZgLrMI7mqXMzH5iZlOizb4N3Ghmy4DHgf/n2tosFSJNPL2knF37G7V3Li0mocG5nHNz8L7sjF02I+b6SuBfkluaSOsVjjgeenM9w0o6cWb/rn6XI22EzhQVSYE3PtrBhqpavnLeiZiZ3+VIG6FAF0mBvy7YSFGHPCYN6el3KdKGKNBFkqx8dy2vrt7O1aNLyM3Wr5i0HH3aRJLs8fc2YcD0Mf38LkXaGAW6SBI1hiI8sXAz407tQUmXAr/LkTZGgS6SRHPLKti5r5FrNVeo+ECBLpJEf12wkb5d2/GZARraQlqeAl0kSdbv2Me7H+/imjEnkJWlQxWl5SnQRZLkmSXlBLKMK0c0HYxUpGUo0EWSIBxxPLtkC2MHFGnMc/GNAl0kCd5et5Nt1fV8bmTfIzcWSREFukgSPL24nE7tchg/qIffpUgbpkAXOU5764O8vKKCKcN6k58T8LscacMU6CLH6cXl22gIRfjcyBK/S5E2ToEucpyeXlzOgB4dOL2kk9+lSBunQBc5Dht27mfxxt1cObJEw+SK7xToIsfhhWVbAZgyrLfPlYgo0EWOmXOO2cu2Mqa0K707t/O7HBEFusix+rCihjXb93HpGdo7l/SgQBc5RrOXbSWQZUweqlmJJD0o0EWOgXOOF5Zt5dyTi+jWIc/vckQABbrIMVmyaQ/lu+v0ZaikFQW6yDF4YdlW8rKzmDCk2O9SRA5SoIscpVA4wt+Wb2PcqT0ozM/xuxyRgxToIkdp4Ybd7NzXwKXqbpE0o0AXOUpzyyrIy87i/IGaZk7SiwJd5Cg455i/spLzBnSnIDfb73JEDqFAFzkKZVv3smVPnb4MlbSkQBc5CvPKKsgyuHCQAl3SjwJd5CjMLatkdGlXurbP9bsUkU9JKNDNbJKZrTaztWZ2ezNtPm9mK82szMxmJrdMEf9t2Lmf1ZU1TByiU/0lPR3xWx0zCwD3ARcB5cBCM5vtnFsZ02YA8APgX5xzu81MEytKxpm3sgJA/eeSthLZQx8DrHXOrXfONQKzgMuatLkRuM85txvAObc9uWWK+G9uWSVDenekpEuB36WIxJVIoPcBNsfcLo8ui3UKcIqZvWVmC8xsUrwHMrObzGyRmS3asWPHsVUs4oPtNfUs2bRb3S2S1pL1pWg2MAA4H5gO/NHMOjdt5Jx7wDk3yjk3qnt3nZQhrccrK7fjnLpbJL0lEuhbgL4xt0uiy2KVA7Odc0Hn3MfAR3gBL5IR5pZVcEK3AgYWF/pdikizEgn0hcAAM+tvZrnANGB2kzbP4+2dY2ZFeF0w65NYp4hvauqDvL1uJxMGF2siaElrRwx051wIuBWYC6wCnnTOlZnZT8xsSrTZXKDKzFYCrwHfdc5VpapokZb02uodBMNO/eeS9hIajMI5NweY02TZjJjrDvhW9CKSUeaVVVDUIY/h/br4XYrIYelMUZHDaAiFeX31Di4a3INAlrpbJL0p0EUO4+11VexrCDFB3S3SCijQRQ5jXlkFHfKyOeekbn6XInJECnSRZoQj3tjn5w/sTl52wO9yRI5IgS7SjPc37WbnvkZ1t0iroUAXacbcsgpyA1lcoKnmpJVQoIvE4Zxj3spKzjm5G4X5OX6XI5IQBbpIHKsra9hYVcuEwepukdZDgS4Sx9wVlZjBhYM1tL+0Hgp0kTjmraxgRL8u9CjM97sUkYQp0EWa2LyrlrKte5mooXKllVGgizQxf2UlgPrPpdVRoIs0MbesgoHFhZQWtfe7FJGjokAXiVG1r4GFG3apu0VaJQW6SIy/f7idiENnh0qrpEAXiTGvrII+ndsxpHdHv0sROWoKdJGo/Q0h3lizk4s01Zy0Ugp0kag3PtpBYyiiqeak1VKgi0TNLaugS0EOo0s11Zy0Tgp0ESAYjvD3D7czflAx2QH9WkjrpE+uCLBgfRU19SF1t0irpkAXAeaVVdIuJ8B5A4r8LkXkmCnQpc2LRBzzVlbwmVO6k5+jqeak9VKgS5u3fEs1lXsbmKCzQ6WVU6BLmze3rIJAljH+VAW6tG4KdGnz5pZVcNaJXelUoKnmpHVToEubtnb7Ptbv2K+jWyQjKNClTZtbVgHARYPV3SKtnwJd2rR5ZRUMK+lEr07t/C5F5Lgp0KXN2rKnjmXl1Uwcqu4WyQwKdGmzXl7hdbdcPLSXz5WIJEdCgW5mk8xstZmtNbPbD9PuSjNzZjYqeSWKpMbLK7Zxas9C+muqOckQRwx0MwsA9wEXA4OB6WY2OE67QuAbwLvJLlIk2bbX1LNo424mqbtFMkgie+hjgLXOufXOuUZgFnBZnHZ3Az8H6pNYn0hKzC2rxDl1t0hmSSTQ+wCbY26XR5cdZGYjgL7OuRcP90BmdpOZLTKzRTt27DjqYkWS5eUV2zixqD2nFHfwuxSRpDnuL0XNLAv4FfDtI7V1zj3gnBvlnBvVvXv3431qkWOye38jC9bvYtLQnppqTjJKIoG+Begbc7skuuyAQmAo8LqZbQDOAmbri1FJV/NXVhKOOHW3SMZJJNAXAgPMrL+Z5QLTgNkHVjrnqp1zRc65UudcKbAAmOKcW5SSikWO00srtlHSpR1D+3T0uxSRpDpioDvnQsCtwFxgFfCkc67MzH5iZlNSXaBIMu2tD/LPtTuZNETdLZJ5shNp5JybA8xpsmxGM23PP/6yRFLj1VXbCYYdF5+mwxUl8+hMUWlTXlqxjeKOeQzv28XvUkSSToEubUZtY4h/fLSDiUN6kpWl7hbJPAp0aTNeX72D+mBEZ4dKxlKgS5vxwrKtFHXIZUxpV79LEUkJBbq0CTX1Qf7+4XYuOa0X2QF97CUz6ZMtbcK8skoaQxGmnNHb71JEUkaBLm3C7GVb6dO5HSP66egWyVwKdMl4Vfsa+OfanVw6rLdOJpKMpkCXjDdnRQXhiGPKMHW3SGZToEvGe2HZVk7u0YFBvQr9LkUkpRToktG2VdexcMMuLj1d3S2S+RToktFmL92Kc+joFmkTFOiSsZxzPLOknOH9OmsiaGkTFOiSsT7YUs1Hlfv43MgSv0sRaREKdMlYTy8uJzc7i8+eru4WaRsU6JKRGkJhZi/bysQhPenULsfvckRahAJdMtKrq7azpzao7hZpUxTokpGeXlxOccc8zj25yO9SRFqMAl0yzvaael7/aAdXjCghoIkspA1RoEvGeW7JFsIRx5Uj1N0ibYsCXTJKJOKY+d4mRpd24eQeHfwuR6RFKdAlo/xz7U42VtXyhbNO8LsUkRaX7XcBIsn01wUb6dY+99jmDW2ogT2boH4v5HeCLidArs4wldZDgS4ZY1t1Ha+squSmsSeRlx1I7E61u2DpTFjxNGxbBi7yybqsbOg9HE67Ck6/Gtp1Tk3hIkmiQJeM8fi7m3DAtWf2O3LjYB28dS+8fS807oM+I2Hsd6HHIMjrCPV7oHIlrJ0PL30PXr0HzvsWnHULZOel/LWIHAsFumSEYDjCrIWbOf+U7vTtWnD4xuWL4bmboWoNDLoUzv8BFA/5dLuhV8L4H8HWpfD6z+CVH8MHT8HU30PP01LzQkSOg74UlYwwf2Ul22sajvxl6KKH4U8TIVgLX3wOrv5r/DCP1fsMuOYJmP4E7NsOD14Iy59MXvEiSaJAl4zw0D8/pqRLO84f2CN+g0gE5t4Bf7sNTvwM3PIWnDTu6J5k4CS45W3oMwqevRH+8d/HX7hIEinQpdVbvHEXizfu5ivn9o9/ZmgkDLNvhXd+C2NugmuehHZdju3JOnSH656H06fBaz+Fv/8EnDu+FyCSJOpDl1bvD/9YT6d2OXx+VN9Pr3QO/vZNWPqY11f+me/D8U5FF8iBy+/3vhx985cQaoAJPz3+xxU5TgntoZvZJDNbbWZrzez2OOu/ZWYrzWy5mf3dzHRWh7SI9Tv2MX9VJdedfQLt85rsnzgH834IS/4M530bzr89eaGblQWX/hrG3Ozt+b9yZ3IeV+Q4HHEP3cwCwH3ARUA5sNDMZjvnVsY0ex8Y5ZyrNbNbgP8Crk5FwSKx/vjmx+QEsrju7NJPr/zHz72wPfOrMO5HyX9yM7j45xAJwVv/Cx17w5k3J/95RBKUyB76GGCtc269c64RmAVcFtvAOfeac642enMBoFGRJOV21DTwzJJyrhxRQvfCJseGL3vCO9TwjGth4s9S1x1iBpP/GwZeAi99H1b+X2qeRyQBiQR6H2BzzO3y6LLmfAV4Kd4KM7vJzBaZ2aIdO3YkXqVIHI+8/THBcIQbz+t/6IryxTD761B6ntctkpXi7/6zAnDlg1AyGp65ETYtSO3ziTQjqZ90M/sCMAqIezyXc+4B59wo59yo7t27J/OppY2p2tfAI29tYPLQXpzYPWZUxb3bYNY1UNgTrvqz9wVmS8gt8I5V79QHnvgC7Nl85PuIJFkigb4FiD18oCS67BBmdiFwBzDFOdeQnPJE4vvDG+upC4b55kUDPlkYrPPCvHEfTJ8F7bu1bFEFXb3nDdZH66g98n1EkiiRQF8IDDCz/maWC0wDZsc2MLPhwB/wwnx78ssU+cT2vfX8+e0NXD68Dyf3KPQWOgcvfAO2LoErHoDiwf4U132g1/1S8YF37LuOUZcWdMRAd86FgFuBucAq4EnnXJmZ/cTMpkSb/TfQAXjKzJaa2exmHk7kuN332lrCEcdt40/5ZOHb98LyJ+CCH8Kpl/hXHHhnlI7/Eax4Bv75P/7WIm1KQicWOefmAHOaLJsRc/3CJNclElf57lpmvreJz4/uS79u0UG4PpoH838Mgy+Hsd/xt8ADzv0WVJZ5Z5L2GOyFvEiK6dR/aVX+Z/4azIyvjzvZW7DjI3jmK97oh5f/Ln3O1jSDKb+FXqfDMzfAjtV+VyRtgAJdWo0lm3bzzJJyvvwvpfTq1A7qdsPj07xT8KfNTL/ZhXILvLpy8r0663b7XZFkOAW6tAqRiOPO2WX0KMzj6+MGQDgET33ZmzLu6r9C5zjjuKSDTiVefXs2w9PXe3WLpIgCXVqFpxZvZnl5Nf8+eRAd8rJh/gxY/xp89n+g31l+l3d4/c6CS34J6171JskQSRGNtihpr7ouyH+9vJrRpV247IzesPjPsOA+b4yWEV/0u7zEjPwSVK7wxpYpHgpnTPe7IslA2kOXtPfLeavZXdvInVOGYBvehBe/BSeNhwn3+F3a0Zn4H9B/LLzwb7B5od/VSAZSoEtae3vdTh59ZyPXnV3KkLyd8MQXodvJcNXDEGhl/2AGcrzhCDr2hieuhb1b/a5IMowCXdJWTX2Q7z61nP5F7fn+Z4ph5ue9gbCmz4L8Tn6Xd2wKusK0x6FxP8y6VsMDSFIp0CVt3fPiKrZV1/GLKwfT7vkDR7Q8Bl37H/nO6ax4sDc8wdb3vblJI2G/K5IMoUCXtPTah9uZtXAzN489kZHL74KP34BL74UTzva7tOQ49RJvcowP/+aNo64xXyQJWlknpLQFW/bU8e2nljGwuJBvB2bB+3+Fsd/LvCNDzrwZqsu9cWg6lcC5t/ldkbRyCnRJK/XBMLf8dTHBUITHhrxH9tv/CyO/DBf8u9+lpcaFd8HeLd7x6e26eIc3ihwjBbqkDeccP3p+BcvLq/nbuRsoevtub8CtS36ZPmO0JFtWFlx+P9RXe8P/5rSD0z/vd1XSSqkPXdLGX9/dxFOLy/n9kFUMXXQHnHiB9+VhVsDv0lIrO88bHqD0XHjuq5qXVI6ZAl3SwssrKvjx/61gRu9FTFz3UzjpApj+uBd2bUFOO+9wzJJR3pgvZc/5XZG0Qgp08d0/1+zk3x5/n+8VvcX1u36FnTzeO1Y7p53fpbWsvA5w7VPQJxrqi//sd0XSyijQxVfvb9rNTX9ZyF0dnuWrNffBgIneseY5+X6X5o/8TvDF5+Ckcd4QAW/92u+KpBVRoItvFqyv4vqH3uZXOb9nesOTMPyLMK0Nh/kBuQXefyiDL/dGlXzxOxAO+l2VtAI6ykV88fKKCu6c9ToP5/2WM8IfwAV3wNjvZu7RLEcrOxc+9yeYX+KN0LjzI7jqEW/oAJFmaA9dWtxj727koZkzeTH33xlma2DqH+Az31OYN5UVgIn3wGW/g03vwIPjoeIDv6uSNKZAlxZTHwxz+1NLWT/75zye+1O6dOqIfWU+DJvmd2npbfi18KW/eQN5/XEcLLhfQwVIXAp0aREbq/bzr795hqkf3MyPch4jMHAiWTe97k2iLEfW70y45S1vHPiXb/dGnty7ze+qJM2oD11SKhxxPPb2Osrn/Yb7sh4nJy8HJt+HnXGtuliOVvsi79j8hQ/CvB/CfWNg/AwYdX3mn3wlCVGgS8p8VFnDzMf/zPRd93NdVjn1J1xA9hW/9QaikmNjBmNu9A5rfPHbMOc7sHQmTPpPby9e2jQFuiRd5d56npz9AoM/+h13Zi1hf4e+uM/+hfxBl2qvPFm6neQdr77iGZj77/CnCTBwMoz7kTfeurRJ5nz6cmXUqFFu0aJFvjy3pMb26jpenjubfmX3c769T12gkMg536D92K/r2PJUatzvfVH61q+hoQYGXQrnfB36jvG7MkkBM1vsnBsVd50CXY7Xyk0VfPDSQwzZ+hRD7WP2BToSGvM1On/mXyG/o9/ltR21u+Dt38Cih7zRG0vGeN0zgy5te8MoZDAFuiTdzr37WfTa8wTKnuXMhrfoaHVU5p9I9lk30u3s67xxScQfDftg6WOw4HewewPkdYKhV8Bpn4N+Z+sL1FZOgS7HzTnHhvItrF/wAtnrX2FI7XsU2V72WwHbeo6n5/k30uGUseojTyeRCGx8y5vxaeX/QagOCopg4MUwYII3XK/OPG11FOhy1CIRx8frVrP1g9dwG9+hZ0rE928AAAsoSURBVPUyTnIbCZhjrxWypds5dBp5Jb1HXab+8dagYR+sfQVWvQAfzYXGGsCgeCj0P88L915nQMfe+qOc5hTo0iznHDt3VLJ94yqqNy4jUrGSwurV9Gn8mCKrBmA/+WwuGEJjr9EUj5hM8aBz9W97axYOwpYl3sTbG96Aze9BqN5bV9ANep4OvYZBj0HQ7WToeqL25NPIcQe6mU0Cfg0EgAedc//ZZH0e8CgwEqgCrnbObTjcYyrQUy8cDrN7ZwXVO8rZt3ML9Xu2Ea6uwPZXkrNvK53qt1AcrqCj1R68Tx25bM05gb0dT4Few+g19HyKB4zAAjk+vhJJqVADbF0KFcth2zLvsn0VRGJGeGzXBbqe5O3BF/aCjr2gsDcU9vRuF3SF/M4Q0JHQqXa4QD/i1jezAHAfcBFQDiw0s9nOuZUxzb4C7HbOnWxm04CfA1cff+mZIRIOEwoFiYRDhEJBwqEQ4VAjkVCIUDhIJBQiHA4SCYeJhIOEg42EGusINdQSbqwn3FhLpLHOuwTrIViHC9VDsJ6sUC1ZjXvJDtaQF9pHXmgfBZF9tHe1dKCWInMUNamn1uVRFShiT14fPuwwHOtSSn7xSRSfdAbd+w7kJP1Sti3Zed5JSbEnJoUavS9Uq9bCrnVQtQ52rYcdH8L616Fhb/zHyi2Edp29cG8XveR18o6yyS2AnALv+sGf0evZ+RDIgawc749CVo53O5ALWdnx11lWzEXdRJDYiUVjgLXOufUAZjYLuAyIDfTLgDuj158Gfmtm5lLQn7Pw2V/TY8UDABgOi3kKwwHOW35wqdfmkNvRS+z94t1u7j4W87gHbsd7jABhsomQZY7cJLz2phpdgHrLY7+1pzarAw2BDtTk92JXTiGR3EJcXkesfRG5nXtT0LUXhUUldOlRQkGHThSY0TcFNUmGyM6F7qd4l3ga9kHNtuilAup2Q90eqN9z6PWda73wD9ZBsPaTrp1UiA14rEngR0M/7vUD7Q/8xtshPw5d1qRNossO+YNjcP73YeiVx/Vy40kk0PsAm2NulwNNzzE+2MY5FzKzaqAbsDO2kZndBNwE0K9fv2MqOKewO1UFJ30St/ZJnH5yGw7GbnR97BsWt601fYzYNyfrkHYHHyO2XZz7uKxsLCsblxXw+pyzcrCsAGRlY4Fs72dWNmQFyApkQyCbrKxsLJBDICefQF4B2bn55OS3Jye/gNz8duTltSevXXty8wvIzc4mF9CR3tLi8jpA3gAoGnB094tEvKNtDgR8sM47MSpU7/XtR4IQDkV/BmOWBSESOvQ2zht10kWaubgmP5tccIe2hZhRLGP2RZsuO2Q/NZFlcR4rv/PRbbcEtej/1s65B4AHwOtDP5bHOOOia+Cia5Jal4i0kKwsyG3vXSTpEhk+dwsc8t95SXRZ3DZmlg10wvtyVEREWkgigb4QGGBm/c0sF5gGzG7SZjbwpej1zwGvpqL/XEREmnfELpdon/itwFy8wxb/5JwrM7OfAIucc7OBh4C/mNlaYBde6IuISAtKqA/dOTcHmNNk2YyY6/XAVcktTUREjoamoBMRyRAKdBGRDKFAFxHJEAp0EZEM4dtoi2a2A9h4jHcvoslZqGlCdR0d1XX00rU21XV0jqeuE5xz3eOt8C3Qj4eZLWputDE/qa6jo7qOXrrWprqOTqrqUpeLiEiGUKCLiGSI1hroD/hdQDNU19FRXUcvXWtTXUcnJXW1yj50ERH5tNa6hy4iIk0o0EVEMkTaBrqZXWVmZWYWMbNRTdb9wMzWmtlqM5vYzP37m9m70XZPRIf+TXaNT5jZ0uhlg5ktbabdBjP7INou5TNjm9mdZrYlprbJzbSbFN2Ga83s9hao67/N7EMzW25mz5lZ3GlbWmp7Hen1m1le9D1eG/0slaaqlpjn7Gtmr5nZyujn/xtx2pxvZtUx7++MeI+VgtoO+76Y597o9lpuZiNaoKaBMdthqZntNbPbmrRpse1lZn8ys+1mtiJmWVczm29ma6I/uzRz3y9F26wxsy/Fa3NEzrm0vACDgIHA68ComOWDgWVAHtAfWAcE4tz/SWBa9PrvgVtSXO8vgRnNrNsAFLXgtrsT+M4R2gSi2+5EIDe6TQenuK4JQHb0+s+Bn/u1vRJ5/cC/Ar+PXp8GPNEC710vYET0eiHwUZy6zgf+1lKfp0TfF2Ay8BLeXIxnAe+2cH0BoALvxBtfthcwFhgBrIhZ9l/A7dHrt8f73ANdgfXRn12i17sc7fOn7R66c26Vc251nFWXAbOccw3OuY+BtXgTWR9kZgaMw5uwGuDPwOWpqjX6fJ8HHk/Vc6TAwcm/nXONwIHJv1PGOTfPOReK3lyAN/uVXxJ5/ZfhfXbA+yyNj77XKeOc2+acWxK9XgOswpuztzW4DHjUeRYAnc2sVws+/3hgnXPuWM9AP27OuTfw5oSIFfs5ai6LJgLznXO7nHO7gfnApKN9/rQN9MOIN2l10w98N2BPTHjEa5NM5wGVzrk1zax3wDwzWxydKLsl3Br9t/dPzfyLl8h2TKXr8fbm4mmJ7ZXI6z9k8nPgwOTnLSLaxTMceDfO6rPNbJmZvWRmQ1qopCO9L35/pqbR/E6VH9vrgGLn3Lbo9QqgOE6bpGy7Fp0kuikzewXoGWfVHc65/2vpeuJJsMbpHH7v/Fzn3BYz6wHMN7MPo3/JU1IXcD9wN94v4N143UHXH8/zJaOuA9vLzO4AQsBjzTxM0rdXa2NmHYBngNucc3ubrF6C162wL/r9yPPAgBYoK23fl+h3ZFOAH8RZ7df2+hTnnDOzlB0r7mugO+cuPIa7JTJpdRXev3vZ0T2reG2SUqN5k2JfAYw8zGNsif7cbmbP4f27f1y/CIluOzP7I/C3OKsS2Y5Jr8vM/h/wWWC8i3YexnmMpG+vOI5m8vNya8HJz80sBy/MH3POPdt0fWzAO+fmmNnvzKzIOZfSQagSeF9S8plK0MXAEudcZdMVfm2vGJVm1ss5ty3aBbU9TpsteH39B5TgfX94VFpjl8tsYFr0CIT+eH9p34ttEA2K1/AmrAZvAutU7fFfCHzonCuPt9LM2ptZ4YHreF8MrojXNlma9FtObeb5Epn8O9l1TQK+B0xxztU206altldaTn4e7aN/CFjlnPtVM216HujLN7MxeL/HKf1Dk+D7Mhu4Lnq0y1lAdUxXQ6o1+1+yH9uridjPUXNZNBeYYGZdol2kE6LLjk5LfPN7LBe8ICoHGoBKYG7MujvwjlBYDVwcs3wO0Dt6/US8oF8LPAXkpajOR4CvNlnWG5gTU8ey6KUMr+sh1dvuL8AHwPLoh6lX07qityfjHUWxroXqWovXT7g0evl907pacnvFe/3AT/D+4ADkRz87a6OfpRNbYBudi9dVtjxmO00GvnrgcwbcGt02y/C+XD6nBeqK+740qcuA+6Lb8wNijk5LcW3t8QK6U8wyX7YX3h+VbUAwml9fwfve5e/AGuAVoGu07SjgwZj7Xh/9rK0Fvnwsz69T/0VEMkRr7HIREZE4FOgiIhlCgS4ikiEU6CIiGUKBLiKSIRToIiIZQoEuIpIhFOgiUWY2OjqgWX70zMgyMxvqd10iidKJRSIxzOyneGeItgPKnXM/87kkkYQp0EViRMd1WQjU450iHva5JJGEqctF5FDdgA54swXl+1yLyFHRHrpIDDObjTd7UX+8Qc1u9bkkkYT5Oh66SDoxs+uAoHNuppkFgLfNbJxz7lW/axNJhPbQRUQyhPrQRUQyhAJdRCRDKNBFRDKEAl1EJEMo0EVEMoQCXUQkQyjQRUQyxP8H/f40iDElTOEAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u4e4OiMpTMCX",
"colab_type": "text"
},
"source": [
"## Control flow\n",
"\n",
"Because tapes record operations as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled.\n",
"\n",
"Here a different variable is used on each branch of an `if`. The gradient only connects to the variable that was used:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7HkKAo-6SRyP",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "4ceed639-2c3c-4e81-81fd-d0741f81fc1f"
},
"source": [
"x = tf.constant(1.0)\n",
"\n",
"v0 = tf.Variable(2.0)\n",
"v1 = tf.Variable(2.0)\n",
"\n",
"with tf.GradientTape(persistent=True) as tape:\n",
" tape.watch(x)\n",
" if x > 0.0:\n",
" result = v0\n",
" else:\n",
" result = v1**2 \n",
"\n",
"dv0, dv1 = tape.gradient(result, [v0, v1])\n",
"\n",
"print(dv0)\n",
"print(dv1)"
],
"execution_count": 79,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(1.0, shape=(), dtype=float32)\n",
"None\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "JALvw4ntThPs",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "16758cc6-7e3c-4b3c-a7fa-621d96d9d012"
},
"source": [
"dx = tape.gradient(result, x)\n",
"print(dx)"
],
"execution_count": 80,
"outputs": [
{
"output_type": "stream",
"text": [
"None\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "acV5Bk_zVPvi",
"colab_type": "text"
},
"source": [
"## Getting a gradient of `None`\n",
"\n",
"When a target is not connected to a source you will get a gradient of `None`."
]
},
{
"cell_type": "code",
"metadata": {
"id": "jEwhzyZKVLH9",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "01402729-8387-4645-cd5c-de35ae893018"
},
"source": [
"x = tf.Variable(2.)\n",
"y = tf.Variable(3.)\n",
"\n",
"with tf.GradientTape() as tape:\n",
" z = y * y\n",
" \n",
"print(tape.gradient(z, x))"
],
"execution_count": 81,
"outputs": [
{
"output_type": "stream",
"text": [
"None\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5P44uHqDVxOg",
"colab_type": "text"
},
"source": [
"### Replaced a variable with a tensor \n",
"\n",
"In the section on \"Controlling what the tape watches\" we saw that the tape will automatically watch a `tf.Variable` but not a `tf.Tensor`.\n",
"\n",
"One common error is to inadvertently replace a `tf.Variable` with a `tf.Tensor`, instead of using `Variable.assign` to update the `tf.Variable`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "XMlyEeQAVaIZ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "e603167f-8f4d-49cc-f826-6a1c5145f607"
},
"source": [
"x = tf.Variable(2.0)\n",
"\n",
"for epoch in range(2):\n",
" with tf.GradientTape() as tape:\n",
" y = x+1\n",
"\n",
" print(type(x).__name__, \":\", tape.gradient(y, x))\n",
" x = x + 1 # This should be `x.assign_add(1)`"
],
"execution_count": 82,
"outputs": [
{
"output_type": "stream",
"text": [
"ResourceVariable : tf.Tensor(1.0, shape=(), dtype=float32)\n",
"EagerTensor : None\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bq4hw7hnWSQW",
"colab_type": "text"
},
"source": [
"### Did calculations outside of TensorFlow\n",
"\n",
"The tape can't record the gradient path if the calculation exits TensorFlow. For example:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "b9y8xkg_WDjZ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "8c7e6ec0-6e4e-4ae2-a514-bbfbfa4ac26c"
},
"source": [
"x = tf.Variable([[1.0, 2.0],\n",
" [3.0, 4.0]], dtype=tf.float32)\n",
"\n",
"with tf.GradientTape() as tape:\n",
" x2 = x**2\n",
"\n",
" # This step is calculated with NumPy\n",
" y = np.mean(x2, axis=0)\n",
"\n",
" # Like most ops, reduce_mean will cast the NumPy array to a constant tensor\n",
" # using `tf.convert_to_tensor`.\n",
" y = tf.reduce_mean(y, axis=0)\n",
"\n",
"print(tape.gradient(y, x))"
],
"execution_count": 83,
"outputs": [
{
"output_type": "stream",
"text": [
"None\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2z_AEk4lYqq3",
"colab_type": "text"
},
"source": [
"### Took gradients through an integer or string\n",
"\n",
"Integers and strings are not differentiable. If a calculation path uses these data types there will be no gradient.\n",
"\n",
"Nobody expects strings to be differentiable, but it's easy to accidentally create an int constant or variable if we don't specify the `dtype`."
]
},
{
"cell_type": "code",
"metadata": {
"id": "RMdLb8zSX9Vn",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 86
},
"outputId": "757d91a1-36f0-4d26-faa6-05aa3906866d"
},
"source": [
"x = tf.Variable([[2, 2],\n",
" [2, 2]])\n",
"\n",
"with tf.GradientTape() as tape:\n",
" # The path to x1 is blocked by the `int` dtype here.\n",
" y = tf.cast(x, tf.float32)\n",
" y = tf.reduce_sum(x)\n",
"\n",
"\n",
"try:\n",
" print(tape.gradient(y, x))\n",
"except Exception as e:\n",
" print(f\"{type(e).__name__}: {e}\")"
],
"execution_count": 84,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:The dtype of the target tensor must be floating (e.g. tf.float32) when calling GradientTape.gradient, got tf.int32\n",
"WARNING:tensorflow:The dtype of the source tensor must be floating (e.g. tf.float32) when calling GradientTape.gradient, got tf.int32\n",
"None\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxiIEbaCZsCT",
"colab_type": "text"
},
"source": [
"TensorFlow doesn't automatically cast between types, so in practice you'll often get a type error instead of a missing gradient."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ArRuVlm-Z3bP",
"colab_type": "text"
},
"source": [
"## No gradient registered\n",
"\n",
"Some `tf.Operations` are **registered as being non-differentiable** and will return `None`. Others have **no gradient registered**.\n",
"\n",
"The `tf.raw_ops` page shows which low-level ops have gradients registered.\n",
"\n",
"If you attempt to take a gradient through a float op that has no gradient registered the tape will throw an error instead of silently returning None. This way you know something has gone wrong.\n",
"\n",
"For example the `tf.image.adjust_contrast` function wraps `raw_ops.AdjustContrastv2` which could have a gradient but the gradient is not implemented:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "yNvyzXpoZXeC",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "3a98e6f8-fd8b-4134-d5ba-e0bf96142050"
},
"source": [
"image = tf.Variable([[[0.5, 0.0, 0.0]]])\n",
"delta = tf.Variable(0.1)\n",
"\n",
"with tf.GradientTape() as tape:\n",
" new_image = tf.image.adjust_contrast(image, delta)\n",
"\n",
"try:\n",
" print(tape.gradient(new_image, [image, delta]))\n",
" assert False # This should not happen.\n",
"except LookupError as e:\n",
" print(f'{type(e).__name__}: {e}')"
],
"execution_count": 85,
"outputs": [
{
"output_type": "stream",
"text": [
"LookupError: gradient registry has no entry for: AdjustContrastv2\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FQCcxIroa6yf",
"colab_type": "text"
},
"source": [
"If we need to differentiate through this op we can implement the gradient and register it (using `tf.RegisterGradient`), or re-implement the function using other ops."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "opIwCyfGbpAi",
"colab_type": "text"
},
"source": [
"## Zeros instead of `None`\n",
"In some cases it would be convenient to get 0 instead of None for unconnected gradients. We can decide what to return when we have unconnected gradients using the `unconnected_gradients` argument:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "yRbOhiY7anh9",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "6de782ce-3836-4ae4-947a-43897c0a4df8"
},
"source": [
"x = tf.Variable([2., 2.])\n",
"y = tf.Variable(3.)\n",
"\n",
"with tf.GradientTape() as tape:\n",
" z = y**2\n",
"print(tape.gradient(z, x, unconnected_gradients=tf.UnconnectedGradients.ZERO))"
],
"execution_count": 86,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor([0. 0.], shape=(2,), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uzRlblJddKmh",
"colab_type": "text"
},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WAfMbcS3nIk6",
"colab_type": "text"
},
"source": [
"# <h1 align=\"center\">-----Part - 3-----</h1>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nn7d2B4em3dD",
"colab_type": "text"
},
"source": [
"## Logistic Regression with Tensorflow"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M4PRDo7Sow-1",
"colab_type": "text"
},
"source": [
"### Import packages"
]
},
{
"cell_type": "code",
"metadata": {
"id": "EQ9xnnCynDev",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from IPython.display import clear_output\n",
"import urllib\n",
"\n",
"# hide warnings\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"import tensorflow as tf"
],
"execution_count": 87,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "udtuULucovjJ",
"colab_type": "text"
},
"source": [
"### Import Titanic dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7_iYLO3Uv11S",
"colab_type": "code",
"colab": {}
},
"source": [
"df_train = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv') # training data\n",
"df_test = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv') # testing data"
],
"execution_count": 88,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JvRPg8hqoab1",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 197
},
"outputId": "cb8ca95c-86c8-410e-f472-dd40a75f2c73"
},
"source": [
"# see how the dataset looks like\n",
"df_train.head()"
],
"execution_count": 89,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>survived</th>\n",
" <th>sex</th>\n",
" <th>age</th>\n",
" <th>n_siblings_spouses</th>\n",
" <th>parch</th>\n",
" <th>fare</th>\n",
" <th>class</th>\n",
" <th>deck</th>\n",
" <th>embark_town</th>\n",
" <th>alone</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>male</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>7.2500</td>\n",
" <td>Third</td>\n",
" <td>unknown</td>\n",
" <td>Southampton</td>\n",
" <td>n</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>female</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>71.2833</td>\n",
" <td>First</td>\n",
" <td>C</td>\n",
" <td>Cherbourg</td>\n",
" <td>n</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>female</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>7.9250</td>\n",
" <td>Third</td>\n",
" <td>unknown</td>\n",
" <td>Southampton</td>\n",
" <td>y</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>female</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>53.1000</td>\n",
" <td>First</td>\n",
" <td>C</td>\n",
" <td>Southampton</td>\n",
" <td>n</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>male</td>\n",
" <td>28.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.4583</td>\n",
" <td>Third</td>\n",
" <td>unknown</td>\n",
" <td>Queenstown</td>\n",
" <td>y</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" survived sex age ... deck embark_town alone\n",
"0 0 male 22.0 ... unknown Southampton n\n",
"1 1 female 38.0 ... C Cherbourg n\n",
"2 1 female 26.0 ... unknown Southampton y\n",
"3 1 female 35.0 ... C Southampton n\n",
"4 0 male 28.0 ... unknown Queenstown y\n",
"\n",
"[5 rows x 10 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 89
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "adGXjAQwxPMx",
"colab_type": "code",
"colab": {}
},
"source": [
"y_train = df_train.pop('survived')\n",
"y_test = df_test.pop('survived')"
],
"execution_count": 90,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BRhb9kyAyXLJ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 287
},
"outputId": "9c732f99-e896-46b5-b8c6-2f71090219e6"
},
"source": [
"df_train.describe()"
],
"execution_count": 91,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>n_siblings_spouses</th>\n",
" <th>parch</th>\n",
" <th>fare</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>627.000000</td>\n",
" <td>627.000000</td>\n",
" <td>627.000000</td>\n",
" <td>627.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>29.631308</td>\n",
" <td>0.545455</td>\n",
" <td>0.379585</td>\n",
" <td>34.385399</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>12.511818</td>\n",
" <td>1.151090</td>\n",
" <td>0.792999</td>\n",
" <td>54.597730</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.750000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>23.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>7.895800</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>28.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>15.045800</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>35.000000</td>\n",
" <td>1.000000</td>\n",
" <td>0.000000</td>\n",
" <td>31.387500</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>80.000000</td>\n",
" <td>8.000000</td>\n",
" <td>5.000000</td>\n",
" <td>512.329200</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age n_siblings_spouses parch fare\n",
"count 627.000000 627.000000 627.000000 627.000000\n",
"mean 29.631308 0.545455 0.379585 34.385399\n",
"std 12.511818 1.151090 0.792999 54.597730\n",
"min 0.750000 0.000000 0.000000 0.000000\n",
"25% 23.000000 0.000000 0.000000 7.895800\n",
"50% 28.000000 0.000000 0.000000 15.045800\n",
"75% 35.000000 1.000000 0.000000 31.387500\n",
"max 80.000000 8.000000 5.000000 512.329200"
]
},
"metadata": {
"tags": []
},
"execution_count": 91
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "umu-hX_53lkE",
"colab_type": "code",
"colab": {}
},
"source": [
"# Categorical features of the training/testing dataset\n",
"cat_col = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck', 'embark_town', 'alone']\n",
"\n",
"# Numerical features of the training/testing dataset\n",
"num_col = ['age', 'fare']"
],
"execution_count": 92,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "HG2dTC_m6P8P",
"colab_type": "text"
},
"source": [
"### Create Base Feature Columns\n",
"\n",
"Estimators use a system called feature columns to describe how the model should interpret each of the raw input features. An Estimator expects a vector of numeric inputs, and feature columns describe how the model should convert each feature.\n",
"\n",
"Selecting and crafting the right set of feature columns is key to learning an effective model. A feature column can be either one of the raw inputs in the original features `dict` (a base feature column), or any new columns created using transformations defined over one or multiple base columns (a derived feature columns).\n",
"\n",
"The linear estimator uses both numeric and categorical features. Feature columns work with all TensorFlow estimators and their purpose is to define the features used for modeling. Additionally, they provide some feature engineering capabilities like one-hot-encoding, normalization, and bucketization."
]
},
{
"cell_type": "code",
"metadata": {
"id": "DFTMq7Xp6PXr",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 336
},
"outputId": "f040b8de-acb4-4161-9064-3a570f0b1485"
},
"source": [
"# to read more : https://www.tensorflow.org/api_docs/python/tf/feature_column \n",
"feature_cols = []\n",
"\n",
"for feature in cat_col:\n",
" # create a vocabulary that contains list of unique values in the data column\n",
" vocab = df_train[feature].unique()\n",
" # create a feature column based on the categorical values in the data column\n",
" # append it two a separate feature column\n",
" feature_cols.append(tf.feature_column.categorical_column_with_vocabulary_list(feature, vocab))\n",
"\n",
"# to read more : https://www.tensorflow.org/api_docs/python/tf/feature_column/numeric_column\n",
"for feature in num_col:\n",
" feature_cols.append(tf.feature_column.numeric_column(feature, dtype=tf.float64))\n",
"\n",
"# Show the base feature column\n",
"for f in feature_cols:\n",
" print(f, '\\n')"
],
"execution_count": 93,
"outputs": [
{
"output_type": "stream",
"text": [
"VocabularyListCategoricalColumn(key='sex', vocabulary_list=('male', 'female'), dtype=tf.string, default_value=-1, num_oov_buckets=0) \n",
"\n",
"VocabularyListCategoricalColumn(key='n_siblings_spouses', vocabulary_list=(1, 0, 3, 4, 2, 5, 8), dtype=tf.int64, default_value=-1, num_oov_buckets=0) \n",
"\n",
"VocabularyListCategoricalColumn(key='parch', vocabulary_list=(0, 1, 2, 5, 3, 4), dtype=tf.int64, default_value=-1, num_oov_buckets=0) \n",
"\n",
"VocabularyListCategoricalColumn(key='class', vocabulary_list=('Third', 'First', 'Second'), dtype=tf.string, default_value=-1, num_oov_buckets=0) \n",
"\n",
"VocabularyListCategoricalColumn(key='deck', vocabulary_list=('unknown', 'C', 'G', 'A', 'B', 'D', 'F', 'E'), dtype=tf.string, default_value=-1, num_oov_buckets=0) \n",
"\n",
"VocabularyListCategoricalColumn(key='embark_town', vocabulary_list=('Southampton', 'Cherbourg', 'Queenstown', 'unknown'), dtype=tf.string, default_value=-1, num_oov_buckets=0) \n",
"\n",
"VocabularyListCategoricalColumn(key='alone', vocabulary_list=('n', 'y'), dtype=tf.string, default_value=-1, num_oov_buckets=0) \n",
"\n",
"NumericColumn(key='age', shape=(1,), default_value=None, dtype=tf.float64, normalizer_fn=None) \n",
"\n",
"NumericColumn(key='fare', shape=(1,), default_value=None, dtype=tf.float64, normalizer_fn=None) \n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zb7TFCSaA846",
"colab_type": "text"
},
"source": [
"### Define a Input Function"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hanfq8D6I8Jp",
"colab_type": "text"
},
"source": [
"A `input_function` specifies how data is converted to a `tf.data.Dataset` that feeds the input pipeline in a streaming fashion. `tf.data.Dataset` can take in multiple sources such as a dataframe, a csv-formatted file, and more."
]
},
{
"cell_type": "code",
"metadata": {
"id": "FZ99gqeb85mp",
"colab_type": "code",
"colab": {}
},
"source": [
"def make_input_func(data_df, target_df, num_epochs=10, shuffle=True, batch_size=64):\n",
" def input_function():\n",
" # create a tensorflow dataset\n",
" dataset = tf.data.Dataset.from_tensor_slices((dict(data_df), target_df))\n",
"\n",
" # shuffle the dataset\n",
" # to know more : https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle\n",
" if shuffle:\n",
" dataset.shuffle(1000)\n",
" \n",
" # divide the dataset into many mini-batches\n",
" # then populate each mini-batch repeating elements by number of epochs\n",
" # to know more : https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch\n",
" batch = dataset.batch(batch_size).repeat(num_epochs)\n",
"\n",
" # return the dataset\n",
" return batch\n",
" return input_function\n",
"\n",
"# create a training input function\n",
"train_input_fn = make_input_func(df_train, y_train)\n",
"\n",
"# create a testing input function\n",
"# we don't need to shuffle the testing dataset\n",
"# coz it's for testing purposes\n",
"test_input_fn = make_input_func(df_train, y_train, num_epochs=1, shuffle=False)"
],
"execution_count": 94,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "UEqv94YSTSuD",
"colab_type": "text"
},
"source": [
"### Creating a Linear Estimator Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ciNeYg_ZT9gp",
"colab_type": "text"
},
"source": [
"After adding all the base features to the model, let's train the model. Training a model is just a single command using the `tf.estimator` API:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "KW7QKvNBRQOE",
"colab_type": "code",
"colab": {}
},
"source": [
"# Create the linear model based on the base feature columns\n",
"linear_est = tf.estimator.LinearClassifier(feature_columns=feature_cols)\n",
"\n",
"# train the estimator on the dataset\n",
"# here dataset would be the each mini-batch created by the train_input_fn\n",
"linear_est.train(train_input_fn)\n",
"\n",
"# using the trainned estimator evaluate the testing dataset\n",
"# here too, evaluation would be performed on each mini-batch of the testing dataset\n",
"result = linear_est.evaluate(test_input_fn)\n",
"\n",
"# clear the output of the cell\n",
"# since this output is not of any use\n",
"clear_output()"
],
"execution_count": 95,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "x761ULU8aSj4",
"colab_type": "text"
},
"source": [
"Print the Accuracy of the model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "jS_3ido6aXAE",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "919f466c-16df-41bf-fd1d-31f8d505784f"
},
"source": [
"# display the result\n",
"print(\"Accuracy : %.2f \"%(result['accuracy']*100))"
],
"execution_count": 96,
"outputs": [
{
"output_type": "stream",
"text": [
"Accuracy : 80.22 \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VLnM0mgnWhNt",
"colab_type": "text"
},
"source": [
"So, from this model we can get 80% accuracy\n",
"\n",
"We can try to obtain various accuracy values by changing the `batch_size` and `num_epochs`.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9xDUwafYgMM2",
"colab_type": "text"
},
"source": [
"**Note:** Only run this code block if you have enough time to spend/waste (1-2 hrs) to find out the values.\n",
"\n",
"Also running this code in your local machine can cluter your RAM (<16 GB)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "2dM38jBwwTkx",
"colab_type": "code",
"colab": {}
},
"source": [
"#@title Model Checker {display-mode: \"form\"}\n",
"\n",
"# This code will be hidden when the notebook is loaded.\n",
"\n",
"# don't execute\n",
"\n",
"batches = [32, 64, 96, 128, 144]\n",
"epochs = [10, 100, 1000, 10000]\n",
"Accuracy = dict()\n",
"\n",
"for batch in batches:\n",
" for num_epoch in epochs:\n",
" train_input_fn = make_input_func(df_train, y_train, num_epochs=num_epoch, batch_size=batch)\n",
" test_input_fn = make_input_func(df_train, y_train, num_epochs=1, shuffle=False)\n",
"\n",
" linear_est.train(train_input_fn)\n",
" result = linear_est.evaluate(test_input_fn)\n",
"\n",
" clear_output()\n",
"\n",
" print(\"Batch Size: \", batch)\n",
" print(\"Number of Epochs\", num_epoch)\n",
" print(\"Accuracy : %.2f \"%(result['accuracy']*100))\n",
"\n",
" id = 'B: ' + str(batch) + ' E: ' + str(num_epoch)\n",
" Accuracy.update({id: result['accuracy']*100})\n",
"\n",
"clear_output()\n",
"print(\"Model Checker Completed\")\n",
"\n",
"for key, value in Accuracy.items():\n",
" print(key, value)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "iym4DBR4h7MV",
"colab_type": "text"
},
"source": [
"### Use the model for prediction"
]
},
{
"cell_type": "code",
"metadata": {
"id": "hTDWZSsodg1v",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 53
},
"outputId": "0a8d5ae5-9007-443a-a6ea-eda5b5d8a465"
},
"source": [
"pred = list(linear_est.predict(test_input_fn))\n",
"clear_output()\n",
"print(pred)"
],
"execution_count": 98,
"outputs": [
{
"output_type": "stream",
"text": [
"[{'logits': array([-2.3956022], dtype=float32), 'logistic': array([0.08350867], dtype=float32), 'probabilities': array([0.9164914 , 0.08350867], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.385266], dtype=float32), 'logistic': array([0.9156968], dtype=float32), 'probabilities': array([0.08430316, 0.9156968 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.3139953], dtype=float32), 'logistic': array([0.5778602], dtype=float32), 'probabilities': array([0.42213982, 0.5778602 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.6583393], dtype=float32), 'logistic': array([0.840015], dtype=float32), 'probabilities': array([0.15998507, 0.840015 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4866104], dtype=float32), 'logistic': array([0.07680219], dtype=float32), 'probabilities': array([0.9231978 , 0.07680218], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5447931], dtype=float32), 'logistic': array([0.07277707], dtype=float32), 'probabilities': array([0.9272229 , 0.07277707], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.06145447], dtype=float32), 'logistic': array([0.51535875], dtype=float32), 'probabilities': array([0.4846412 , 0.51535875], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.658158], dtype=float32), 'logistic': array([0.934512], dtype=float32), 'probabilities': array([0.06548796, 0.934512 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.6333369], dtype=float32), 'logistic': array([0.83662623], dtype=float32), 'probabilities': array([0.16337374, 0.83662623], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4253259], dtype=float32), 'logistic': array([0.08126175], dtype=float32), 'probabilities': array([0.9187383 , 0.08126175], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.3621705], dtype=float32), 'logistic': array([0.03349888], dtype=float32), 'probabilities': array([0.9665011 , 0.03349888], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.81703705], dtype=float32), 'logistic': array([0.69360703], dtype=float32), 'probabilities': array([0.30639294, 0.69360703], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.8063811], dtype=float32), 'logistic': array([0.14107606], dtype=float32), 'probabilities': array([0.8589239 , 0.14107606], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.8856239], dtype=float32), 'logistic': array([0.13174422], dtype=float32), 'probabilities': array([0.8682558 , 0.13174422], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.2705683], dtype=float32), 'logistic': array([0.56723243], dtype=float32), 'probabilities': array([0.43276757, 0.56723243], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.9913704], dtype=float32), 'logistic': array([0.7293585], dtype=float32), 'probabilities': array([0.27064148, 0.7293585 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.1162727], dtype=float32), 'logistic': array([0.10752524], dtype=float32), 'probabilities': array([0.8924748 , 0.10752523], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.7778502], dtype=float32), 'logistic': array([0.3147834], dtype=float32), 'probabilities': array([0.6852166 , 0.31478336], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.328057], dtype=float32), 'logistic': array([0.41871348], dtype=float32), 'probabilities': array([0.58128655, 0.41871348], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0002491], dtype=float32), 'logistic': array([0.11917676], dtype=float32), 'probabilities': array([0.8808232 , 0.11917677], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.0128727], dtype=float32), 'logistic': array([0.26641804], dtype=float32), 'probabilities': array([0.73358196, 0.266418 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.5022029], dtype=float32), 'logistic': array([0.6229769], dtype=float32), 'probabilities': array([0.37702307, 0.6229769 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.99719244], dtype=float32), 'logistic': array([0.2694938], dtype=float32), 'probabilities': array([0.7305062, 0.2694938], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.735545], dtype=float32), 'logistic': array([0.97669584], dtype=float32), 'probabilities': array([0.02330413, 0.97669584], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.50157696], dtype=float32), 'logistic': array([0.6228299], dtype=float32), 'probabilities': array([0.37717015, 0.6228299 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.4917898], dtype=float32), 'logistic': array([0.02954674], dtype=float32), 'probabilities': array([0.9704533 , 0.02954674], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.11246891], dtype=float32), 'logistic': array([0.47191238], dtype=float32), 'probabilities': array([0.5280876 , 0.47191238], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.6106104], dtype=float32), 'logistic': array([0.16650389], dtype=float32), 'probabilities': array([0.8334961 , 0.16650389], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0002286], dtype=float32), 'logistic': array([0.11917893], dtype=float32), 'probabilities': array([0.8808211 , 0.11917892], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.7156135], dtype=float32), 'logistic': array([0.84756297], dtype=float32), 'probabilities': array([0.15243706, 0.84756297], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.14827865], dtype=float32), 'logistic': array([0.4629981], dtype=float32), 'probabilities': array([0.5370019 , 0.46299812], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.3042066], dtype=float32), 'logistic': array([0.78654206], dtype=float32), 'probabilities': array([0.21345791, 0.78654206], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9969987], dtype=float32), 'logistic': array([0.11951841], dtype=float32), 'probabilities': array([0.8804816 , 0.11951841], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.0007672], dtype=float32), 'logistic': array([0.95260876], dtype=float32), 'probabilities': array([0.04739122, 0.95260876], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.760916], dtype=float32), 'logistic': array([0.0594731], dtype=float32), 'probabilities': array([0.94052684, 0.0594731 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3349917], dtype=float32), 'logistic': array([0.08826613], dtype=float32), 'probabilities': array([0.91173387, 0.08826613], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.1067624], dtype=float32), 'logistic': array([0.10844129], dtype=float32), 'probabilities': array([0.8915587 , 0.10844128], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.81493276], dtype=float32), 'logistic': array([0.69315964], dtype=float32), 'probabilities': array([0.30684033, 0.69315964], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2372706], dtype=float32), 'logistic': array([0.09645315], dtype=float32), 'probabilities': array([0.90354687, 0.09645315], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.6111088], dtype=float32), 'logistic': array([0.9315731], dtype=float32), 'probabilities': array([0.06842689, 0.9315731 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.2445369], dtype=float32), 'logistic': array([0.77635276], dtype=float32), 'probabilities': array([0.22364727, 0.77635276], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2487648], dtype=float32), 'logistic': array([0.22291403], dtype=float32), 'probabilities': array([0.77708596, 0.22291404], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2424186], dtype=float32), 'logistic': array([0.22401527], dtype=float32), 'probabilities': array([0.77598476, 0.22401528], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.3875228], dtype=float32), 'logistic': array([0.8001965], dtype=float32), 'probabilities': array([0.19980352, 0.8001965 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0212033], dtype=float32), 'logistic': array([0.11699462], dtype=float32), 'probabilities': array([0.88300544, 0.11699463], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6224706], dtype=float32), 'logistic': array([0.06770618], dtype=float32), 'probabilities': array([0.9322938 , 0.06770618], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.7485363], dtype=float32), 'logistic': array([0.1482319], dtype=float32), 'probabilities': array([0.85176814, 0.1482319 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.3519194], dtype=float32), 'logistic': array([0.9130867], dtype=float32), 'probabilities': array([0.08691333, 0.9130867 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.6055837], dtype=float32), 'logistic': array([0.16720267], dtype=float32), 'probabilities': array([0.83279735, 0.16720268], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.9083207], dtype=float32), 'logistic': array([0.05174378], dtype=float32), 'probabilities': array([0.9482562 , 0.05174377], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.7057766], dtype=float32), 'logistic': array([0.15371232], dtype=float32), 'probabilities': array([0.84628767, 0.15371232], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3828526], dtype=float32), 'logistic': array([0.08448966], dtype=float32), 'probabilities': array([0.91551036, 0.08448965], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.13174719], dtype=float32), 'logistic': array([0.46711072], dtype=float32), 'probabilities': array([0.53288925, 0.46711072], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.8506024], dtype=float32), 'logistic': array([0.05465019], dtype=float32), 'probabilities': array([0.9453498 , 0.05465018], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0655327], dtype=float32), 'logistic': array([0.11249227], dtype=float32), 'probabilities': array([0.88750774, 0.11249228], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2988261], dtype=float32), 'logistic': array([0.21436265], dtype=float32), 'probabilities': array([0.7856373 , 0.21436265], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6939638], dtype=float32), 'logistic': array([0.06333049], dtype=float32), 'probabilities': array([0.9366695 , 0.06333049], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2044327], dtype=float32), 'logistic': array([0.09935313], dtype=float32), 'probabilities': array([0.90064687, 0.09935313], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.760916], dtype=float32), 'logistic': array([0.0594731], dtype=float32), 'probabilities': array([0.94052684, 0.0594731 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.89448565], dtype=float32), 'logistic': array([0.290185], dtype=float32), 'probabilities': array([0.709815, 0.290185], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.16824764], dtype=float32), 'logistic': array([0.541963], dtype=float32), 'probabilities': array([0.458037, 0.541963], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5046203], dtype=float32), 'logistic': array([0.07553492], dtype=float32), 'probabilities': array([0.92446506, 0.07553492], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7958386], dtype=float32), 'logistic': array([0.05754947], dtype=float32), 'probabilities': array([0.9424505 , 0.05754946], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.50175864], dtype=float32), 'logistic': array([0.62287253], dtype=float32), 'probabilities': array([0.37712744, 0.62287253], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.164569], dtype=float32), 'logistic': array([0.23783804], dtype=float32), 'probabilities': array([0.7621619 , 0.23783804], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.976288], dtype=float32), 'logistic': array([0.12171509], dtype=float32), 'probabilities': array([0.8782849, 0.1217151], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.760916], dtype=float32), 'logistic': array([0.0594731], dtype=float32), 'probabilities': array([0.94052684, 0.0594731 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.8109518], dtype=float32), 'logistic': array([0.8594769], dtype=float32), 'probabilities': array([0.14052312, 0.8594769 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.593121], dtype=float32), 'logistic': array([0.06958245], dtype=float32), 'probabilities': array([0.93041754, 0.06958245], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.54794425], dtype=float32), 'logistic': array([0.3663415], dtype=float32), 'probabilities': array([0.63365847, 0.36634147], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-4.0652037], dtype=float32), 'logistic': array([0.01687001], dtype=float32), 'probabilities': array([0.98313004, 0.01687001], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.760916], dtype=float32), 'logistic': array([0.0594731], dtype=float32), 'probabilities': array([0.94052684, 0.0594731 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.82108], dtype=float32), 'logistic': array([0.13930434], dtype=float32), 'probabilities': array([0.8606956 , 0.13930434], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.6144325], dtype=float32), 'logistic': array([0.64895123], dtype=float32), 'probabilities': array([0.35104874, 0.64895123], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.98932207], dtype=float32), 'logistic': array([0.728954], dtype=float32), 'probabilities': array([0.271046, 0.728954], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9568261], dtype=float32), 'logistic': array([0.12381095], dtype=float32), 'probabilities': array([0.8761891 , 0.12381094], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.22995609], dtype=float32), 'logistic': array([0.557237], dtype=float32), 'probabilities': array([0.44276297, 0.557237 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.00116032], dtype=float32), 'logistic': array([0.5002901], dtype=float32), 'probabilities': array([0.4997099, 0.5002901], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.9677322], dtype=float32), 'logistic': array([0.0489051], dtype=float32), 'probabilities': array([0.9510949, 0.0489051], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.3156126], dtype=float32), 'logistic': array([0.03503945], dtype=float32), 'probabilities': array([0.9649606 , 0.03503945], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7622488], dtype=float32), 'logistic': array([0.0593986], dtype=float32), 'probabilities': array([0.9406014 , 0.05939861], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.6985418], dtype=float32), 'logistic': array([0.6678644], dtype=float32), 'probabilities': array([0.3321356, 0.6678644], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9594928], dtype=float32), 'logistic': array([0.12352194], dtype=float32), 'probabilities': array([0.876478 , 0.12352194], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.7102058], dtype=float32), 'logistic': array([0.846863], dtype=float32), 'probabilities': array([0.15313703, 0.846863 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5092235], dtype=float32), 'logistic': array([0.0752141], dtype=float32), 'probabilities': array([0.9247859, 0.0752141], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.4878559], dtype=float32), 'logistic': array([0.8157562], dtype=float32), 'probabilities': array([0.18424375, 0.8157562 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4678805], dtype=float32), 'logistic': array([0.07814077], dtype=float32), 'probabilities': array([0.9218592 , 0.07814077], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-4.272864], dtype=float32), 'logistic': array([0.0137501], dtype=float32), 'probabilities': array([0.9862499, 0.0137501], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.7713103], dtype=float32), 'logistic': array([0.14537944], dtype=float32), 'probabilities': array([0.8546205 , 0.14537945], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.3701876], dtype=float32), 'logistic': array([0.7974104], dtype=float32), 'probabilities': array([0.20258951, 0.7974104 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.760916], dtype=float32), 'logistic': array([0.0594731], dtype=float32), 'probabilities': array([0.94052684, 0.0594731 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.1095134], dtype=float32), 'logistic': array([0.24796163], dtype=float32), 'probabilities': array([0.7520384 , 0.24796163], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.1032293], dtype=float32), 'logistic': array([0.89121664], dtype=float32), 'probabilities': array([0.10878334, 0.89121664], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4900424], dtype=float32), 'logistic': array([0.0765592], dtype=float32), 'probabilities': array([0.92344075, 0.07655919], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5975223], dtype=float32), 'logistic': array([0.06929805], dtype=float32), 'probabilities': array([0.9307019 , 0.06929804], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.4792538], dtype=float32), 'logistic': array([0.02990833], dtype=float32), 'probabilities': array([0.9700917 , 0.02990832], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4301715], dtype=float32), 'logistic': array([0.08090071], dtype=float32), 'probabilities': array([0.9190993 , 0.08090071], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.417571], dtype=float32), 'logistic': array([0.39709815], dtype=float32), 'probabilities': array([0.6029019 , 0.39709815], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.90130264], dtype=float32), 'logistic': array([0.2887829], dtype=float32), 'probabilities': array([0.7112171 , 0.28878286], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.5168319], dtype=float32), 'logistic': array([0.9253134], dtype=float32), 'probabilities': array([0.07468659, 0.9253134 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.4171778], dtype=float32), 'logistic': array([0.19510439], dtype=float32), 'probabilities': array([0.8048956 , 0.19510439], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2518775], dtype=float32), 'logistic': array([0.09518763], dtype=float32), 'probabilities': array([0.9048124 , 0.09518764], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.46801513], dtype=float32), 'logistic': array([0.6149138], dtype=float32), 'probabilities': array([0.38508612, 0.6149138 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.8040976], dtype=float32), 'logistic': array([0.6908503], dtype=float32), 'probabilities': array([0.30914968, 0.6908503 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.4809423], dtype=float32), 'logistic': array([0.6179704], dtype=float32), 'probabilities': array([0.38202965, 0.6179704 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.1173494], dtype=float32), 'logistic': array([0.10742196], dtype=float32), 'probabilities': array([0.892578 , 0.10742195], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.4734049], dtype=float32), 'logistic': array([0.18642563], dtype=float32), 'probabilities': array([0.8135743 , 0.18642563], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.8041455], dtype=float32), 'logistic': array([0.6908605], dtype=float32), 'probabilities': array([0.30913943, 0.6908605 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.8527462], dtype=float32), 'logistic': array([0.05453954], dtype=float32), 'probabilities': array([0.94546044, 0.05453953], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.9145062], dtype=float32), 'logistic': array([0.01956017], dtype=float32), 'probabilities': array([0.98043984, 0.01956016], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.4801595], dtype=float32), 'logistic': array([0.02988205], dtype=float32), 'probabilities': array([0.9701179 , 0.02988205], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7644894], dtype=float32), 'logistic': array([0.05927354], dtype=float32), 'probabilities': array([0.9407264 , 0.05927354], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2089716], dtype=float32), 'logistic': array([0.22988307], dtype=float32), 'probabilities': array([0.7701169 , 0.22988304], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.004881], dtype=float32), 'logistic': array([0.73201716], dtype=float32), 'probabilities': array([0.26798284, 0.73201716], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7579482], dtype=float32), 'logistic': array([0.05963933], dtype=float32), 'probabilities': array([0.9403607 , 0.05963933], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.3676736], dtype=float32), 'logistic': array([0.03332116], dtype=float32), 'probabilities': array([0.96667886, 0.03332116], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.306526], dtype=float32), 'logistic': array([0.03534799], dtype=float32), 'probabilities': array([0.964652 , 0.03534799], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6783512], dtype=float32), 'logistic': array([0.06426295], dtype=float32), 'probabilities': array([0.9357371 , 0.06426296], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.296512], dtype=float32), 'logistic': array([0.09141225], dtype=float32), 'probabilities': array([0.90858775, 0.09141225], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9855784], dtype=float32), 'logistic': array([0.12072544], dtype=float32), 'probabilities': array([0.87927455, 0.12072543], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.1378963], dtype=float32), 'logistic': array([0.95842916], dtype=float32), 'probabilities': array([0.04157085, 0.95842916], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5261686], dtype=float32), 'logistic': array([0.07404391], dtype=float32), 'probabilities': array([0.925956 , 0.07404391], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.8902787], dtype=float32), 'logistic': array([0.1312127], dtype=float32), 'probabilities': array([0.8687873, 0.1312127], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.08677], dtype=float32), 'logistic': array([0.11038937], dtype=float32), 'probabilities': array([0.8896106 , 0.11038935], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.61418], dtype=float32), 'logistic': array([0.02623233], dtype=float32), 'probabilities': array([0.97376764, 0.02623233], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.5581039], dtype=float32), 'logistic': array([0.82608116], dtype=float32), 'probabilities': array([0.17391889, 0.82608116], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9695214], dtype=float32), 'logistic': array([0.12244029], dtype=float32), 'probabilities': array([0.87755966, 0.12244029], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.1355128], dtype=float32), 'logistic': array([0.04166592], dtype=float32), 'probabilities': array([0.9583341 , 0.04166592], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.37605423], dtype=float32), 'logistic': array([0.40707892], dtype=float32), 'probabilities': array([0.5929211, 0.4070789], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.111026], dtype=float32), 'logistic': array([0.24767964], dtype=float32), 'probabilities': array([0.75232035, 0.24767964], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6740875], dtype=float32), 'logistic': array([0.06451982], dtype=float32), 'probabilities': array([0.9354802 , 0.06451982], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.0790519], dtype=float32), 'logistic': array([0.7463145], dtype=float32), 'probabilities': array([0.25368547, 0.7463145 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.6566277], dtype=float32), 'logistic': array([0.65850246], dtype=float32), 'probabilities': array([0.34149754, 0.65850246], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9772738], dtype=float32), 'logistic': array([0.12160975], dtype=float32), 'probabilities': array([0.8783903 , 0.12160975], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.0972533], dtype=float32), 'logistic': array([0.0432207], dtype=float32), 'probabilities': array([0.95677924, 0.04322069], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.93820065], dtype=float32), 'logistic': array([0.71873605], dtype=float32), 'probabilities': array([0.28126395, 0.71873605], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.7247912], dtype=float32), 'logistic': array([0.67366123], dtype=float32), 'probabilities': array([0.3263388 , 0.67366123], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.08573056], dtype=float32), 'logistic': array([0.47858045], dtype=float32), 'probabilities': array([0.5214195, 0.4785805], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.3712142], dtype=float32), 'logistic': array([0.9146057], dtype=float32), 'probabilities': array([0.08539426, 0.9146057 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.359585], dtype=float32), 'logistic': array([0.9136931], dtype=float32), 'probabilities': array([0.08630691, 0.9136931 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4900424], dtype=float32), 'logistic': array([0.0765592], dtype=float32), 'probabilities': array([0.92344075, 0.07655919], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.50157696], dtype=float32), 'logistic': array([0.6228299], dtype=float32), 'probabilities': array([0.37717015, 0.6228299 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.2737905], dtype=float32), 'logistic': array([0.7813909], dtype=float32), 'probabilities': array([0.21860906, 0.7813909 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7538898], dtype=float32), 'logistic': array([0.05986735], dtype=float32), 'probabilities': array([0.9401326 , 0.05986734], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.3676736], dtype=float32), 'logistic': array([0.03332116], dtype=float32), 'probabilities': array([0.96667886, 0.03332116], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.0201395], dtype=float32), 'logistic': array([0.04652429], dtype=float32), 'probabilities': array([0.9534757 , 0.04652429], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.5179659], dtype=float32), 'logistic': array([0.82023877], dtype=float32), 'probabilities': array([0.17976125, 0.82023877], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7734177], dtype=float32), 'logistic': array([0.05877765], dtype=float32), 'probabilities': array([0.9412223 , 0.05877765], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.8603247], dtype=float32), 'logistic': array([0.13466519], dtype=float32), 'probabilities': array([0.86533475, 0.13466519], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.004962], dtype=float32), 'logistic': array([0.732033], dtype=float32), 'probabilities': array([0.26796696, 0.732033 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.53837556], dtype=float32), 'logistic': array([0.36856556], dtype=float32), 'probabilities': array([0.63143444, 0.36856556], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.8511191], dtype=float32), 'logistic': array([0.70080185], dtype=float32), 'probabilities': array([0.29919815, 0.70080185], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5131], dtype=float32), 'logistic': array([0.07494492], dtype=float32), 'probabilities': array([0.92505515, 0.07494492], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3725448], dtype=float32), 'logistic': array([0.08529039], dtype=float32), 'probabilities': array([0.9147096, 0.0852904], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.5432713], dtype=float32), 'logistic': array([0.97189415], dtype=float32), 'probabilities': array([0.02810579, 0.97189415], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.2720465], dtype=float32), 'logistic': array([0.56759524], dtype=float32), 'probabilities': array([0.43240473, 0.56759524], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2875702], dtype=float32), 'logistic': array([0.09215764], dtype=float32), 'probabilities': array([0.9078424 , 0.09215764], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9816355], dtype=float32), 'logistic': array([0.12114462], dtype=float32), 'probabilities': array([0.8788554 , 0.12114462], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.257531], dtype=float32), 'logistic': array([0.09470184], dtype=float32), 'probabilities': array([0.9052982 , 0.09470184], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.725737], dtype=float32), 'logistic': array([0.02352841], dtype=float32), 'probabilities': array([0.9764716 , 0.02352841], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5029242], dtype=float32), 'logistic': array([0.07565343], dtype=float32), 'probabilities': array([0.92434657, 0.07565343], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.5201992], dtype=float32), 'logistic': array([0.1794322], dtype=float32), 'probabilities': array([0.8205678 , 0.17943218], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4501767], dtype=float32), 'logistic': array([0.07942563], dtype=float32), 'probabilities': array([0.92057437, 0.07942563], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.4661366], dtype=float32), 'logistic': array([0.18753055], dtype=float32), 'probabilities': array([0.8124694 , 0.18753053], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.6225607], dtype=float32), 'logistic': array([0.3491993], dtype=float32), 'probabilities': array([0.6508007, 0.3491993], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.8055233], dtype=float32), 'logistic': array([0.85881996], dtype=float32), 'probabilities': array([0.14118005, 0.85881996], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.804197], dtype=float32), 'logistic': array([0.0570978], dtype=float32), 'probabilities': array([0.9429022, 0.0570978], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.1836119], dtype=float32), 'logistic': array([0.03978712], dtype=float32), 'probabilities': array([0.9602129 , 0.03978711], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.22828072], dtype=float32), 'logistic': array([0.5568236], dtype=float32), 'probabilities': array([0.4431764, 0.5568236], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3763134], dtype=float32), 'logistic': array([0.08499683], dtype=float32), 'probabilities': array([0.9150032 , 0.08499683], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.5201992], dtype=float32), 'logistic': array([0.1794322], dtype=float32), 'probabilities': array([0.8205678 , 0.17943218], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0988805], dtype=float32), 'logistic': array([0.10920568], dtype=float32), 'probabilities': array([0.8907943 , 0.10920567], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.1438975], dtype=float32), 'logistic': array([0.75839454], dtype=float32), 'probabilities': array([0.24160549, 0.75839454], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.6566277], dtype=float32), 'logistic': array([0.65850246], dtype=float32), 'probabilities': array([0.34149754, 0.65850246], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5137055], dtype=float32), 'logistic': array([0.07490294], dtype=float32), 'probabilities': array([0.92509705, 0.07490294], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0841463], dtype=float32), 'logistic': array([0.1106473], dtype=float32), 'probabilities': array([0.88935274, 0.1106473 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.55374], dtype=float32), 'logistic': array([0.17454675], dtype=float32), 'probabilities': array([0.8254533 , 0.17454676], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.0549208], dtype=float32), 'logistic': array([0.7417187], dtype=float32), 'probabilities': array([0.2582813, 0.7417187], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7647927], dtype=float32), 'logistic': array([0.05925664], dtype=float32), 'probabilities': array([0.9407434 , 0.05925664], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.55439454], dtype=float32), 'logistic': array([0.63515455], dtype=float32), 'probabilities': array([0.36484545, 0.63515455], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.688309], dtype=float32), 'logistic': array([0.06366675], dtype=float32), 'probabilities': array([0.9363333 , 0.06366675], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.48183388], dtype=float32), 'logistic': array([0.38181916], dtype=float32), 'probabilities': array([0.6181808, 0.3818192], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.7472577], dtype=float32), 'logistic': array([0.9397583], dtype=float32), 'probabilities': array([0.06024171, 0.9397583 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([4.5523725], dtype=float32), 'logistic': array([0.9895678], dtype=float32), 'probabilities': array([0.01043219, 0.9895678 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.33267885], dtype=float32), 'logistic': array([0.582411], dtype=float32), 'probabilities': array([0.41758895, 0.582411 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4900424], dtype=float32), 'logistic': array([0.0765592], dtype=float32), 'probabilities': array([0.92344075, 0.07655919], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.422395], dtype=float32), 'logistic': array([0.08148083], dtype=float32), 'probabilities': array([0.9185192 , 0.08148083], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.3515981], dtype=float32), 'logistic': array([0.20560923], dtype=float32), 'probabilities': array([0.7943908 , 0.20560922], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.50157696], dtype=float32), 'logistic': array([0.6228299], dtype=float32), 'probabilities': array([0.37717015, 0.6228299 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2333276], dtype=float32), 'logistic': array([0.09679732], dtype=float32), 'probabilities': array([0.9032027 , 0.09679732], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6148093], dtype=float32), 'logistic': array([0.06819138], dtype=float32), 'probabilities': array([0.9318086 , 0.06819138], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5189044], dtype=float32), 'logistic': array([0.07454348], dtype=float32), 'probabilities': array([0.9254565 , 0.07454349], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.1488968], dtype=float32), 'logistic': array([0.75930935], dtype=float32), 'probabilities': array([0.24069065, 0.75930935], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.9407622], dtype=float32), 'logistic': array([0.87443584], dtype=float32), 'probabilities': array([0.12556414, 0.87443584], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2425827], dtype=float32), 'logistic': array([0.22398674], dtype=float32), 'probabilities': array([0.77601326, 0.22398676], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6740763], dtype=float32), 'logistic': array([0.0645205], dtype=float32), 'probabilities': array([0.93547946, 0.06452049], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.67872125], dtype=float32), 'logistic': array([0.6634532], dtype=float32), 'probabilities': array([0.33654675, 0.6634532 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.79683405], dtype=float32), 'logistic': array([0.31070316], dtype=float32), 'probabilities': array([0.68929684, 0.31070316], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.50157696], dtype=float32), 'logistic': array([0.6228299], dtype=float32), 'probabilities': array([0.37717015, 0.6228299 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.2651176], dtype=float32), 'logistic': array([0.77990586], dtype=float32), 'probabilities': array([0.22009417, 0.77990586], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.48387903], dtype=float32), 'logistic': array([0.38133657], dtype=float32), 'probabilities': array([0.61866343, 0.38133654], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9486164], dtype=float32), 'logistic': array([0.12470432], dtype=float32), 'probabilities': array([0.8752957, 0.1247043], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0161247], dtype=float32), 'logistic': array([0.1175203], dtype=float32), 'probabilities': array([0.88247967, 0.1175203 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.25178558], dtype=float32), 'logistic': array([0.56261593], dtype=float32), 'probabilities': array([0.43738404, 0.56261593], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.761865], dtype=float32), 'logistic': array([0.05942005], dtype=float32), 'probabilities': array([0.9405799 , 0.05942005], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2505047], dtype=float32), 'logistic': array([0.09530593], dtype=float32), 'probabilities': array([0.904694 , 0.09530593], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.8238831], dtype=float32), 'logistic': array([0.30494002], dtype=float32), 'probabilities': array([0.69505996, 0.30494 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.0647948], dtype=float32), 'logistic': array([0.88743407], dtype=float32), 'probabilities': array([0.11256596, 0.88743407], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.8179055], dtype=float32), 'logistic': array([0.30620846], dtype=float32), 'probabilities': array([0.69379157, 0.30620843], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5938683], dtype=float32), 'logistic': array([0.06953409], dtype=float32), 'probabilities': array([0.93046594, 0.06953409], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2666465], dtype=float32), 'logistic': array([0.21983187], dtype=float32), 'probabilities': array([0.7801681 , 0.21983185], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.2711403], dtype=float32), 'logistic': array([0.96342534], dtype=float32), 'probabilities': array([0.03657462, 0.96342534], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.50157696], dtype=float32), 'logistic': array([0.6228299], dtype=float32), 'probabilities': array([0.37717015, 0.6228299 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.422384], dtype=float32), 'logistic': array([0.08148165], dtype=float32), 'probabilities': array([0.91851836, 0.08148165], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.561176], dtype=float32), 'logistic': array([0.92832077], dtype=float32), 'probabilities': array([0.07167924, 0.92832077], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.760916], dtype=float32), 'logistic': array([0.0594731], dtype=float32), 'probabilities': array([0.94052684, 0.0594731 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.9007814], dtype=float32), 'logistic': array([0.94788504], dtype=float32), 'probabilities': array([0.05211494, 0.94788504], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.448464], dtype=float32), 'logistic': array([0.9691853], dtype=float32), 'probabilities': array([0.0308147, 0.9691853], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.034058], dtype=float32), 'logistic': array([0.26229814], dtype=float32), 'probabilities': array([0.7377019 , 0.26229814], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.7414486], dtype=float32), 'logistic': array([0.97682995], dtype=float32), 'probabilities': array([0.02317013, 0.97682995], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.9125917], dtype=float32), 'logistic': array([0.94846535], dtype=float32), 'probabilities': array([0.0515346 , 0.94846535], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([4.2477703], dtype=float32), 'logistic': array([0.98590547], dtype=float32), 'probabilities': array([0.01409458, 0.98590547], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.5084926], dtype=float32), 'logistic': array([0.8188377], dtype=float32), 'probabilities': array([0.1811623, 0.8188377], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.1950438], dtype=float32), 'logistic': array([0.10019644], dtype=float32), 'probabilities': array([0.8998036 , 0.10019644], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.4542805], dtype=float32), 'logistic': array([0.81065637], dtype=float32), 'probabilities': array([0.18934366, 0.81065637], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.0240715], dtype=float32), 'logistic': array([0.8833014], dtype=float32), 'probabilities': array([0.11669865, 0.8833014 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.9534428], dtype=float32), 'logistic': array([0.98117274], dtype=float32), 'probabilities': array([0.01882726, 0.98117274], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5131], dtype=float32), 'logistic': array([0.07494492], dtype=float32), 'probabilities': array([0.92505515, 0.07494492], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7197144], dtype=float32), 'logistic': array([0.06182003], dtype=float32), 'probabilities': array([0.93818 , 0.06182003], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.6908244], dtype=float32), 'logistic': array([0.8443326], dtype=float32), 'probabilities': array([0.15566745, 0.8443326 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.3676736], dtype=float32), 'logistic': array([0.03332116], dtype=float32), 'probabilities': array([0.96667886, 0.03332116], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.6634781], dtype=float32), 'logistic': array([0.93483686], dtype=float32), 'probabilities': array([0.06516314, 0.93483686], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-4.1540074], dtype=float32), 'logistic': array([0.01545865], dtype=float32), 'probabilities': array([0.9845413 , 0.01545865], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.42091304], dtype=float32), 'logistic': array([0.6037017], dtype=float32), 'probabilities': array([0.3962983, 0.6037017], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.40013117], dtype=float32), 'logistic': array([0.5987192], dtype=float32), 'probabilities': array([0.40128085, 0.5987192 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.0037479], dtype=float32), 'logistic': array([0.26820517], dtype=float32), 'probabilities': array([0.73179483, 0.26820517], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.363932], dtype=float32), 'logistic': array([0.91403526], dtype=float32), 'probabilities': array([0.08596474, 0.91403526], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.0161726], dtype=float32), 'logistic': array([0.2657736], dtype=float32), 'probabilities': array([0.7342264, 0.2657736], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.4740448], dtype=float32), 'logistic': array([0.03005983], dtype=float32), 'probabilities': array([0.9699402 , 0.03005983], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.933906], dtype=float32), 'logistic': array([0.12631889], dtype=float32), 'probabilities': array([0.87368107, 0.12631887], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.8856239], dtype=float32), 'logistic': array([0.13174422], dtype=float32), 'probabilities': array([0.8682558 , 0.13174422], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.7597777], dtype=float32), 'logistic': array([0.14681819], dtype=float32), 'probabilities': array([0.8531818 , 0.14681819], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2212136], dtype=float32), 'logistic': array([0.09786161], dtype=float32), 'probabilities': array([0.9021384 , 0.09786162], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.7063655], dtype=float32), 'logistic': array([0.84636426], dtype=float32), 'probabilities': array([0.15363573, 0.84636426], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.60261077], dtype=float32), 'logistic': array([0.6462534], dtype=float32), 'probabilities': array([0.3537466, 0.6462534], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.38720793], dtype=float32), 'logistic': array([0.59561044], dtype=float32), 'probabilities': array([0.4043896 , 0.59561044], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.1992877], dtype=float32), 'logistic': array([0.23160195], dtype=float32), 'probabilities': array([0.768398 , 0.23160197], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4703276], dtype=float32), 'logistic': array([0.07796468], dtype=float32), 'probabilities': array([0.9220353 , 0.07796468], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0002491], dtype=float32), 'logistic': array([0.11917676], dtype=float32), 'probabilities': array([0.8808232 , 0.11917677], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7538898], dtype=float32), 'logistic': array([0.05986735], dtype=float32), 'probabilities': array([0.9401326 , 0.05986734], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.389589], dtype=float32), 'logistic': array([0.9673776], dtype=float32), 'probabilities': array([0.03262242, 0.9673776 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.68650836], dtype=float32), 'logistic': array([0.66518974], dtype=float32), 'probabilities': array([0.33481026, 0.66518974], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.5022029], dtype=float32), 'logistic': array([0.6229769], dtype=float32), 'probabilities': array([0.37702307, 0.6229769 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.5022029], dtype=float32), 'logistic': array([0.6229769], dtype=float32), 'probabilities': array([0.37702307, 0.6229769 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.410484], dtype=float32), 'logistic': array([0.03196941], dtype=float32), 'probabilities': array([0.9680305 , 0.03196941], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.97407967], dtype=float32), 'logistic': array([0.2740681], dtype=float32), 'probabilities': array([0.72593194, 0.2740681 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.39983422], dtype=float32), 'logistic': array([0.59864783], dtype=float32), 'probabilities': array([0.40135217, 0.59864783], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.0594027], dtype=float32), 'logistic': array([0.04481326], dtype=float32), 'probabilities': array([0.9551868 , 0.04481326], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3349917], dtype=float32), 'logistic': array([0.08826613], dtype=float32), 'probabilities': array([0.91173387, 0.08826613], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.1425052], dtype=float32), 'logistic': array([0.8949663], dtype=float32), 'probabilities': array([0.10503367, 0.8949663 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.99139065], dtype=float32), 'logistic': array([0.7293625], dtype=float32), 'probabilities': array([0.27063748, 0.7293625 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.50157696], dtype=float32), 'logistic': array([0.6228299], dtype=float32), 'probabilities': array([0.37717015, 0.6228299 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.4116638], dtype=float32), 'logistic': array([0.96806705], dtype=float32), 'probabilities': array([0.03193292, 0.96806705], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2314618], dtype=float32), 'logistic': array([0.09696057], dtype=float32), 'probabilities': array([0.90303946, 0.09696058], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3833773], dtype=float32), 'logistic': array([0.08444908], dtype=float32), 'probabilities': array([0.9155509 , 0.08444907], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.8791506], dtype=float32), 'logistic': array([0.9468061], dtype=float32), 'probabilities': array([0.0531939, 0.9468061], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.19087619], dtype=float32), 'logistic': array([0.5475747], dtype=float32), 'probabilities': array([0.45242533, 0.5475747 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3847098], dtype=float32), 'logistic': array([0.08434611], dtype=float32), 'probabilities': array([0.9156538, 0.0843461], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.878694], dtype=float32), 'logistic': array([0.94678307], dtype=float32), 'probabilities': array([0.0532169 , 0.94678307], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.9293168], dtype=float32), 'logistic': array([0.05072322], dtype=float32), 'probabilities': array([0.9492768 , 0.05072321], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.1729797], dtype=float32), 'logistic': array([0.23631682], dtype=float32), 'probabilities': array([0.7636832 , 0.23631683], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2029834], dtype=float32), 'logistic': array([0.09948291], dtype=float32), 'probabilities': array([0.90051717, 0.0994829 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4901433], dtype=float32), 'logistic': array([0.07655207], dtype=float32), 'probabilities': array([0.9234479 , 0.07655206], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.3272507], dtype=float32), 'logistic': array([0.9111089], dtype=float32), 'probabilities': array([0.08889107, 0.9111089 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.65942895], dtype=float32), 'logistic': array([0.3408679], dtype=float32), 'probabilities': array([0.65913206, 0.34086788], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4685066], dtype=float32), 'logistic': array([0.07809568], dtype=float32), 'probabilities': array([0.92190427, 0.07809568], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.8788614], dtype=float32), 'logistic': array([0.97974443], dtype=float32), 'probabilities': array([0.02025558, 0.97974443], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.31261688], dtype=float32), 'logistic': array([0.5775239], dtype=float32), 'probabilities': array([0.4224761, 0.5775239], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5104554], dtype=float32), 'logistic': array([0.07512846], dtype=float32), 'probabilities': array([0.92487156, 0.07512846], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.10390847], dtype=float32), 'logistic': array([0.52595377], dtype=float32), 'probabilities': array([0.4740462 , 0.52595377], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5777085], dtype=float32), 'logistic': array([0.07058692], dtype=float32), 'probabilities': array([0.929413 , 0.07058692], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.6879941], dtype=float32), 'logistic': array([0.15603982], dtype=float32), 'probabilities': array([0.84396017, 0.1560398 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.2229578], dtype=float32), 'logistic': array([0.03831086], dtype=float32), 'probabilities': array([0.9616892 , 0.03831086], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6770184], dtype=float32), 'logistic': array([0.06434315], dtype=float32), 'probabilities': array([0.93565685, 0.06434314], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.605623], dtype=float32), 'logistic': array([0.0687774], dtype=float32), 'probabilities': array([0.9312226 , 0.06877741], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.981054], dtype=float32), 'logistic': array([0.12120653], dtype=float32), 'probabilities': array([0.8787934 , 0.12120652], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.553436], dtype=float32), 'logistic': array([0.3650676], dtype=float32), 'probabilities': array([0.6349324, 0.3650676], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4686072], dtype=float32), 'logistic': array([0.07808845], dtype=float32), 'probabilities': array([0.9219116 , 0.07808845], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.6225607], dtype=float32), 'logistic': array([0.3491993], dtype=float32), 'probabilities': array([0.6508007, 0.3491993], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4943633], dtype=float32), 'logistic': array([0.07625429], dtype=float32), 'probabilities': array([0.92374575, 0.07625428], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9486164], dtype=float32), 'logistic': array([0.12470432], dtype=float32), 'probabilities': array([0.8752957, 0.1247043], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.4327016], dtype=float32), 'logistic': array([0.03128894], dtype=float32), 'probabilities': array([0.9687111 , 0.03128894], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.2307033], dtype=float32), 'logistic': array([0.5574214], dtype=float32), 'probabilities': array([0.44257864, 0.5574214 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.2043992], dtype=float32), 'logistic': array([0.7693064], dtype=float32), 'probabilities': array([0.23069353, 0.7693064 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.2993449], dtype=float32), 'logistic': array([0.78572476], dtype=float32), 'probabilities': array([0.2142753 , 0.78572476], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9695214], dtype=float32), 'logistic': array([0.12244029], dtype=float32), 'probabilities': array([0.87755966, 0.12244029], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9969987], dtype=float32), 'logistic': array([0.11951841], dtype=float32), 'probabilities': array([0.8804816 , 0.11951841], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.196482], dtype=float32), 'logistic': array([0.10006686], dtype=float32), 'probabilities': array([0.89993316, 0.10006686], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.026887], dtype=float32), 'logistic': array([0.11640875], dtype=float32), 'probabilities': array([0.8835913 , 0.11640874], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.5465266], dtype=float32), 'logistic': array([0.8244115], dtype=float32), 'probabilities': array([0.1755885, 0.8244115], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4900424], dtype=float32), 'logistic': array([0.0765592], dtype=float32), 'probabilities': array([0.92344075, 0.07655919], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.7427081], dtype=float32), 'logistic': array([0.14896928], dtype=float32), 'probabilities': array([0.8510307 , 0.14896928], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2857865], dtype=float32), 'logistic': array([0.21656686], dtype=float32), 'probabilities': array([0.78343314, 0.21656683], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.38720793], dtype=float32), 'logistic': array([0.59561044], dtype=float32), 'probabilities': array([0.4043896 , 0.59561044], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.69920343], dtype=float32), 'logistic': array([0.6680111], dtype=float32), 'probabilities': array([0.33198887, 0.6680111 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.303962], dtype=float32), 'logistic': array([0.09079536], dtype=float32), 'probabilities': array([0.9092046 , 0.09079536], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.74129957], dtype=float32), 'logistic': array([0.32272002], dtype=float32), 'probabilities': array([0.67727995, 0.32272 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.2550628], dtype=float32), 'logistic': array([0.9628546], dtype=float32), 'probabilities': array([0.03714539, 0.9628546 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.1612864], dtype=float32), 'logistic': array([0.7615664], dtype=float32), 'probabilities': array([0.23843361, 0.7615664 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.8925749], dtype=float32), 'logistic': array([0.13095115], dtype=float32), 'probabilities': array([0.86904883, 0.13095115], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.023584], dtype=float32), 'logistic': array([0.11674892], dtype=float32), 'probabilities': array([0.8832511 , 0.11674892], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.712678], dtype=float32), 'logistic': array([0.6709927], dtype=float32), 'probabilities': array([0.32900736, 0.6709927 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4183], dtype=float32), 'logistic': array([0.08178784], dtype=float32), 'probabilities': array([0.9182121 , 0.08178784], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.1059955], dtype=float32), 'logistic': array([0.75138175], dtype=float32), 'probabilities': array([0.2486182 , 0.75138175], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7606132], dtype=float32), 'logistic': array([0.05949005], dtype=float32), 'probabilities': array([0.94051 , 0.05949005], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.5158379], dtype=float32), 'logistic': array([0.18007523], dtype=float32), 'probabilities': array([0.8199248 , 0.18007523], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.976057], dtype=float32), 'logistic': array([0.87826025], dtype=float32), 'probabilities': array([0.12173978, 0.87826025], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2734163], dtype=float32), 'logistic': array([0.09334868], dtype=float32), 'probabilities': array([0.9066513 , 0.09334867], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2068355], dtype=float32), 'logistic': array([0.09913834], dtype=float32), 'probabilities': array([0.9008617 , 0.09913834], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.5992047], dtype=float32), 'logistic': array([0.35452566], dtype=float32), 'probabilities': array([0.6454744 , 0.35452566], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.98143727], dtype=float32), 'logistic': array([0.2726067], dtype=float32), 'probabilities': array([0.7273933 , 0.27260667], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0389473], dtype=float32), 'logistic': array([0.11517397], dtype=float32), 'probabilities': array([0.884826 , 0.11517395], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.6302459], dtype=float32), 'logistic': array([0.16379666], dtype=float32), 'probabilities': array([0.83620334, 0.16379666], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.6068747], dtype=float32), 'logistic': array([0.9313027], dtype=float32), 'probabilities': array([0.06869729, 0.9313027 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.9171173], dtype=float32), 'logistic': array([0.28554565], dtype=float32), 'probabilities': array([0.71445435, 0.28554565], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.0126083], dtype=float32), 'logistic': array([0.04685951], dtype=float32), 'probabilities': array([0.95314056, 0.04685951], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.81726396], dtype=float32), 'logistic': array([0.30634478], dtype=float32), 'probabilities': array([0.69365525, 0.30634475], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.760916], dtype=float32), 'logistic': array([0.0594731], dtype=float32), 'probabilities': array([0.94052684, 0.0594731 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.1852489], dtype=float32), 'logistic': array([0.03972463], dtype=float32), 'probabilities': array([0.96027535, 0.03972462], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.43871], dtype=float32), 'logistic': array([0.08026809], dtype=float32), 'probabilities': array([0.9197319 , 0.08026809], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4901636], dtype=float32), 'logistic': array([0.07655063], dtype=float32), 'probabilities': array([0.9234494 , 0.07655063], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.1543393], dtype=float32), 'logistic': array([0.8960736], dtype=float32), 'probabilities': array([0.10392642, 0.8960736 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7647927], dtype=float32), 'logistic': array([0.05925664], dtype=float32), 'probabilities': array([0.9407434 , 0.05925664], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.91063017], dtype=float32), 'logistic': array([0.7131291], dtype=float32), 'probabilities': array([0.2868709, 0.7131291], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.7234895], dtype=float32), 'logistic': array([0.93839854], dtype=float32), 'probabilities': array([0.06160144, 0.93839854], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.6978984], dtype=float32), 'logistic': array([0.33227834], dtype=float32), 'probabilities': array([0.6677217 , 0.33227834], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.981054], dtype=float32), 'logistic': array([0.12120653], dtype=float32), 'probabilities': array([0.8787934 , 0.12120652], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.690233], dtype=float32), 'logistic': array([0.06355215], dtype=float32), 'probabilities': array([0.93644786, 0.06355215], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.4284667], dtype=float32), 'logistic': array([0.8066623], dtype=float32), 'probabilities': array([0.19333771, 0.8066623 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5385733], dtype=float32), 'logistic': array([0.0731979], dtype=float32), 'probabilities': array([0.9268021, 0.0731979], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9486164], dtype=float32), 'logistic': array([0.12470432], dtype=float32), 'probabilities': array([0.8752957, 0.1247043], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2300525], dtype=float32), 'logistic': array([0.22617224], dtype=float32), 'probabilities': array([0.7738278 , 0.22617222], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.6011258], dtype=float32), 'logistic': array([0.64591384], dtype=float32), 'probabilities': array([0.35408616, 0.64591384], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.6225607], dtype=float32), 'logistic': array([0.3491993], dtype=float32), 'probabilities': array([0.6508007, 0.3491993], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.8371407], dtype=float32), 'logistic': array([0.86261016], dtype=float32), 'probabilities': array([0.13738982, 0.86261016], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.1980975], dtype=float32), 'logistic': array([0.23181383], dtype=float32), 'probabilities': array([0.76818615, 0.23181382], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.670245], dtype=float32), 'logistic': array([0.15839152], dtype=float32), 'probabilities': array([0.84160846, 0.15839152], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.585675], dtype=float32), 'logistic': array([0.07006606], dtype=float32), 'probabilities': array([0.92993397, 0.07006606], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.1920502], dtype=float32), 'logistic': array([0.10046665], dtype=float32), 'probabilities': array([0.89953333, 0.10046665], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4672747], dtype=float32), 'logistic': array([0.07818443], dtype=float32), 'probabilities': array([0.9218155 , 0.07818442], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.3802812], dtype=float32), 'logistic': array([0.9153113], dtype=float32), 'probabilities': array([0.08468877, 0.9153113 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.296512], dtype=float32), 'logistic': array([0.09141225], dtype=float32), 'probabilities': array([0.90858775, 0.09141225], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.7952183], dtype=float32), 'logistic': array([0.6889507], dtype=float32), 'probabilities': array([0.31104928, 0.6889507 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.1393854], dtype=float32), 'logistic': array([0.46520993], dtype=float32), 'probabilities': array([0.53479004, 0.46520996], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.4148956], dtype=float32), 'logistic': array([0.6022612], dtype=float32), 'probabilities': array([0.39773887, 0.6022612 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.7331061], dtype=float32), 'logistic': array([0.6754865], dtype=float32), 'probabilities': array([0.3245135, 0.6754865], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2641455], dtype=float32), 'logistic': array([0.2202611], dtype=float32), 'probabilities': array([0.7797389 , 0.22026108], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.442271], dtype=float32), 'logistic': array([0.08000559], dtype=float32), 'probabilities': array([0.91999435, 0.08000559], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.678146], dtype=float32), 'logistic': array([0.8426588], dtype=float32), 'probabilities': array([0.15734112, 0.8426588 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5958064], dtype=float32), 'logistic': array([0.0694088], dtype=float32), 'probabilities': array([0.9305912, 0.0694088], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.3857218], dtype=float32), 'logistic': array([0.20009163], dtype=float32), 'probabilities': array([0.7999084 , 0.20009165], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.2747643], dtype=float32), 'logistic': array([0.78155726], dtype=float32), 'probabilities': array([0.21844277, 0.78155726], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.9508957], dtype=float32), 'logistic': array([0.7212953], dtype=float32), 'probabilities': array([0.2787047, 0.7212953], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.9294584], dtype=float32), 'logistic': array([0.0507164], dtype=float32), 'probabilities': array([0.94928366, 0.05071639], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.5125692], dtype=float32), 'logistic': array([0.9250183], dtype=float32), 'probabilities': array([0.07498172, 0.9250183 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5099707], dtype=float32), 'logistic': array([0.07516214], dtype=float32), 'probabilities': array([0.9248379 , 0.07516215], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.604396], dtype=float32), 'logistic': array([0.93114394], dtype=float32), 'probabilities': array([0.06885603, 0.93114394], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0002286], dtype=float32), 'logistic': array([0.11917893], dtype=float32), 'probabilities': array([0.8808211 , 0.11917892], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.0144017], dtype=float32), 'logistic': array([0.04677948], dtype=float32), 'probabilities': array([0.9532205 , 0.04677948], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.17100972], dtype=float32), 'logistic': array([0.54264855], dtype=float32), 'probabilities': array([0.45735145, 0.54264855], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.33979052], dtype=float32), 'logistic': array([0.41586033], dtype=float32), 'probabilities': array([0.58413965, 0.41586035], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.7215909], dtype=float32), 'logistic': array([0.15166637], dtype=float32), 'probabilities': array([0.84833366, 0.15166637], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.5152626], dtype=float32), 'logistic': array([0.9252049], dtype=float32), 'probabilities': array([0.07479511, 0.9252049 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2831852], dtype=float32), 'logistic': array([0.21700852], dtype=float32), 'probabilities': array([0.78299147, 0.21700852], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.8385616], dtype=float32), 'logistic': array([0.6981622], dtype=float32), 'probabilities': array([0.30183783, 0.6981622 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.14977401], dtype=float32), 'logistic': array([0.53737366], dtype=float32), 'probabilities': array([0.4626263 , 0.53737366], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.8249847], dtype=float32), 'logistic': array([0.8611632], dtype=float32), 'probabilities': array([0.13883683, 0.8611632 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.432692], dtype=float32), 'logistic': array([0.19267957], dtype=float32), 'probabilities': array([0.8073204 , 0.19267957], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.7952807], dtype=float32), 'logistic': array([0.94242024], dtype=float32), 'probabilities': array([0.05757973, 0.94242024], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.1734807], dtype=float32), 'logistic': array([0.9598241], dtype=float32), 'probabilities': array([0.04017598, 0.9598241 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.9257134], dtype=float32), 'logistic': array([0.8727742], dtype=float32), 'probabilities': array([0.12722579, 0.8727742 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.31698674], dtype=float32), 'logistic': array([0.57858974], dtype=float32), 'probabilities': array([0.4214103 , 0.57858974], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.23308939], dtype=float32), 'logistic': array([0.5580099], dtype=float32), 'probabilities': array([0.44199002, 0.5580099 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.93945676], dtype=float32), 'logistic': array([0.2810101], dtype=float32), 'probabilities': array([0.7189899, 0.2810101], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.6640242], dtype=float32), 'logistic': array([0.84077746], dtype=float32), 'probabilities': array([0.15922251, 0.84077746], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.1167802], dtype=float32), 'logistic': array([0.24660902], dtype=float32), 'probabilities': array([0.753391 , 0.24660902], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.654604], dtype=float32), 'logistic': array([0.06570581], dtype=float32), 'probabilities': array([0.93429416, 0.06570581], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.12281876], dtype=float32), 'logistic': array([0.5306662], dtype=float32), 'probabilities': array([0.46933383, 0.5306662 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.7806828], dtype=float32), 'logistic': array([0.14421883], dtype=float32), 'probabilities': array([0.85578114, 0.14421885], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4896586], dtype=float32), 'logistic': array([0.07658634], dtype=float32), 'probabilities': array([0.92341363, 0.07658633], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.48106343], dtype=float32), 'logistic': array([0.61799896], dtype=float32), 'probabilities': array([0.38200104, 0.61799896], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.2768235], dtype=float32), 'logistic': array([0.90693927], dtype=float32), 'probabilities': array([0.0930607 , 0.90693927], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.47435695], dtype=float32), 'logistic': array([0.6164144], dtype=float32), 'probabilities': array([0.3835855, 0.6164144], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.964948], dtype=float32), 'logistic': array([0.9509652], dtype=float32), 'probabilities': array([0.04903476, 0.9509652 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.05791729], dtype=float32), 'logistic': array([0.5144753], dtype=float32), 'probabilities': array([0.48552474, 0.5144753 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4900424], dtype=float32), 'logistic': array([0.0765592], dtype=float32), 'probabilities': array([0.92344075, 0.07655919], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.265048], dtype=float32), 'logistic': array([0.03678991], dtype=float32), 'probabilities': array([0.96321005, 0.03678991], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.760916], dtype=float32), 'logistic': array([0.0594731], dtype=float32), 'probabilities': array([0.94052684, 0.0594731 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6916592], dtype=float32), 'logistic': array([0.06346732], dtype=float32), 'probabilities': array([0.9365326 , 0.06346732], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3841245], dtype=float32), 'logistic': array([0.08439132], dtype=float32), 'probabilities': array([0.9156087 , 0.08439132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.15954393], dtype=float32), 'logistic': array([0.4601984], dtype=float32), 'probabilities': array([0.53980166, 0.4601984 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.9296596], dtype=float32), 'logistic': array([0.05070671], dtype=float32), 'probabilities': array([0.9492933 , 0.05070671], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.3239946], dtype=float32), 'logistic': array([0.03475714], dtype=float32), 'probabilities': array([0.96524286, 0.03475714], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.6013586], dtype=float32), 'logistic': array([0.64596707], dtype=float32), 'probabilities': array([0.3540329 , 0.64596707], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.41451973], dtype=float32), 'logistic': array([0.39782885], dtype=float32), 'probabilities': array([0.6021712 , 0.39782888], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.257531], dtype=float32), 'logistic': array([0.09470184], dtype=float32), 'probabilities': array([0.9052982 , 0.09470184], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.8543033], dtype=float32), 'logistic': array([0.7014691], dtype=float32), 'probabilities': array([0.2985309, 0.7014691], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.5698239], dtype=float32), 'logistic': array([0.8277585], dtype=float32), 'probabilities': array([0.1722415, 0.8277585], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.32636482], dtype=float32), 'logistic': array([0.41912535], dtype=float32), 'probabilities': array([0.5808746 , 0.41912538], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6729586], dtype=float32), 'logistic': array([0.06458799], dtype=float32), 'probabilities': array([0.935412 , 0.06458799], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.7865321], dtype=float32), 'logistic': array([0.31291378], dtype=float32), 'probabilities': array([0.6870862 , 0.31291378], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.760916], dtype=float32), 'logistic': array([0.0594731], dtype=float32), 'probabilities': array([0.94052684, 0.0594731 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.0590394], dtype=float32), 'logistic': array([0.04482882], dtype=float32), 'probabilities': array([0.9551711 , 0.04482882], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.4927127], dtype=float32), 'logistic': array([0.9236294], dtype=float32), 'probabilities': array([0.07637062, 0.9236294 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.5618186], dtype=float32), 'logistic': array([0.02760357], dtype=float32), 'probabilities': array([0.9723964 , 0.02760356], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.27543885], dtype=float32), 'logistic': array([0.5684276], dtype=float32), 'probabilities': array([0.43157235, 0.5684276 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0826724], dtype=float32), 'logistic': array([0.11079242], dtype=float32), 'probabilities': array([0.88920754, 0.11079242], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.762885], dtype=float32), 'logistic': array([0.05936306], dtype=float32), 'probabilities': array([0.940637 , 0.05936307], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.2029068], dtype=float32), 'logistic': array([0.76904154], dtype=float32), 'probabilities': array([0.2309585 , 0.76904154], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.6808465], dtype=float32), 'logistic': array([0.02458213], dtype=float32), 'probabilities': array([0.97541785, 0.02458212], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0002491], dtype=float32), 'logistic': array([0.11917676], dtype=float32), 'probabilities': array([0.8808232 , 0.11917677], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.67277503], dtype=float32), 'logistic': array([0.33787575], dtype=float32), 'probabilities': array([0.6621243 , 0.33787572], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.79312205], dtype=float32), 'logistic': array([0.3114987], dtype=float32), 'probabilities': array([0.68850136, 0.3114987 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.9426663], dtype=float32), 'logistic': array([0.05008427], dtype=float32), 'probabilities': array([0.9499157 , 0.05008427], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.8455608], dtype=float32), 'logistic': array([0.05491124], dtype=float32), 'probabilities': array([0.9450888 , 0.05491124], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2030567], dtype=float32), 'logistic': array([0.23093188], dtype=float32), 'probabilities': array([0.7690681, 0.2309319], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.2037413], dtype=float32), 'logistic': array([0.900585], dtype=float32), 'probabilities': array([0.09941501, 0.900585 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.8174108], dtype=float32), 'logistic': array([0.8602552], dtype=float32), 'probabilities': array([0.13974483, 0.8602552 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.37055105], dtype=float32), 'logistic': array([0.40840787], dtype=float32), 'probabilities': array([0.5915921 , 0.40840787], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7657616], dtype=float32), 'logistic': array([0.05920263], dtype=float32), 'probabilities': array([0.9407973 , 0.05920264], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.6566277], dtype=float32), 'logistic': array([0.65850246], dtype=float32), 'probabilities': array([0.34149754, 0.65850246], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.054557], dtype=float32), 'logistic': array([0.04502114], dtype=float32), 'probabilities': array([0.9549789 , 0.04502114], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.4686661], dtype=float32), 'logistic': array([0.8128546], dtype=float32), 'probabilities': array([0.18714546, 0.8128546 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7262318], dtype=float32), 'logistic': array([0.06144311], dtype=float32), 'probabilities': array([0.93855697, 0.06144311], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.8138404], dtype=float32), 'logistic': array([0.14017461], dtype=float32), 'probabilities': array([0.85982543, 0.14017461], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.3677834], dtype=float32), 'logistic': array([0.2029782], dtype=float32), 'probabilities': array([0.7970218, 0.2029782], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9812416], dtype=float32), 'logistic': array([0.12118654], dtype=float32), 'probabilities': array([0.8788134 , 0.12118655], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.8329594], dtype=float32), 'logistic': array([0.05556888], dtype=float32), 'probabilities': array([0.9444311 , 0.05556888], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.909467], dtype=float32), 'logistic': array([0.94831246], dtype=float32), 'probabilities': array([0.05168756, 0.94831246], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6777656], dtype=float32), 'logistic': array([0.06429818], dtype=float32), 'probabilities': array([0.93570185, 0.06429818], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4901235], dtype=float32), 'logistic': array([0.07655346], dtype=float32), 'probabilities': array([0.9234466 , 0.07655347], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.9858353], dtype=float32), 'logistic': array([0.0480699], dtype=float32), 'probabilities': array([0.9519301 , 0.04806991], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.7305622], dtype=float32), 'logistic': array([0.02341781], dtype=float32), 'probabilities': array([0.9765822 , 0.02341781], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.10355384], dtype=float32), 'logistic': array([0.47413462], dtype=float32), 'probabilities': array([0.5258653 , 0.47413462], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.12644523], dtype=float32), 'logistic': array([0.4684308], dtype=float32), 'probabilities': array([0.53156924, 0.46843073], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.1059955], dtype=float32), 'logistic': array([0.75138175], dtype=float32), 'probabilities': array([0.2486182 , 0.75138175], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.9293168], dtype=float32), 'logistic': array([0.05072322], dtype=float32), 'probabilities': array([0.9492768 , 0.05072321], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.5827293], dtype=float32), 'logistic': array([0.35830483], dtype=float32), 'probabilities': array([0.6416952 , 0.35830483], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6044114], dtype=float32), 'logistic': array([0.06885505], dtype=float32), 'probabilities': array([0.9311449 , 0.06885505], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4262745], dtype=float32), 'logistic': array([0.08119095], dtype=float32), 'probabilities': array([0.91880906, 0.08119094], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5261686], dtype=float32), 'logistic': array([0.07404391], dtype=float32), 'probabilities': array([0.925956 , 0.07404391], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.1543393], dtype=float32), 'logistic': array([0.8960736], dtype=float32), 'probabilities': array([0.10392642, 0.8960736 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.33856195], dtype=float32), 'logistic': array([0.41615885], dtype=float32), 'probabilities': array([0.5838412 , 0.41615885], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7633388], dtype=float32), 'logistic': array([0.05933773], dtype=float32), 'probabilities': array([0.94066226, 0.05933773], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.43802434], dtype=float32), 'logistic': array([0.60778815], dtype=float32), 'probabilities': array([0.39221182, 0.60778815], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.6605018], dtype=float32), 'logistic': array([0.8403054], dtype=float32), 'probabilities': array([0.15969464, 0.8403054 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4654174], dtype=float32), 'logistic': array([0.07831839], dtype=float32), 'probabilities': array([0.92168164, 0.07831839], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.6012243], dtype=float32), 'logistic': array([0.16781057], dtype=float32), 'probabilities': array([0.83218944, 0.16781057], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.62694246], dtype=float32), 'logistic': array([0.6517959], dtype=float32), 'probabilities': array([0.34820414, 0.6517959 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.8446095], dtype=float32), 'logistic': array([0.13650705], dtype=float32), 'probabilities': array([0.86349297, 0.13650703], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.8798323], dtype=float32), 'logistic': array([0.29321253], dtype=float32), 'probabilities': array([0.70678747, 0.29321253], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.0991914], dtype=float32), 'logistic': array([0.04314062], dtype=float32), 'probabilities': array([0.95685935, 0.04314062], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3084338], dtype=float32), 'logistic': array([0.09042689], dtype=float32), 'probabilities': array([0.9095731 , 0.09042688], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.7597777], dtype=float32), 'logistic': array([0.14681819], dtype=float32), 'probabilities': array([0.8531818 , 0.14681819], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7622488], dtype=float32), 'logistic': array([0.0593986], dtype=float32), 'probabilities': array([0.9406014 , 0.05939861], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.3901472], dtype=float32), 'logistic': array([0.03260481], dtype=float32), 'probabilities': array([0.96739525, 0.03260481], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.98420286], dtype=float32), 'logistic': array([0.7279414], dtype=float32), 'probabilities': array([0.27205864, 0.7279414 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.60459214], dtype=float32), 'logistic': array([0.3532938], dtype=float32), 'probabilities': array([0.6467062 , 0.35329378], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.01147], dtype=float32), 'logistic': array([0.1180039], dtype=float32), 'probabilities': array([0.88199615, 0.1180039 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9486164], dtype=float32), 'logistic': array([0.12470432], dtype=float32), 'probabilities': array([0.8752957, 0.1247043], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.342761], dtype=float32), 'logistic': array([0.08764288], dtype=float32), 'probabilities': array([0.9123571 , 0.08764288], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.65887266], dtype=float32), 'logistic': array([0.6590071], dtype=float32), 'probabilities': array([0.34099287, 0.6590071 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.1499498], dtype=float32), 'logistic': array([0.8956641], dtype=float32), 'probabilities': array([0.10433591, 0.8956641 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.4248641], dtype=float32), 'logistic': array([0.60464656], dtype=float32), 'probabilities': array([0.3953534 , 0.60464656], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4196322], dtype=float32), 'logistic': array([0.08168784], dtype=float32), 'probabilities': array([0.91831213, 0.08168785], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.748317], dtype=float32), 'logistic': array([0.06018177], dtype=float32), 'probabilities': array([0.93981826, 0.06018177], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.8463914], dtype=float32), 'logistic': array([0.05486815], dtype=float32), 'probabilities': array([0.9451319 , 0.05486815], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.373101], dtype=float32), 'logistic': array([0.08524701], dtype=float32), 'probabilities': array([0.914753 , 0.08524701], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3426602], dtype=float32), 'logistic': array([0.08765095], dtype=float32), 'probabilities': array([0.91234905, 0.08765095], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.799353], dtype=float32), 'logistic': array([0.9781049], dtype=float32), 'probabilities': array([0.02189513, 0.9781049 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.58036435], dtype=float32), 'logistic': array([0.35884875], dtype=float32), 'probabilities': array([0.64115125, 0.35884875], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.1147056], dtype=float32), 'logistic': array([0.8923243], dtype=float32), 'probabilities': array([0.10767572, 0.8923243 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5261686], dtype=float32), 'logistic': array([0.07404391], dtype=float32), 'probabilities': array([0.925956 , 0.07404391], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6065052], dtype=float32), 'logistic': array([0.06872093], dtype=float32), 'probabilities': array([0.93127906, 0.06872093], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.501496], dtype=float32), 'logistic': array([0.62281084], dtype=float32), 'probabilities': array([0.37718916, 0.62281084], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.7377956], dtype=float32), 'logistic': array([0.3234864], dtype=float32), 'probabilities': array([0.6765136 , 0.32348636], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.37305564], dtype=float32), 'logistic': array([0.40780288], dtype=float32), 'probabilities': array([0.5921971 , 0.40780288], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3642364], dtype=float32), 'logistic': array([0.08594083], dtype=float32), 'probabilities': array([0.91405916, 0.08594082], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2840674], dtype=float32), 'logistic': array([0.09245113], dtype=float32), 'probabilities': array([0.9075489 , 0.09245112], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.5848625], dtype=float32), 'logistic': array([0.929881], dtype=float32), 'probabilities': array([0.07011902, 0.929881 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.7057766], dtype=float32), 'logistic': array([0.15371232], dtype=float32), 'probabilities': array([0.84628767, 0.15371232], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.7495186], dtype=float32), 'logistic': array([0.93988615], dtype=float32), 'probabilities': array([0.06011384, 0.93988615], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2857865], dtype=float32), 'logistic': array([0.21656686], dtype=float32), 'probabilities': array([0.78343314, 0.21656683], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.8839439], dtype=float32), 'logistic': array([0.13193652], dtype=float32), 'probabilities': array([0.8680635 , 0.13193652], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7959197], dtype=float32), 'logistic': array([0.05754507], dtype=float32), 'probabilities': array([0.942455 , 0.05754507], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9527406], dtype=float32), 'logistic': array([0.12425485], dtype=float32), 'probabilities': array([0.8757452 , 0.12425484], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.024848], dtype=float32), 'logistic': array([0.9536842], dtype=float32), 'probabilities': array([0.04631587, 0.9536842 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.3218334], dtype=float32), 'logistic': array([0.91066915], dtype=float32), 'probabilities': array([0.08933079, 0.91066915], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4524894], dtype=float32), 'logistic': array([0.0792567], dtype=float32), 'probabilities': array([0.92074335, 0.0792567 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.9719923], dtype=float32), 'logistic': array([0.04870733], dtype=float32), 'probabilities': array([0.95129263, 0.04870733], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.2123423], dtype=float32), 'logistic': array([0.90135235], dtype=float32), 'probabilities': array([0.09864761, 0.90135235], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.8084958], dtype=float32), 'logistic': array([0.05686681], dtype=float32), 'probabilities': array([0.94313323, 0.05686681], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.20995384], dtype=float32), 'logistic': array([0.5522965], dtype=float32), 'probabilities': array([0.44770348, 0.5522965 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.0149048], dtype=float32), 'logistic': array([0.4962739], dtype=float32), 'probabilities': array([0.5037261 , 0.49627388], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.5792875], dtype=float32), 'logistic': array([0.17089641], dtype=float32), 'probabilities': array([0.8291036, 0.1708964], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.47344178], dtype=float32), 'logistic': array([0.61619806], dtype=float32), 'probabilities': array([0.38380194, 0.61619806], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.1255076], dtype=float32), 'logistic': array([0.9579327], dtype=float32), 'probabilities': array([0.04206727, 0.9579327 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2310938], dtype=float32), 'logistic': array([0.22599006], dtype=float32), 'probabilities': array([0.77400994, 0.22599004], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.9486164], dtype=float32), 'logistic': array([0.12470432], dtype=float32), 'probabilities': array([0.8752957, 0.1247043], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.6758802], dtype=float32), 'logistic': array([0.15764177], dtype=float32), 'probabilities': array([0.8423583 , 0.15764178], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.6758802], dtype=float32), 'logistic': array([0.15764177], dtype=float32), 'probabilities': array([0.8423583 , 0.15764178], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.1053352], dtype=float32), 'logistic': array([0.89142066], dtype=float32), 'probabilities': array([0.10857935, 0.89142066], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.2504556], dtype=float32), 'logistic': array([0.22262126], dtype=float32), 'probabilities': array([0.7773787 , 0.22262126], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([4.121924], dtype=float32), 'logistic': array([0.9840454], dtype=float32), 'probabilities': array([0.01595462, 0.9840454 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4366162], dtype=float32), 'logistic': array([0.0804228], dtype=float32), 'probabilities': array([0.91957724, 0.08042281], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0104177], dtype=float32), 'logistic': array([0.11811347], dtype=float32), 'probabilities': array([0.88188654, 0.11811347], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.4168284], dtype=float32), 'logistic': array([0.9181016], dtype=float32), 'probabilities': array([0.08189841, 0.9181016 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.54403615], dtype=float32), 'logistic': array([0.36724913], dtype=float32), 'probabilities': array([0.6327508 , 0.36724913], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.9636338], dtype=float32), 'logistic': array([0.04909608], dtype=float32), 'probabilities': array([0.9509039 , 0.04909608], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5519195], dtype=float32), 'logistic': array([0.07229764], dtype=float32), 'probabilities': array([0.9277024 , 0.07229764], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.46189648], dtype=float32), 'logistic': array([0.61346394], dtype=float32), 'probabilities': array([0.386536 , 0.61346394], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7621477], dtype=float32), 'logistic': array([0.05940425], dtype=float32), 'probabilities': array([0.94059575, 0.05940425], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.0126083], dtype=float32), 'logistic': array([0.04685951], dtype=float32), 'probabilities': array([0.95314056, 0.04685951], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7296622], dtype=float32), 'logistic': array([0.06124558], dtype=float32), 'probabilities': array([0.9387544 , 0.06124558], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.3107316], dtype=float32), 'logistic': array([0.03520486], dtype=float32), 'probabilities': array([0.9647951 , 0.03520486], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.6646388], dtype=float32), 'logistic': array([0.15914027], dtype=float32), 'probabilities': array([0.8408597 , 0.15914027], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.7685025], dtype=float32), 'logistic': array([0.85427135], dtype=float32), 'probabilities': array([0.14572866, 0.85427135], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.43624622], dtype=float32), 'logistic': array([0.3926358], dtype=float32), 'probabilities': array([0.6073642, 0.3926358], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2930777], dtype=float32), 'logistic': array([0.09169789], dtype=float32), 'probabilities': array([0.9083021 , 0.09169789], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.9271967], dtype=float32), 'logistic': array([0.05082539], dtype=float32), 'probabilities': array([0.94917464, 0.05082539], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.6008396], dtype=float32), 'logistic': array([0.02657527], dtype=float32), 'probabilities': array([0.9734247 , 0.02657527], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.0633711], dtype=float32), 'logistic': array([0.74333423], dtype=float32), 'probabilities': array([0.25666577, 0.74333423], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3428822], dtype=float32), 'logistic': array([0.08763321], dtype=float32), 'probabilities': array([0.9123668 , 0.08763321], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0574672], dtype=float32), 'logistic': array([0.11330004], dtype=float32), 'probabilities': array([0.8867 , 0.11330003], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.216966], dtype=float32), 'logistic': array([0.7715292], dtype=float32), 'probabilities': array([0.2284708, 0.7715292], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.9985375], dtype=float32), 'logistic': array([0.95250803], dtype=float32), 'probabilities': array([0.04749199, 0.95250803], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.6501275], dtype=float32), 'logistic': array([0.34296083], dtype=float32), 'probabilities': array([0.65703917, 0.3429608 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.743388], dtype=float32), 'logistic': array([0.06046116], dtype=float32), 'probabilities': array([0.9395389 , 0.06046116], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6399152], dtype=float32), 'logistic': array([0.0666133], dtype=float32), 'probabilities': array([0.93338674, 0.06661331], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.6247033], dtype=float32), 'logistic': array([0.16455725], dtype=float32), 'probabilities': array([0.8354428 , 0.16455725], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.45485705], dtype=float32), 'logistic': array([0.38820657], dtype=float32), 'probabilities': array([0.6117934 , 0.38820657], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4900424], dtype=float32), 'logistic': array([0.0765592], dtype=float32), 'probabilities': array([0.92344075, 0.07655919], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.3192466], dtype=float32), 'logistic': array([0.21094365], dtype=float32), 'probabilities': array([0.78905636, 0.21094365], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.37605423], dtype=float32), 'logistic': array([0.40707892], dtype=float32), 'probabilities': array([0.5929211, 0.4070789], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.4793921], dtype=float32), 'logistic': array([0.38239568], dtype=float32), 'probabilities': array([0.6176043 , 0.38239568], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6358168], dtype=float32), 'logistic': array([0.06686858], dtype=float32), 'probabilities': array([0.93313146, 0.06686859], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.34706], dtype=float32), 'logistic': array([0.08729974], dtype=float32), 'probabilities': array([0.91270024, 0.08729975], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.4827917], dtype=float32), 'logistic': array([0.81499386], dtype=float32), 'probabilities': array([0.18500613, 0.81499386], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.10792582], dtype=float32), 'logistic': array([0.5269553], dtype=float32), 'probabilities': array([0.47304472, 0.5269553 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0841262], dtype=float32), 'logistic': array([0.11064926], dtype=float32), 'probabilities': array([0.8893508 , 0.11064927], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.48042685], dtype=float32), 'logistic': array([0.6178487], dtype=float32), 'probabilities': array([0.38215134, 0.6178487 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.137316], dtype=float32), 'logistic': array([0.10552245], dtype=float32), 'probabilities': array([0.89447755, 0.10552246], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.2999605], dtype=float32), 'logistic': array([0.7858284], dtype=float32), 'probabilities': array([0.21417166, 0.7858284 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.38928968], dtype=float32), 'logistic': array([0.5961117], dtype=float32), 'probabilities': array([0.40388831, 0.5961117 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.8880947], dtype=float32), 'logistic': array([0.05274523], dtype=float32), 'probabilities': array([0.9472548 , 0.05274523], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.4113041], dtype=float32), 'logistic': array([0.19602844], dtype=float32), 'probabilities': array([0.8039716 , 0.19602844], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.34706], dtype=float32), 'logistic': array([0.08729974], dtype=float32), 'probabilities': array([0.91270024, 0.08729975], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.9498808], dtype=float32), 'logistic': array([0.95025784], dtype=float32), 'probabilities': array([0.04974214, 0.95025784], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6778057], dtype=float32), 'logistic': array([0.06429576], dtype=float32), 'probabilities': array([0.93570423, 0.06429576], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.1443386], dtype=float32), 'logistic': array([0.04131493], dtype=float32), 'probabilities': array([0.95868504, 0.04131493], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.1913788], dtype=float32), 'logistic': array([0.10052735], dtype=float32), 'probabilities': array([0.89947265, 0.10052735], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.8657877], dtype=float32), 'logistic': array([0.05387095], dtype=float32), 'probabilities': array([0.946129 , 0.05387094], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.84821314], dtype=float32), 'logistic': array([0.29980782], dtype=float32), 'probabilities': array([0.70019215, 0.29980782], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.43984145], dtype=float32), 'logistic': array([0.60822123], dtype=float32), 'probabilities': array([0.39177877, 0.60822123], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.87488425], dtype=float32), 'logistic': array([0.294239], dtype=float32), 'probabilities': array([0.70576096, 0.294239 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.1600134], dtype=float32), 'logistic': array([0.04069853], dtype=float32), 'probabilities': array([0.9593015 , 0.04069853], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.8453041], dtype=float32), 'logistic': array([0.86357486], dtype=float32), 'probabilities': array([0.1364252 , 0.86357486], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7159994], dtype=float32), 'logistic': array([0.06203584], dtype=float32), 'probabilities': array([0.93796414, 0.06203584], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.027527], dtype=float32), 'logistic': array([0.11634292], dtype=float32), 'probabilities': array([0.8836571 , 0.11634292], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4939191], dtype=float32), 'logistic': array([0.07628557], dtype=float32), 'probabilities': array([0.9237144 , 0.07628556], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.5261686], dtype=float32), 'logistic': array([0.07404391], dtype=float32), 'probabilities': array([0.925956 , 0.07404391], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4900424], dtype=float32), 'logistic': array([0.0765592], dtype=float32), 'probabilities': array([0.92344075, 0.07655919], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.6892314], dtype=float32), 'logistic': array([0.84412307], dtype=float32), 'probabilities': array([0.15587695, 0.84412307], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.0002286], dtype=float32), 'logistic': array([0.11917893], dtype=float32), 'probabilities': array([0.8808211 , 0.11917892], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.552121], dtype=float32), 'logistic': array([0.07228412], dtype=float32), 'probabilities': array([0.92771584, 0.07228412], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.340217], dtype=float32), 'logistic': array([0.08784653], dtype=float32), 'probabilities': array([0.91215354, 0.08784652], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.7466123], dtype=float32), 'logistic': array([0.9769464], dtype=float32), 'probabilities': array([0.02305354, 0.9769464 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4643068], dtype=float32), 'logistic': array([0.07839859], dtype=float32), 'probabilities': array([0.9216014 , 0.07839859], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.760916], dtype=float32), 'logistic': array([0.0594731], dtype=float32), 'probabilities': array([0.94052684, 0.0594731 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.5058585], dtype=float32), 'logistic': array([0.37616488], dtype=float32), 'probabilities': array([0.62383515, 0.37616488], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.394353], dtype=float32), 'logistic': array([0.19871373], dtype=float32), 'probabilities': array([0.8012862 , 0.19871372], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.429804], dtype=float32), 'logistic': array([0.919072], dtype=float32), 'probabilities': array([0.08092804, 0.919072 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2767317], dtype=float32), 'logistic': array([0.09306845], dtype=float32), 'probabilities': array([0.9069315 , 0.09306846], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.350621], dtype=float32), 'logistic': array([0.03387484], dtype=float32), 'probabilities': array([0.9661252 , 0.03387484], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.3676736], dtype=float32), 'logistic': array([0.03332116], dtype=float32), 'probabilities': array([0.96667886, 0.03332116], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.2906396], dtype=float32), 'logistic': array([0.09190115], dtype=float32), 'probabilities': array([0.9080989 , 0.09190115], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.7021495], dtype=float32), 'logistic': array([0.15418473], dtype=float32), 'probabilities': array([0.84581524, 0.15418473], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.891106], dtype=float32), 'logistic': array([0.947405], dtype=float32), 'probabilities': array([0.05259499, 0.947405 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4648888], dtype=float32), 'logistic': array([0.07835655], dtype=float32), 'probabilities': array([0.92164344, 0.07835656], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([3.018937], dtype=float32), 'logistic': array([0.9534223], dtype=float32), 'probabilities': array([0.04657765, 0.9534223 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.74305326], dtype=float32), 'logistic': array([0.67766315], dtype=float32), 'probabilities': array([0.32233682, 0.67766315], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.9401757], dtype=float32), 'logistic': array([0.87437147], dtype=float32), 'probabilities': array([0.12562856, 0.87437147], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.0429635], dtype=float32), 'logistic': array([0.2605786], dtype=float32), 'probabilities': array([0.7394214, 0.2605786], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.2021129], dtype=float32), 'logistic': array([0.7689004], dtype=float32), 'probabilities': array([0.23109956, 0.7689004 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.4534457], dtype=float32), 'logistic': array([0.03066627], dtype=float32), 'probabilities': array([0.9693337 , 0.03066627], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.717829], dtype=float32), 'logistic': array([0.15215102], dtype=float32), 'probabilities': array([0.847849 , 0.15215102], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.5187134], dtype=float32), 'logistic': array([0.6268469], dtype=float32), 'probabilities': array([0.37315315, 0.6268469 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.0342655], dtype=float32), 'logistic': array([0.88434803], dtype=float32), 'probabilities': array([0.11565194, 0.88434803], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-0.83103323], dtype=float32), 'logistic': array([0.30342665], dtype=float32), 'probabilities': array([0.6965734 , 0.30342665], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.4835986], dtype=float32), 'logistic': array([0.18488449], dtype=float32), 'probabilities': array([0.8151155 , 0.18488449], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6777656], dtype=float32), 'logistic': array([0.06429818], dtype=float32), 'probabilities': array([0.93570185, 0.06429818], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.9513098], dtype=float32), 'logistic': array([0.87558943], dtype=float32), 'probabilities': array([0.12441061, 0.87558943], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.0337292], dtype=float32), 'logistic': array([0.26236176], dtype=float32), 'probabilities': array([0.73763824, 0.26236176], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-3.5533388], dtype=float32), 'logistic': array([0.02783209], dtype=float32), 'probabilities': array([0.97216785, 0.02783209], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.041459], dtype=float32), 'logistic': array([0.8850817], dtype=float32), 'probabilities': array([0.11491823, 0.8850817 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([1.5367038], dtype=float32), 'logistic': array([0.8229851], dtype=float32), 'probabilities': array([0.17701496, 0.8229851 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.4166243], dtype=float32), 'logistic': array([0.08191376], dtype=float32), 'probabilities': array([0.91808623, 0.08191376], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.3841245], dtype=float32), 'logistic': array([0.08439132], dtype=float32), 'probabilities': array([0.9156087 , 0.08439132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.7616632], dtype=float32), 'logistic': array([0.05943133], dtype=float32), 'probabilities': array([0.9405686 , 0.05943132], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.49434847], dtype=float32), 'logistic': array([0.6211303], dtype=float32), 'probabilities': array([0.3788697, 0.6211303], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-1.8977379], dtype=float32), 'logistic': array([0.13036472], dtype=float32), 'probabilities': array([0.8696353, 0.1303647], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6399152], dtype=float32), 'logistic': array([0.0666133], dtype=float32), 'probabilities': array([0.93338674, 0.06661331], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([2.666312], dtype=float32), 'logistic': array([0.9350093], dtype=float32), 'probabilities': array([0.06499072, 0.9350093 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([0.24823135], dtype=float32), 'logistic': array([0.5617411], dtype=float32), 'probabilities': array([0.43825886, 0.5617411 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}, {'logits': array([-2.6578376], dtype=float32), 'logistic': array([0.06550758], dtype=float32), 'probabilities': array([0.93449247, 0.06550758], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "QTKZMrZ0jXEJ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "e1f5f295-7c66-4c19-a90b-cd094da34861"
},
"source": [
"print(\"Probability that in Titanic during a shipwreck, \")\n",
"print(\"A random person will survive : \", pred[0]['probabilities'][1])\n",
"print(\"A random person will die : \", pred[0]['probabilities'][0])"
],
"execution_count": 99,
"outputs": [
{
"output_type": "stream",
"text": [
"Probability that in Titanic during a shipwreck, \n",
"A random person will survive : 0.08350867\n",
"A random person will die : 0.9164914\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IpRfDFpCo16w",
"colab_type": "text"
},
"source": [
"Let's see a more detailed prediction about few people in titanic based on their details."
]
},
{
"cell_type": "code",
"metadata": {
"id": "wT0alX7yj5dm",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "3eb716d0-95ce-4d1a-f212-72aac8fd8332"
},
"source": [
"for i in range(2, 20, 2):\n",
" print(\"-----\\nDetails of a person: \\n\", df_test.loc[i])\n",
" print(\"Probabiliy of survival: \", pred[i]['probabilities'][1])\n",
" print(\"Probabiliy of death: \", pred[i]['probabilities'][0])\n",
" print()"
],
"execution_count": 100,
"outputs": [
{
"output_type": "stream",
"text": [
"-----\n",
"Details of a person: \n",
" sex female\n",
"age 58\n",
"n_siblings_spouses 0\n",
"parch 0\n",
"fare 26.55\n",
"class First\n",
"deck C\n",
"embark_town Southampton\n",
"alone y\n",
"Name: 2, dtype: object\n",
"Probabiliy of survival: 0.5778602\n",
"Probabiliy of death: 0.42213982\n",
"\n",
"-----\n",
"Details of a person: \n",
" sex male\n",
"age 34\n",
"n_siblings_spouses 0\n",
"parch 0\n",
"fare 13\n",
"class Second\n",
"deck D\n",
"embark_town Southampton\n",
"alone y\n",
"Name: 4, dtype: object\n",
"Probabiliy of survival: 0.07680218\n",
"Probabiliy of death: 0.9231978\n",
"\n",
"-----\n",
"Details of a person: \n",
" sex female\n",
"age 8\n",
"n_siblings_spouses 3\n",
"parch 1\n",
"fare 21.075\n",
"class Third\n",
"deck unknown\n",
"embark_town Southampton\n",
"alone n\n",
"Name: 6, dtype: object\n",
"Probabiliy of survival: 0.51535875\n",
"Probabiliy of death: 0.4846412\n",
"\n",
"-----\n",
"Details of a person: \n",
" sex female\n",
"age 18\n",
"n_siblings_spouses 2\n",
"parch 0\n",
"fare 18\n",
"class Third\n",
"deck unknown\n",
"embark_town Southampton\n",
"alone n\n",
"Name: 8, dtype: object\n",
"Probabiliy of survival: 0.83662623\n",
"Probabiliy of death: 0.16337374\n",
"\n",
"-----\n",
"Details of a person: \n",
" sex female\n",
"age 28\n",
"n_siblings_spouses 0\n",
"parch 0\n",
"fare 7.75\n",
"class Third\n",
"deck unknown\n",
"embark_town Queenstown\n",
"alone y\n",
"Name: 10, dtype: object\n",
"Probabiliy of survival: 0.03349888\n",
"Probabiliy of death: 0.9665011\n",
"\n",
"-----\n",
"Details of a person: \n",
" sex female\n",
"age 5\n",
"n_siblings_spouses 1\n",
"parch 2\n",
"fare 27.75\n",
"class Second\n",
"deck unknown\n",
"embark_town Southampton\n",
"alone n\n",
"Name: 12, dtype: object\n",
"Probabiliy of survival: 0.14107606\n",
"Probabiliy of death: 0.8589239\n",
"\n",
"-----\n",
"Details of a person: \n",
" sex female\n",
"age 29\n",
"n_siblings_spouses 0\n",
"parch 0\n",
"fare 10.5\n",
"class Second\n",
"deck F\n",
"embark_town Southampton\n",
"alone y\n",
"Name: 14, dtype: object\n",
"Probabiliy of survival: 0.56723243\n",
"Probabiliy of death: 0.43276757\n",
"\n",
"-----\n",
"Details of a person: \n",
" sex male\n",
"age 26\n",
"n_siblings_spouses 1\n",
"parch 0\n",
"fare 14.4542\n",
"class Third\n",
"deck unknown\n",
"embark_town Cherbourg\n",
"alone n\n",
"Name: 16, dtype: object\n",
"Probabiliy of survival: 0.10752523\n",
"Probabiliy of death: 0.8924748\n",
"\n",
"-----\n",
"Details of a person: \n",
" sex female\n",
"age 33\n",
"n_siblings_spouses 3\n",
"parch 0\n",
"fare 15.85\n",
"class Third\n",
"deck unknown\n",
"embark_town Southampton\n",
"alone n\n",
"Name: 18, dtype: object\n",
"Probabiliy of survival: 0.41871348\n",
"Probabiliy of death: 0.58128655\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KWbN7IIdfs_b",
"colab_type": "text"
},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jdPm5S14ptcW",
"colab_type": "text"
},
"source": [
"## Basic Image Classification with Tensorflow and keras"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "itoR1B6_qIXs",
"colab_type": "text"
},
"source": [
"### Import Packages"
]
},
{
"cell_type": "code",
"metadata": {
"id": "XPaBsdJMpsm_",
"colab_type": "code",
"colab": {}
},
"source": [
"# TensorFlow and tf.keras\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"\n",
"# Helper libraries\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import clear_output\n",
"\n",
"# Hide warnings\n",
"import warnings\n",
"warnings.filterwarnings('ignore')"
],
"execution_count": 101,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "DyldeMVrqMIs",
"colab_type": "text"
},
"source": [
"### Import the Fashion MNIST Dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "_Fi0XZzLqMnP",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 150
},
"outputId": "e33fd5e5-747e-4247-98c5-c54ddebbc33f"
},
"source": [
"fashion_mnist = keras.datasets.fashion_mnist\n",
"\n",
"(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()"
],
"execution_count": 102,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz\n",
"32768/29515 [=================================] - 0s 0us/step\n",
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz\n",
"26427392/26421880 [==============================] - 0s 0us/step\n",
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz\n",
"8192/5148 [===============================================] - 0s 0us/step\n",
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz\n",
"4423680/4422102 [==============================] - 0s 0us/step\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J2HtkM88s_z2",
"colab_type": "text"
},
"source": [
"### About the Fashion MNIST dataset\n",
"\n",
"The images are 28x28 NumPy arrays, with pixel values ranging from 0 to 255. The labels are an array of integers, ranging from 0 to 9. These correspond to the class of clothing the image represents:\n",
"\n",
"| label | class |\n",
"|---|---|\n",
"| 0 | T-shirt/Top |\n",
"| 1 | Trouser |\n",
"| 2 | Pullover |\n",
"| 3 | Dress |\n",
"| 4 | Coat |\n",
"| 5 | Sandal |\n",
"| 6 | Shirt |\n",
"| 7 | Sneaker |\n",
"| 8 | Bag |\n",
"| 9 | Ankle Boot |\n",
"\n",
"Each image is mapped to a single label. Since the class names are not included with the dataset, store them here to use later when plotting the images:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rnPYUtVDrU1O",
"colab_type": "code",
"colab": {}
},
"source": [
"class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n",
" 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']"
],
"execution_count": 103,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "7uWnRWxOrWfR",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "de51656d-1670-4325-8b0d-0217415cfe7f"
},
"source": [
"train_images.shape"
],
"execution_count": 104,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(60000, 28, 28)"
]
},
"metadata": {
"tags": []
},
"execution_count": 104
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "0TkBPuASrdwO",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "1965ba87-9ca8-499f-c72e-c3c4302f3802"
},
"source": [
"test_images.shape"
],
"execution_count": 105,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(10000, 28, 28)"
]
},
"metadata": {
"tags": []
},
"execution_count": 105
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7YlSKQPtpvbP",
"colab_type": "text"
},
"source": [
"### Preprocess the data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mqwJrnAtvM45",
"colab_type": "text"
},
"source": [
"The data must be preprocessed before training the network. If you inspect the first image in the training set, you will see that the pixel values fall in the range of 0 to 255:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "KHrlSeqkummm",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 252
},
"outputId": "f20cb948-c209-4709-934b-5fdd1994237f"
},
"source": [
"plt.figure()\n",
"plt.imshow(train_images[0], cmap='binary')\n",
"plt.xticks([])\n",
"plt.yticks([])\n",
"plt.grid(False)\n",
"plt.grid(False)\n",
"plt.show()"
],
"execution_count": 106,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAK0UlEQVR4nO3dzW/M6x/G8bseSqccj1VKVTwkmrQiMqS6bEJUYuEvsLKysCJi4S+wtyARG0klGisREhIaGkY0FiIlFE1plSjV1lN7tueXn+/1MTN1uJz3a3udezrT6ZWvnE/u+66YmppKAH5/M371GwDwYygrYIKyAiYoK2CCsgImKCtgYlYx//HSpUun1qxZ85PeCoC+vr40PDxc8b2sqLKuWbMmFQqF6XlXAP5PPp/PzPhnMGCCsgImKCtggrICJigrYIKyAiYoK2CCsgImKCtggrICJigrYIKyAiYoK2CCsgImKCtggrICJigrYIKyAiYoK2CCsgImijowDf++6OKwiorvHoT3wz58+CDzrq6uzKy9vb2snx19tm/fvmVms2b92j/dci50K/U748kKmKCsgAnKCpigrIAJygqYoKyACcoKmGDO+pubnJyU+cyZM2X++PFjmZ86dUrmVVVVmVl1dbVcO3fuXJlv27ZN5uXMUqM5aPR7jdaX897U/FjhyQqYoKyACcoKmKCsgAnKCpigrIAJygqYYM76m4tmctGc9erVqzK/cuWKzOvr6zOzT58+ybVjY2Myv3z5ssz379+fmdXW1sq10Z7R6PcWGR0dzcxmzNDPwFwuV9LP5MkKmKCsgAnKCpigrIAJygqYoKyACcoKmGDO+purrKwsa/2dO3dk3tfXJ3O17zPaE7pz506Z37t3T+aHDx/OzPL5vFzb3Nws88bGRpnfvn1b5ur32traKtdu3749M1NzdZ6sgAnKCpigrIAJygqYoKyACcoKmGB08xtQx15GW72iLW6FQkHmf/31l8w/fvyYmfX29sq1Ub5161aZr1+/PjNTW9RSSunmzZsy7+zslHl01Kg6RvXkyZNyrRrHqW2FPFkBE5QVMEFZAROUFTBBWQETlBUwQVkBExXR1Xb/lM/np6K53X9RMb/DYkVz1paWFplHW+Ai6rNFx3nOmTOnrJ+troyMfi9btmyR+YYNG2QefbZLly5lZk+ePJFrBwYGMrN8Pp8KhcJ3PxxPVsAEZQVMUFbABGUFTFBWwARlBUxQVsAE+1mnQTTz+5kWLVok85cvX8q8qqpK5upaxy9fvsi10Z5TNUdNKaXx8fHMLPqdd3V1yTza7xrNzgcHBzOzXbt2ybWl4skKmKCsgAnKCpigrIAJygqYoKyACcoKmGDOak6dM5uSvkIwpfjaRjWHXb58uVy7ZMkSmUd7bWfMyH6WRHPQ6HOrGW70s1PS+137+/vl2lLxZAVMUFbABGUFTFBWwARlBUxQVsAEZQVMMGedBtHML5plqpldtCdUnUGbUnx2r7orNKWUPn/+XPJrV1dXy3xkZETmak4bzZfV+04ppXnz5sn8/fv3Mm9ubs7M1J22Kek7c7mfFfgDUFbABGUFTFBWwARlBUxQVsAEo5tpEB2LGW3XUqObjo4OuTY6arSmpkbm0VYx9d6iEcXz589lPnv2bJmrY1BnzdJ/utExqdHnHh4elvmBAwcys56eHrn269evmZkaA/JkBUxQVsAEZQVMUFbABGUFTFBWwARlBUwwZ50Gam6WUrwNTWlqapJ5tE0tmjeWMwMeGhqSa6MrHRcvXixz9XuNPlc0A46uyqyvr5f52bNnM7NDhw7JtS0tLZmZ2lbIkxUwQVkBE5QVMEFZAROUFTBBWQETlBUw8a/OWdVevXKvJoyOA1V7J6Pr/SLR3spytLe3yzw6UlNd2ZhSfGSnEu2VjebPExMTMi9nPh19J9F3Hv093r9/PzNbsGCBXFsqnqyACcoKmKCsgAnKCpigrIAJygqYoKyAiWkdEJazN/Jnzip/tuvXr8v8/PnzMu/q6srMcrmcXKuuRUxJn72bUnzmsfpeovcW/T1E703NYaP3HV03GYnmz+r1Ozs75do9e/aU9J54sgImKCtggrICJigrYIKyAiYoK2CCsgImpnW4qeao5Xr79q3MBwYGZN7b21vy2mhupl47pfhsX7VXN5oXvnnzRuZ1dXUyj872VefzDg4OyrXR5x4bG5N5a2trZvbhwwe59saNGzKP9rNGe1LV/uju7m65tlQ8WQETlBUwQVkBE5QVMEFZAROUFTAxraObW7duyfzYsWOZ2evXr+Xad+/eyTz6X/FqPLJw4UK5NhpJzZ8/X+bRCEMdoxodJarGGyml1NHRIfOtW7fK/P3795lZNPbp6+uTeUQd9zk6OirXrlq1SubRSCwaK6krJcv93Fl4sgImKCtggrICJigrYIKyAiYoK2CCsgImip6zquMlDx48KNeqrWjlXtFXztGT0ZGY0awzyiMjIyOZ2bNnz+TaI0eOyDx6bydOnJD5ihUrMrNoztrW1ibzdevWyfzRo0eZWbQ1UG1hSym+jjK6YlT9vS5btkyuLRVPVsAEZQVMUFbABGUFTFBWwARlBUxQVsBEUXPW4eHhdObMmcw8mgmuXbs2M1P7A1OKj56M5m5KNHNTc9CU4r2TK1eulPn4+HhmVltbK9fu27dP5hcuXJB5dP3g06dPM7PoO7t7967Mr127JnM104/2CEez8+hKx4ias0av/eLFi5LW8mQFTFBWwARlBUxQVsAEZQVMUFbABGUFTBQ1Z509e7bcqxfNG9WsNJqbrV69uuTXTklfXajOxk0ppcWLF8u8oaFB5tF7U/tCoz2j0ZnGe/fulXlzc7PM1Rm40Ww7+k6j85rVntToc1dWVso8moVG+6fVWc8qS0lfEarmwzxZAROUFTBBWQETlBUwQVkBE5QVMFH06EaNZ6L/3V1fX5+ZRdutoishozFATU1NSVlK8Ra6aDtWtH5iYiIzi642VNvIUkppyZIlMn/w4IHM582bl5lF47RFixbJXH3ulPT3Eh1dGx1FGq1X2xZTSunVq1eZ2YIFC+Tanp6ezExdNcmTFTBBWQETlBUwQVkBE5QVMEFZAROUFTBR1Jw1l8ulzZs3Z+bRdqzTp09nZnV1dXJtdD1gtJVMzSuj7VLRzE1tv0spnrOq9x6traiokHkul5O5utIxJT07j7apRe89mo2Xs6Uyeu0oj7bYqTmuOr41JX28rHpdnqyACcoKmKCsgAnKCpigrIAJygqYoKyAiYro2MR/yufzU4VCoeQfdvHixczs+PHjcu3Q0JDMoz2paq4W7cOdnJyUebSfNdpzquaR0fcTzVmjWWc0Y1Z59NrF/G0Vu14difsjotl49Deh9rNu2rRJrj137lxmls/nU6FQ+O6XypMVMEFZAROUFTBBWQETlBUwQVkBE5QVMFHUftaU9Mwxmk3t3r27pCyllK5evSrzo0ePylxdXTgyMiLXRvPCaI4azfTUGbbRz47mjdEcNrqmU+21VWcKpxT/XsoR7TeN9vFGs/MdO3bIvLGxMTNrbW2Va0vFkxUwQVkBE5QVMEFZAROUFTBBWQETlBUwUfScNZql/ixtbW0y7+7uLvm1Hz58KPPobtjoHtL+/n6ZNzQ0ZGbRPDE6Txl/Dp6sgAnKCpigrIAJygqYoKyACcoKmCh6dPMn2rhxY1l5pKmpqaz1QEo8WQEblBUwQVkBE5QVMEFZAROUFTBBWQETlBUwQVkBE5QVMEFZAROUFTBBWQETlBUwQVkBExXRlYL/8x9XVLxOKT37eW8H+M9rmJqaqvleUFRZAfw6/DMYMEFZAROUFTBBWQETlBUwQVkBE5QVMEFZAROUFTDxN/Yv91E2qyW8AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "mTlTY3dZunAS",
"colab_type": "code",
"colab": {}
},
"source": [
"# scaling the dataset images between 0 to 1\n",
"# for feeding into neural network model\n",
"train_images = train_images / 255.0\n",
"test_images = test_images / 255.0"
],
"execution_count": 107,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "uvqr9Nx0upZe",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 588
},
"outputId": "32c6b9f1-6edf-4f53-b4b6-df5e739438fe"
},
"source": [
"# displaying first 25 images to check if class names labelling is wrong\n",
"plt.figure(figsize=(10,10))\n",
"for i in range(25):\n",
" plt.subplot(5,5,i+1)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.grid(False)\n",
" plt.imshow(train_images[i], cmap='binary')\n",
" plt.xlabel(class_names[train_labels[i]])\n",
"plt.show()"
],
"execution_count": 108,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x720 with 25 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SuLwEd6qwiMg",
"colab_type": "text"
},
"source": [
"### Build the model\n",
"\n",
"Building the neural network requires configuring the layers of the model, then compiling the model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VIB9WuvDwoDm",
"colab_type": "text"
},
"source": [
"#### Set up the layers\n",
"\n",
"Most of deep learning consists of chaining together simple layers. Most layers, such as `tf.keras.layers.Dense`, have parameters that are learned during training."
]
},
{
"cell_type": "code",
"metadata": {
"id": "Tt5wpWXWupw9",
"colab_type": "code",
"colab": {}
},
"source": [
"minst_model = keras.Sequential([\n",
" # transforms the format of the images from a two-dimensional array (of 28 by 28 pixels) \n",
" # to a one-dimensional array (of 28 * 28 = 784 pixels). \n",
" keras.layers.Flatten(input_shape=(28, 28)),\n",
"\n",
" # creating two dense network layers\n",
" # first layer consists of 128 nodes, using relu activation for forward-backward propagation\n",
" keras.layers.Dense(128, activation='relu',\n",
" kernel_regularizer=tf.keras.regularizers.l2(0.001)),\n",
"\n",
" # last layer consits of logits array of 10 nodes\n",
" # each node will indicate image belongs to which one among 10 classes. \n",
" keras.layers.Dense(10)\n",
"])"
],
"execution_count": 109,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9d5LNXdIuqHk",
"colab_type": "code",
"colab": {}
},
"source": [
"minst_model.compile(optimizer='adam',\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])"
],
"execution_count": 110,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "W2RBmcCozMzW",
"colab_type": "text"
},
"source": [
"### Train the model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "NWV4vGkouplK",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 700
},
"outputId": "38140c84-a70e-46c7-e190-2149cb244c09"
},
"source": [
"minst_model.fit(train_images, train_labels, epochs=20)"
],
"execution_count": 111,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.6395 - accuracy: 0.8181\n",
"Epoch 2/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.5075 - accuracy: 0.8495\n",
"Epoch 3/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4792 - accuracy: 0.8566\n",
"Epoch 4/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4612 - accuracy: 0.8607\n",
"Epoch 5/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4501 - accuracy: 0.8638\n",
"Epoch 6/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4392 - accuracy: 0.8675\n",
"Epoch 7/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4341 - accuracy: 0.8700\n",
"Epoch 8/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4266 - accuracy: 0.8697\n",
"Epoch 9/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4210 - accuracy: 0.8719\n",
"Epoch 10/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4180 - accuracy: 0.8737\n",
"Epoch 11/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4142 - accuracy: 0.8737\n",
"Epoch 12/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4100 - accuracy: 0.8759\n",
"Epoch 13/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4110 - accuracy: 0.8746\n",
"Epoch 14/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4077 - accuracy: 0.8763\n",
"Epoch 15/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4042 - accuracy: 0.8778\n",
"Epoch 16/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.4044 - accuracy: 0.8774\n",
"Epoch 17/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.3997 - accuracy: 0.8783\n",
"Epoch 18/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.3984 - accuracy: 0.8778\n",
"Epoch 19/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.3983 - accuracy: 0.8777\n",
"Epoch 20/20\n",
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.3953 - accuracy: 0.8780\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7fd2afa5f438>"
]
},
"metadata": {
"tags": []
},
"execution_count": 111
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kUzi6zuYVtfu",
"colab_type": "text"
},
"source": [
"### Evaluate the model on the Test data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "fs_kdfrhupLM",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "e487f639-1e56-4d65-8612-d680192574a2"
},
"source": [
"test_loss, test_acc = minst_model.evaluate(test_images, test_labels, verbose=2)\n",
"\n",
"print('\\nTest accuracy:', test_acc)"
],
"execution_count": 112,
"outputs": [
{
"output_type": "stream",
"text": [
"313/313 - 0s - loss: 0.4441 - accuracy: 0.8642\n",
"\n",
"Test accuracy: 0.8641999959945679\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j56WsYeZS_4p",
"colab_type": "text"
},
"source": [
"### Use Model for predictions\n",
"\n",
"With the model trained, you can use it to make predictions about some images. The model's linear outputs, [logits](https://developers.google.com/machine-learning/glossary#logits). Attach a softmax layer to convert the logits to probabilities, which are easier to interpret."
]
},
{
"cell_type": "code",
"metadata": {
"id": "6tGUXaW5unbD",
"colab_type": "code",
"colab": {}
},
"source": [
"probability_model = tf.keras.Sequential([minst_model, \n",
" tf.keras.layers.Softmax()])\n",
"\n",
"predictions = probability_model.predict(test_images)"
],
"execution_count": 113,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "eZUYG7pPUMzi",
"colab_type": "text"
},
"source": [
"Here, the model has predicted the label for each image in the testing set.\n",
"\n",
"Each prediction contains an array of 10 numbers. Each number in the array representsthe model's \"confidence\" that the image corresponds to each of the 10 different articles of clothing."
]
},
{
"cell_type": "code",
"metadata": {
"id": "X_rVIutDunLU",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "5f31f3e9-7998-4ac0-c525-94bfee747159"
},
"source": [
"np.argmax(predictions[0])"
],
"execution_count": 114,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"9"
]
},
"metadata": {
"tags": []
},
"execution_count": 114
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ccXm_DqXYPJP",
"colab_type": "text"
},
"source": [
"So, the model is most confident that this image is an ankle boot, or `class_names[9]`. Examining the test label shows that this classification is correct:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "w-O59ojZumwY",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "6753367e-5113-44f5-956a-be6b95b2d807"
},
"source": [
"test_labels[0]"
],
"execution_count": 115,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"9"
]
},
"metadata": {
"tags": []
},
"execution_count": 115
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wGlJu6J4Yx0g",
"colab_type": "text"
},
"source": [
"### Final Representation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "yIQqrB7Tumc7",
"colab_type": "code",
"colab": {}
},
"source": [
"def plot_image(i, predictions_array, true_label, img):\n",
" true_label, img = true_label[i], img[i]\n",
" plt.grid(False)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
"\n",
" plt.imshow(img, cmap=plt.cm.binary)\n",
"\n",
" predicted_label = np.argmax(predictions_array)\n",
" if predicted_label == true_label:\n",
" color = 'blue'\n",
" else:\n",
" color = 'red'\n",
"\n",
" plt.xlabel(\"{} {:2.0f}% ({})\".format(class_names[predicted_label],\n",
" 100*np.max(predictions_array),\n",
" class_names[true_label]),\n",
" color=color)\n",
"\n",
"def plot_value_array(i, predictions_array, true_label):\n",
" true_label = true_label[i]\n",
" plt.grid(False)\n",
" plt.xticks(range(10))\n",
" plt.yticks([])\n",
" thisplot = plt.bar(range(10), predictions_array, color=\"#777777\")\n",
" plt.ylim([0, 1])\n",
" predicted_label = np.argmax(predictions_array)\n",
"\n",
" thisplot[predicted_label].set_color('red')\n",
" thisplot[true_label].set_color('blue')"
],
"execution_count": 116,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ohdenMcKZErQ",
"colab_type": "text"
},
"source": [
"Let's plot several images with their predictions. Note that the model can be wrong even when very confident."
]
},
{
"cell_type": "code",
"metadata": {
"id": "_xE1Q9L8Y74m",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 728
},
"outputId": "623788f5-0a4c-48f3-8d66-78bef24c1eb1"
},
"source": [
"num_rows = 5\n",
"num_cols = 3\n",
"\n",
"num_images = num_rows * num_cols\n",
"plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))\n",
"\n",
"for i in range(num_images):\n",
" plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)\n",
" plot_image(i, predictions[i], test_labels, test_images)\n",
" plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)\n",
" plot_value_array(i, predictions[i], test_labels)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
],
"execution_count": 117,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x720 with 30 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NgXr0mXFeOgb",
"colab_type": "text"
},
"source": [
"### Try Predicting with your own input images"
]
},
{
"cell_type": "code",
"metadata": {
"id": "nFR38XWJePEm",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 285
},
"outputId": "4ce7fc6d-247e-4011-8475-6710f5fc0aa9"
},
"source": [
"from skimage import color\n",
"from skimage.transform import resize\n",
"image = plt.imread('https://external-content.duckduckgo.com/iu/?u=http%3A%2F%2Fpngimg.com%2Fuploads%2Fdress_shirt%2Fdress_shirt_PNG8117.png&f=1&nofb=1')\n",
"image = color.rgb2gray(image)\n",
"plt.imshow(image, cmap='gray')\n",
"image = resize(image, (28, 28), anti_aliasing=True)\n",
"image = (np.expand_dims(image,0))\n",
"print(image.shape)"
],
"execution_count": 118,
"outputs": [
{
"output_type": "stream",
"text": [
"(1, 28, 28)\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "bVYmLWDDhFrw",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300
},
"outputId": "652e77bf-0bb4-4f86-b5ce-e60e2512d150"
},
"source": [
"predictions_single = probability_model.predict(image)\n",
"\n",
"plot_value_array(1, predictions_single[0], test_labels)\n",
"_ = plt.xticks(range(10), class_names, rotation=45)"
],
"execution_count": 122,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_zhXc9YFqo2c",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "8643e7fc-2f54-4dd0-9008-20dd02ca72f5"
},
"source": [
"np.argmax(predictions_single[0])"
],
"execution_count": 123,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"2"
]
},
"metadata": {
"tags": []
},
"execution_count": 123
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tNit8BrnpwN0",
"colab_type": "text"
},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J5-6jVbKq8Ex",
"colab_type": "text"
},
"source": [
"## Iris Flower Classification with Tensorflow"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6H6GeQkDrOjd",
"colab_type": "text"
},
"source": [
"### Import Packages"
]
},
{
"cell_type": "code",
"metadata": {
"id": "1EZFlJNypCQh",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import clear_output\n",
"import urllib\n",
"\n",
"# hide warnings\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"\n",
"import tensorflow as tf"
],
"execution_count": 124,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ehbZbkqP-rsg",
"colab_type": "text"
},
"source": [
"### Load the Iris Flower Dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "V90WMOI2-0EF",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 103
},
"outputId": "ed7a9307-834e-4406-e2e1-10f3d45d4884"
},
"source": [
"categories = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'species']\n",
"species = ['Setosa', 'Versicolor', 'Virginica']\n",
"\n",
"train_path = tf.keras.utils.get_file('iris_training.csv', \n",
" 'https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv')\n",
"test_path = tf.keras.utils.get_file('iris_test.csv', \n",
" 'https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv')\n",
"\n",
"iris_train = pd.read_csv(train_path, names=categories, header=0)\n",
"iris_test = pd.read_csv(test_path, names=categories, header=0)\n",
"\n",
"y_train = iris_train.pop('species')\n",
"y_test = iris_test.pop('species')"
],
"execution_count": 125,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv\n",
"\r8192/2194 [================================================================================================================] - 0s 0us/step\n",
"Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv\n",
"8192/573 [============================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 0us/step\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wmoZ5vmfBFMH",
"colab_type": "text"
},
"source": [
"### Create Base Feature Columns"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zcYLpO7uBQS8",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 53
},
"outputId": "d65b658b-e92d-4286-caf7-030fa4feb9dd"
},
"source": [
"feature_cols = list()\n",
"\n",
"for key in iris_train.keys():\n",
" feature_cols.append(tf.feature_column.numeric_column(key=key))\n",
"\n",
"print(feature_cols)"
],
"execution_count": 126,
"outputs": [
{
"output_type": "stream",
"text": [
"[NumericColumn(key='SepalLength', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None), NumericColumn(key='SepalWidth', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None), NumericColumn(key='PetalLength', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None), NumericColumn(key='PetalWidth', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None)]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SCMis3OYAxWi",
"colab_type": "text"
},
"source": [
"### Define a Input Function"
]
},
{
"cell_type": "code",
"metadata": {
"id": "CUQVVu8X81Po",
"colab_type": "code",
"colab": {}
},
"source": [
"def input_fn(features, target, training=True, batch_size=256):\n",
" dataset = tf.data.Dataset.from_tensor_slices((dict(features), target))\n",
"\n",
" if training:\n",
" dataset = dataset.shuffle(1000).repeat()\n",
" \n",
" return dataset.batch(batch_size)"
],
"execution_count": 127,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "xJ41FeGDC892",
"colab_type": "text"
},
"source": [
"### Building a Model\n",
"\n",
"We can use two kinds (most commonly used) of Classifier models in Tensorflow to build a classification model.\n",
"- `DNNClassifier`\n",
"- `LinearClassifier`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5OcqMGSPW0Td",
"colab_type": "text"
},
"source": [
"#### `LinearClassifier`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "dyubbq1CXIYB",
"colab_type": "code",
"colab": {}
},
"source": [
"# Build a Linear Classifier\n",
"iris_linear_classifier = tf.estimator.LinearClassifier(feature_columns=feature_cols,\n",
" n_classes=3,\n",
" optimizer=tf.keras.optimizers.Ftrl(\n",
" learning_rate=0.1,\n",
" l1_regularization_strength=0.001\n",
" ))\n",
"\n",
"clear_output()"
],
"execution_count": 128,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "G1VRoGZ3cHuR",
"colab_type": "text"
},
"source": [
"##### Training the Classifier Model with training data "
]
},
{
"cell_type": "code",
"metadata": {
"id": "5ZDS2h5Nbw6K",
"colab_type": "code",
"colab": {}
},
"source": [
"iris_linear_classifier.train(lambda: input_fn(iris_train, y_train, training=True), steps = 5000)\n",
"clear_output()"
],
"execution_count": 129,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hvTXV882c8Ws",
"colab_type": "text"
},
"source": [
"##### Evaluate the Classifier Model with training data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "2x415iD_c8_l",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "90dc3298-3dda-433f-c32f-3501b984b7ed"
},
"source": [
"result = iris_linear_classifier.evaluate(lambda: input_fn(iris_train, y_train, training=False))\n",
"clear_output()\n",
"print(\"Train Set Accuracy : %.2f \"%(result['accuracy']*100))"
],
"execution_count": 130,
"outputs": [
{
"output_type": "stream",
"text": [
"Train Set Accuracy : 98.33 \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bQVLN6jccZ8L",
"colab_type": "text"
},
"source": [
"##### Evaluate the Classifier Model with test data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "eRC02xFecbnZ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "f7c1e608-75ae-4e36-cf61-791609040406"
},
"source": [
"result = iris_linear_classifier.evaluate(lambda: input_fn(iris_test, y_test, training=False))\n",
"clear_output()\n",
"print(\"Test Set Accuracy : %.2f \"%(result['accuracy']*100))"
],
"execution_count": 131,
"outputs": [
{
"output_type": "stream",
"text": [
"Test Set Accuracy : 96.67 \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FTNWyYzRcx7O",
"colab_type": "text"
},
"source": [
"##### Use the model for prediction"
]
},
{
"cell_type": "code",
"metadata": {
"id": "e8nJcBsNcvKE",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 133
},
"outputId": "37915c93-f38a-4ca4-ee79-7ec861702dd2"
},
"source": [
"predict = dict()\n",
"categories = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth']\n",
"print('Please type info')\n",
"for category in categories:\n",
" valid = True\n",
" while valid:\n",
" val = input(category + ':')\n",
" if not val.isdigit():\n",
" valid = False\n",
" \n",
" predict[category] = [float(val)]\n",
"\n",
"def input_func(features, batch_size=256):\n",
" return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)\n",
"\n",
"predictions = iris_linear_classifier.predict(input_fn=lambda: input_func(predict))\n",
"clear_output()\n",
"for pred in predictions:\n",
" id = pred['class_ids'][0]\n",
" prob = pred['probabilities'][id]\n",
"\n",
" print('Prediction is: \"{}\" ({:.2f}%)'.format(species[id], prob * 100))"
],
"execution_count": 133,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Restoring parameters from /tmp/tmpx7p2d6pl/model.ckpt-5000\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"Prediction is: \"Versicolor\" (95.65%)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vCQkafAODfk_",
"colab_type": "text"
},
"source": [
"#### `DNNClassifier`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "RAEPRpMN96EV",
"colab_type": "code",
"colab": {}
},
"source": [
"# Build a DNN Layers with three hidden layers each having 128, 32 and 8 nodes\n",
"iris_dnn_classifier = tf.estimator.DNNClassifier(feature_columns=feature_cols,\n",
" \n",
" # 3 hidden layers of 128, 32 and 8 nodes\n",
" hidden_units = [128, 32, 8],\n",
" \n",
" # number of classes for number of classifications\n",
" # here, three species of flowers\n",
" n_classes=3,\n",
" \n",
" # adam optimizer for better efficiency\n",
" optimizer=lambda: tf.keras.optimizers.Adam(\n",
" learning_rate=tf.compat.v1.train.exponential_decay(\n",
" learning_rate=0.1,\n",
" global_step=tf.compat.v1.train.get_global_step(),\n",
" decay_steps=10000,\n",
" decay_rate=0.96)))\n",
"\n",
"clear_output()"
],
"execution_count": 134,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "pdEDmhngHmyJ",
"colab_type": "text"
},
"source": [
"##### Training the Classifier Model on training data "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JLmZwWOqKYBk",
"colab_type": "text"
},
"source": [
"While we train our model we must make sure that we obtain minimum loss.\n",
"\n",
"Here, the third-last line denotes the loss value:\n",
"~~~\n",
"INFO:tensorflow:Loss for final step: 0.047325253.\n",
"~~~\n",
"in\n",
"~~~\n",
"....\n",
"INFO:tensorflow:global_step/sec: 370.872\n",
"INFO:tensorflow:loss = 0.025431797, step = 4800 (0.270 sec)\n",
"INFO:tensorflow:global_step/sec: 393.163\n",
"INFO:tensorflow:loss = 0.027603408, step = 4900 (0.251 sec)\n",
"INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000...\n",
"INFO:tensorflow:Saving checkpoints for 5000 into /tmp/tmpnx65c0fz/model.ckpt.\n",
"INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000...\n",
"INFO:tensorflow:Loss for final step: 0.047325253.\n",
"<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7fd0a824afd0>\n",
"~~~\n",
"\n",
"to minimize the loss we can :\n",
"- add/reduce layers\n",
"- add/reduce nodes\n",
"\n",
"but remember always to use optimizer for better efficiency.\n",
"\n",
"If we obtain loss value at final step greater than 1, then we must tweak the hyperparameters (Hyperparameter Tunning) and re-build re-train the model."
]
},
{
"cell_type": "code",
"metadata": {
"id": "30vp6heKGctA",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "08d4f44a-4cc9-4b4f-fc3f-42386ff2aed3"
},
"source": [
"iris_dnn_classifier.train(lambda: input_fn(iris_train, y_train, training=True), steps = 5000)"
],
"execution_count": 135,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n",
"WARNING:tensorflow:Layer dnn is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because its dtype defaults to floatx.\n",
"\n",
"If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
"\n",
"To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
"\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n",
"INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpf039um29/model.ckpt.\n",
"INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n",
"INFO:tensorflow:loss = 0.9548463, step = 0\n",
"INFO:tensorflow:global_step/sec: 340.898\n",
"INFO:tensorflow:loss = 0.06929995, step = 100 (0.297 sec)\n",
"INFO:tensorflow:global_step/sec: 391.516\n",
"INFO:tensorflow:loss = 0.051905256, step = 200 (0.254 sec)\n",
"INFO:tensorflow:global_step/sec: 413.551\n",
"INFO:tensorflow:loss = 0.032315936, step = 300 (0.243 sec)\n",
"INFO:tensorflow:global_step/sec: 383.534\n",
"INFO:tensorflow:loss = 0.05966128, step = 400 (0.259 sec)\n",
"WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 401 vs previous value: 401. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.\n",
"INFO:tensorflow:global_step/sec: 404.175\n",
"INFO:tensorflow:loss = 0.03130503, step = 500 (0.249 sec)\n",
"INFO:tensorflow:global_step/sec: 411.51\n",
"INFO:tensorflow:loss = 0.041695025, step = 600 (0.243 sec)\n",
"INFO:tensorflow:global_step/sec: 412.703\n",
"INFO:tensorflow:loss = 0.04021568, step = 700 (0.244 sec)\n",
"INFO:tensorflow:global_step/sec: 388.643\n",
"INFO:tensorflow:loss = 0.0441709, step = 800 (0.257 sec)\n",
"INFO:tensorflow:global_step/sec: 401.047\n",
"INFO:tensorflow:loss = 0.027175507, step = 900 (0.249 sec)\n",
"INFO:tensorflow:global_step/sec: 412.232\n",
"INFO:tensorflow:loss = 0.029116593, step = 1000 (0.242 sec)\n",
"INFO:tensorflow:global_step/sec: 409.903\n",
"INFO:tensorflow:loss = 0.035351362, step = 1100 (0.244 sec)\n",
"INFO:tensorflow:global_step/sec: 380.454\n",
"INFO:tensorflow:loss = 0.15880767, step = 1200 (0.261 sec)\n",
"INFO:tensorflow:global_step/sec: 411.295\n",
"INFO:tensorflow:loss = 0.054302238, step = 1300 (0.243 sec)\n",
"INFO:tensorflow:global_step/sec: 386.703\n",
"INFO:tensorflow:loss = 0.03428327, step = 1400 (0.260 sec)\n",
"INFO:tensorflow:global_step/sec: 406.562\n",
"INFO:tensorflow:loss = 0.034377772, step = 1500 (0.244 sec)\n",
"INFO:tensorflow:global_step/sec: 382.043\n",
"INFO:tensorflow:loss = 0.07531372, step = 1600 (0.263 sec)\n",
"INFO:tensorflow:global_step/sec: 403.809\n",
"INFO:tensorflow:loss = 0.028360832, step = 1700 (0.248 sec)\n",
"INFO:tensorflow:global_step/sec: 403.088\n",
"INFO:tensorflow:loss = 0.035252944, step = 1800 (0.246 sec)\n",
"WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1894 vs previous value: 1894. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.\n",
"INFO:tensorflow:global_step/sec: 400.621\n",
"INFO:tensorflow:loss = 0.04120819, step = 1900 (0.251 sec)\n",
"INFO:tensorflow:global_step/sec: 385.629\n",
"INFO:tensorflow:loss = 0.03402465, step = 2000 (0.259 sec)\n",
"INFO:tensorflow:global_step/sec: 406.34\n",
"INFO:tensorflow:loss = 0.025264615, step = 2100 (0.247 sec)\n",
"INFO:tensorflow:global_step/sec: 400.315\n",
"INFO:tensorflow:loss = 0.019488363, step = 2200 (0.250 sec)\n",
"INFO:tensorflow:global_step/sec: 407.913\n",
"INFO:tensorflow:loss = 0.034090046, step = 2300 (0.245 sec)\n",
"INFO:tensorflow:global_step/sec: 383.383\n",
"INFO:tensorflow:loss = 0.025606574, step = 2400 (0.261 sec)\n",
"INFO:tensorflow:global_step/sec: 406.737\n",
"INFO:tensorflow:loss = 0.027330484, step = 2500 (0.244 sec)\n",
"INFO:tensorflow:global_step/sec: 409.151\n",
"INFO:tensorflow:loss = 0.10570714, step = 2600 (0.246 sec)\n",
"INFO:tensorflow:global_step/sec: 417.941\n",
"INFO:tensorflow:loss = 0.03930735, step = 2700 (0.239 sec)\n",
"INFO:tensorflow:global_step/sec: 384.056\n",
"INFO:tensorflow:loss = 0.026937068, step = 2800 (0.260 sec)\n",
"INFO:tensorflow:global_step/sec: 402.361\n",
"INFO:tensorflow:loss = 0.03000135, step = 2900 (0.249 sec)\n",
"INFO:tensorflow:global_step/sec: 407.095\n",
"INFO:tensorflow:loss = 0.049963575, step = 3000 (0.244 sec)\n",
"INFO:tensorflow:global_step/sec: 413.01\n",
"INFO:tensorflow:loss = 0.10355372, step = 3100 (0.241 sec)\n",
"INFO:tensorflow:global_step/sec: 384.285\n",
"INFO:tensorflow:loss = 0.02253997, step = 3200 (0.262 sec)\n",
"INFO:tensorflow:global_step/sec: 403.511\n",
"INFO:tensorflow:loss = 0.025050739, step = 3300 (0.248 sec)\n",
"INFO:tensorflow:global_step/sec: 405.741\n",
"INFO:tensorflow:loss = 0.022113536, step = 3400 (0.245 sec)\n",
"INFO:tensorflow:global_step/sec: 402.251\n",
"INFO:tensorflow:loss = 0.021794006, step = 3500 (0.250 sec)\n",
"INFO:tensorflow:global_step/sec: 395.529\n",
"INFO:tensorflow:loss = 0.037575096, step = 3600 (0.252 sec)\n",
"INFO:tensorflow:global_step/sec: 395.242\n",
"INFO:tensorflow:loss = 0.06161868, step = 3700 (0.251 sec)\n",
"INFO:tensorflow:global_step/sec: 402.537\n",
"INFO:tensorflow:loss = 0.02247855, step = 3800 (0.250 sec)\n",
"INFO:tensorflow:global_step/sec: 409.245\n",
"INFO:tensorflow:loss = 0.02687079, step = 3900 (0.244 sec)\n",
"INFO:tensorflow:global_step/sec: 385.271\n",
"INFO:tensorflow:loss = 0.029623002, step = 4000 (0.258 sec)\n",
"INFO:tensorflow:global_step/sec: 393.757\n",
"INFO:tensorflow:loss = 0.020733567, step = 4100 (0.254 sec)\n",
"INFO:tensorflow:global_step/sec: 409.733\n",
"INFO:tensorflow:loss = 0.020996619, step = 4200 (0.246 sec)\n",
"INFO:tensorflow:global_step/sec: 411.107\n",
"INFO:tensorflow:loss = 0.059230465, step = 4300 (0.244 sec)\n",
"INFO:tensorflow:global_step/sec: 393.461\n",
"INFO:tensorflow:loss = 0.032371342, step = 4400 (0.253 sec)\n",
"INFO:tensorflow:global_step/sec: 387.777\n",
"INFO:tensorflow:loss = 0.06218221, step = 4500 (0.260 sec)\n",
"INFO:tensorflow:global_step/sec: 389.653\n",
"INFO:tensorflow:loss = 0.022168107, step = 4600 (0.255 sec)\n",
"INFO:tensorflow:global_step/sec: 381.999\n",
"INFO:tensorflow:loss = 0.02789202, step = 4700 (0.263 sec)\n",
"INFO:tensorflow:global_step/sec: 391.465\n",
"INFO:tensorflow:loss = 0.02288571, step = 4800 (0.255 sec)\n",
"INFO:tensorflow:global_step/sec: 396.413\n",
"INFO:tensorflow:loss = 0.034475923, step = 4900 (0.252 sec)\n",
"INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000...\n",
"INFO:tensorflow:Saving checkpoints for 5000 into /tmp/tmpf039um29/model.ckpt.\n",
"INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000...\n",
"INFO:tensorflow:Loss for final step: 0.02225061.\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7fd2acef9978>"
]
},
"metadata": {
"tags": []
},
"execution_count": 135
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LftIQMQQMCpb",
"colab_type": "text"
},
"source": [
"##### Evaluate the Classifier Model on test data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "VqkrdWhZIodg",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "7fa3354f-5a3f-445a-a9f8-648d75e162ba"
},
"source": [
"result = iris_dnn_classifier.evaluate(lambda: input_fn(iris_test, y_test, training=False))\n",
"clear_output()\n",
"print(\"Test Set Accuracy : %.2f \"%(result['accuracy']*100))"
],
"execution_count": 136,
"outputs": [
{
"output_type": "stream",
"text": [
"Test Set Accuracy : 96.67 \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OnBfUEJiJxXY",
"colab_type": "text"
},
"source": [
"##### Evaluate the Classifier Model on the training data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "8a1JpCAhNdVR",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "91223f37-1d81-43e2-8d68-49493106482c"
},
"source": [
"result = iris_dnn_classifier.evaluate(lambda: input_fn(iris_train, y_train, training=False))\n",
"clear_output()\n",
"print(\"Train Set Accuracy : %.2f \"%(result['accuracy']*100))"
],
"execution_count": 137,
"outputs": [
{
"output_type": "stream",
"text": [
"Train Set Accuracy : 99.17 \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tTu_jWriN0Em",
"colab_type": "text"
},
"source": [
"##### Using the model for prediction"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "siH038gNOsBw",
"colab_type": "text"
},
"source": [
"Ask the user to input data on various features of the flower. The model will try to predict the species based on the data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IQMwyDa_jIxz",
"colab_type": "text"
},
"source": [
"Try with data like this:\n",
"\n",
"| Setosa | Versicolor | Virginica |\n",
"|---|---|---|\n",
"| 5.1 | 5.9 | 6.9 |\n",
"| 3.3 | 3.0 | 3.1 |\n",
"| 1.7 | 4.2 | 5.4 |\n",
"| 0.5 | 1.5 | 2.1 |\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sJ4KyapzN55n",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 133
},
"outputId": "329155f0-74af-40ae-f276-890a780d6201"
},
"source": [
"predict = dict()\n",
"categories = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth']\n",
"print('Please type info')\n",
"for category in categories:\n",
" valid = True\n",
" while valid:\n",
" val = input(category + ':')\n",
" if not val.isdigit():\n",
" valid = False\n",
" \n",
" predict[category] = [float(val)]\n",
"\n",
"def input_func(features, batch_size=256):\n",
" return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)\n",
"\n",
"predictions = iris_dnn_classifier.predict(input_fn=lambda: input_func(predict))\n",
"clear_output()\n",
"for pred in predictions:\n",
" id = pred['class_ids'][0]\n",
" prob = pred['probabilities'][id]\n",
"\n",
" print('Prediction is: \"{}\" ({:.2f}%)'.format(species[id], prob * 100))"
],
"execution_count": 139,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Restoring parameters from /tmp/tmpf039um29/model.ckpt-5000\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"Prediction is: \"Versicolor\" (99.92%)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bxP1xdM3WLF9",
"colab_type": "text"
},
"source": [
"So, we see that this model can much accurately predict the species of the flower based on user-provided feature data\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7jtaXBHKWngG",
"colab_type": "text"
},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kihZge-mdEd1",
"colab_type": "text"
},
"source": [
"## Basic Text Classification : Sentiment Analysis on an IMDB dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qOu-aBfvuG-G",
"colab_type": "text"
},
"source": [
"### Import Packages"
]
},
{
"cell_type": "code",
"metadata": {
"id": "BYUGeBsKdDru",
"colab_type": "code",
"colab": {}
},
"source": [
"import matplotlib.pyplot as plt\n",
"import os\n",
"import re\n",
"import shutil\n",
"import string\n",
"import tensorflow as tf\n",
"\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras import losses\n",
"from tensorflow.keras import preprocessing\n",
"from tensorflow.keras.layers.experimental.preprocessing import TextVectorization"
],
"execution_count": 140,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "_7X8VOGmuifv",
"colab_type": "text"
},
"source": [
"### Download and explore the IMDB dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "oRuz2AHRulRF",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "32c81e72-25e1-4b6b-edd0-443c62dea0d2"
},
"source": [
"url = \"https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\"\n",
"\n",
"dataset = tf.keras.utils.get_file(\"aclImdb_v1.tar.gz\", url,\n",
" untar=True, cache_dir='.',\n",
" cache_subdir='')\n",
"\n",
"dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')\n",
"train_dir = os.path.join(dataset_dir, 'train')\n",
"\n",
"remove_dir = os.path.join(train_dir, 'unsup')\n",
"shutil.rmtree(remove_dir)"
],
"execution_count": 141,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading data from https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n",
"84131840/84125825 [==============================] - 8s 0us/step\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ztdkS8g9yfIj",
"colab_type": "text"
},
"source": [
"#### Create the training dataset "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wgLeCwE5dDNg",
"colab_type": "text"
},
"source": [
"We will load the data off disk and prepare it into a format suitable for training. To do so, you will use the helpful [`text_dataset_from_directory`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/text_dataset_from_directory) utility, which expects a directory structure as follows.\n",
"\n",
"~~~\n",
"main_directory/\n",
"...class_a/\n",
"......a_text_1.txt\n",
"......a_text_2.txt\n",
"...class_b/\n",
"......b_text_1.txt\n",
"......b_text_2.txt\n",
"~~~\n",
"\n",
"To prepare a dataset for binary classification, we need two folders on disk, corresponding to `class_a` and `class_b`. These will be the positive and negative movie reviews, which can be found in `aclImdb/train/pos` and `aclImdb/train/neg`. As the IMDB dataset contains additional folders, we removed them before using this utility."
]
},
{
"cell_type": "code",
"metadata": {
"id": "H8pbC7JKdF-d",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "99ffec47-a92c-4811-fb8c-d86434df3c61"
},
"source": [
"batch = 42\n",
"seed = 42\n",
"\n",
"# training dataset\n",
"train_ds = preprocessing.text_dataset_from_directory('aclImdb/train',\n",
" batch_size = batch,\n",
" seed = seed,\n",
" validation_split = 0.2,\n",
" subset = 'training')\n"
],
"execution_count": 142,
"outputs": [
{
"output_type": "stream",
"text": [
"Found 25000 files belonging to 2 classes.\n",
"Using 20000 files for training.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Eg1A6-TpyNEQ",
"colab_type": "text"
},
"source": [
"#### Create a validation and test dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "n2_OXaSGdGqm",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "b70e31a8-6c06-44ae-f6e8-bfc2b4d66dda"
},
"source": [
"# validation dataset\n",
"valid_ds = preprocessing.text_dataset_from_directory('aclImdb/train',\n",
" batch_size = batch,\n",
" seed = seed,\n",
" validation_split = 0.2,\n",
" subset = 'validation')"
],
"execution_count": 143,
"outputs": [
{
"output_type": "stream",
"text": [
"Found 25000 files belonging to 2 classes.\n",
"Using 5000 files for validation.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "N6DMhd4udHRn",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "96b5fc4d-9dae-4d40-924c-042c08ee7056"
},
"source": [
"# test dataset\n",
"test_ds = tf.keras.preprocessing.text_dataset_from_directory('aclImdb/test', \n",
" batch_size=batch)"
],
"execution_count": 145,
"outputs": [
{
"output_type": "stream",
"text": [
"Found 25000 files belonging to 2 classes.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yZ7XUwmzywvf",
"colab_type": "text"
},
"source": [
"#### Prepare the dataset for training\n",
"\n",
"Here, we will standardize, tokenize, and vectorize the data.\n",
"\n",
"- **Standardization** refers to preprocessing the text, typically to remove punctuation or HTML elements to simplify the dataset.\n",
"- **Tokenization** refers to splitting strings into tokens (for example, splitting a sentence into individual words, by splitting on whitespace).\n",
"- **Vectorization** refers to converting tokens into numbers so they can be fed into a neural network. All of these tasks can be accomplished with this layer."
]
},
{
"cell_type": "code",
"metadata": {
"id": "gTbWucn_dIM9",
"colab_type": "code",
"colab": {}
},
"source": [
"def custom_standardization(input_data):\n",
" # convert a string into lowercase\n",
" lowercase = tf.strings.lower(input_data)\n",
" # remove all the <br/> tags from the lowercase string\n",
" stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')\n",
" # remove all punctuation from the stripped string\n",
" return tf.strings.regex_replace(stripped_html,\n",
" '[%s]' % re.escape(string.punctuation),'')\n",
"\n",
"max_features = 10000\n",
"sequence_length = 250\n",
"\n",
"# apply text vectorization\n",
"vectorize_layer = TextVectorization(\n",
" # use custom standardizer\n",
" standardize=custom_standardization,\n",
" max_tokens=max_features,\n",
" output_mode='int',\n",
" output_sequence_length=sequence_length)"
],
"execution_count": 146,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "8VW4dAdl9foo",
"colab_type": "text"
},
"source": [
"Now, we will call `adapt` to fit the state of the preprocessing layer to the dataset. This will cause the model to build an index of strings to integers."
]
},
{
"cell_type": "code",
"metadata": {
"id": "cBrZb8jodH7u",
"colab_type": "code",
"colab": {}
},
"source": [
"# Make a text-only dataset (without labels), then call adapt\n",
"train_text = train_ds.map(lambda x, y: x)\n",
"vectorize_layer.adapt(train_text)"
],
"execution_count": 147,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "WDAjf1kS-r0d",
"colab_type": "text"
},
"source": [
"create a function to see the result of using this layer to preprocess some data."
]
},
{
"cell_type": "code",
"metadata": {
"id": "hoSMeTyBdHBR",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 470
},
"outputId": "d984afe2-cf03-4d74-9e1b-d9311f17f35b"
},
"source": [
"def vectorize_text(text, label):\n",
" text = tf.expand_dims(text, -1)\n",
" return vectorize_layer(text), label\n",
"\n",
"# retrieve a batch (of 32 reviews and labels) from the dataset\n",
"text_batch, label_batch = next(iter(train_ds))\n",
"first_review, first_label = text_batch[0], label_batch[0]\n",
"print(\"Review\", first_review)\n",
"print(\"Label\", train_ds.class_names[first_label])\n",
"print(\"Vectorized review\", vectorize_text(first_review, first_label))"
],
"execution_count": 148,
"outputs": [
{
"output_type": "stream",
"text": [
"Review tf.Tensor(b'boring stuff we got here. His 5 minute shorts are better than this. know why? because there only 5 minutes and not 91 minutes or how ever long this is. <br /><br />The plot is kinda... eh.. the last half hour is alright the rest is boring and not funny =( I had my hopes up, the trailer made it look funny but the pace of this movie is pretty slow and sadly not funny. Just plain boring klaymen running into each other and trying to make us laugh.. not working.<br /><br />Maybe next time knox.<br /><br />Maybe re-cutting this movie and adding better scenes would do a lot of healing but for now its just not good.', shape=(), dtype=string)\n",
"Label neg\n",
"Vectorized review (<tf.Tensor: shape=(1, 250), dtype=int64, numpy=\n",
"array([[ 346, 524, 71, 184, 128, 24, 661, 909, 3043, 23, 122,\n",
" 70, 11, 118, 133, 84, 47, 61, 661, 226, 3, 21,\n",
" 1, 226, 41, 87, 121, 209, 11, 7, 2, 111, 7,\n",
" 1812, 7959, 2, 229, 363, 563, 7, 2757, 2, 351, 7,\n",
" 346, 3, 21, 160, 10, 66, 54, 1796, 56, 2, 1420,\n",
" 90, 9, 162, 160, 18, 2, 1052, 5, 11, 17, 7,\n",
" 179, 602, 3, 1030, 21, 160, 40, 1029, 346, 1, 633,\n",
" 81, 245, 78, 3, 258, 6, 96, 167, 472, 21, 773,\n",
" 278, 357, 58, 9521, 278, 1, 11, 17, 3, 2954, 122,\n",
" 136, 59, 82, 4, 171, 5, 8460, 18, 15, 148, 29,\n",
" 40, 21, 49, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0]])>, <tf.Tensor: shape=(), dtype=int32, numpy=0>)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YNPDsrXW_yeF",
"colab_type": "text"
},
"source": [
"each token has been replaced by an integer. We can lookup the token (string) that each integer corresponds to by calling `.get_vocabulary()` on the layer."
]
},
{
"cell_type": "code",
"metadata": {
"id": "AjVH5nn2dGeP",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "4fe1d31f-2c1f-4417-dec9-f99b3f4233be"
},
"source": [
"print(\"2166 ---> \",vectorize_layer.get_vocabulary()[2166])\n",
"print(\" 259 ---> \",vectorize_layer.get_vocabulary()[259 ])\n",
"print('Vocabulary size: {}'.format(len(vectorize_layer.get_vocabulary())))"
],
"execution_count": 149,
"outputs": [
{
"output_type": "stream",
"text": [
"2166 ---> grow\n",
" 259 ---> course\n",
"Vocabulary size: 10000\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ADQSgikkAIjd",
"colab_type": "text"
},
"source": [
"As a final preprocessing step, you will apply the `TextVectorization` layer you created earlier to the train, validation, and test dataset."
]
},
{
"cell_type": "code",
"metadata": {
"id": "IqBKXgNXdFY-",
"colab_type": "code",
"colab": {}
},
"source": [
"train_ds = train_ds.map(vectorize_text)\n",
"valid_ds = valid_ds.map(vectorize_text)\n",
"test_ds = test_ds.map(vectorize_text)"
],
"execution_count": 150,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "JFURiwzmATC3",
"colab_type": "text"
},
"source": [
"### Configure Dataset for performance\n",
"\n",
"These are two important methods we should use when loading data to make sure that I/O does not become blocking.\n",
"\n",
"`.cache()` keeps data in memory after it's loaded off disk. This will ensure the dataset does not become a bottleneck while training our model. If our dataset is too large to fit into memory, we can also use this method to create a performant on-disk cache, which is more efficient to read than many small files.\n",
"\n",
"`.prefetch()` overlaps data preprocessing and model execution while training.\n",
"\n",
"We can learn more about both methods, as well as how to cache data to disk in the [data performance guide](https://www.tensorflow.org/guide/data_performance)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "q9-Mi5zBASUN",
"colab_type": "code",
"colab": {}
},
"source": [
"AUTOTUNE = -1\n",
"\n",
"train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
"valid_ds = valid_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
"test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)"
],
"execution_count": 151,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q9HKXkj4AT5p",
"colab_type": "text"
},
"source": [
"### Build the Model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "krsAV6giHL4P",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 316
},
"outputId": "a80fcd83-09e0-4987-c4c0-2a3506815d75"
},
"source": [
"# initialize size of embedding dimensions to be 16\n",
"embedding_dim = 16\n",
"\n",
"model = tf.keras.Sequential([\n",
" # layers.Embedding : Turns positive integers (indexes) into dense vectors of fixed size.\n",
" layers.Embedding(max_features + 1, embedding_dim),\n",
" # layers.Dropout : apply Dropout regularization\n",
" layers.Dropout(0.2),\n",
" # layers.GlobalAveragePooling2D : Global average pooling operation for spatial data\n",
" layers.GlobalAveragePooling1D(),\n",
" layers.Dropout(0.2),\n",
" layers.Dense(1)])\n",
"\n",
"model.summary()"
],
"execution_count": 152,
"outputs": [
{
"output_type": "stream",
"text": [
"Model: \"sequential_2\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"embedding (Embedding) (None, None, 16) 160016 \n",
"_________________________________________________________________\n",
"dropout (Dropout) (None, None, 16) 0 \n",
"_________________________________________________________________\n",
"global_average_pooling1d (Gl (None, 16) 0 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 16) 0 \n",
"_________________________________________________________________\n",
"dense_3 (Dense) (None, 1) 17 \n",
"=================================================================\n",
"Total params: 160,033\n",
"Trainable params: 160,033\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G554OPzMPaNN",
"colab_type": "text"
},
"source": [
"#### Loss function and optimizer\n",
"\n",
"A model needs a loss function and an optimizer for training. Since this is a binary classification problem and the model outputs a probability (a single-unit layer with a sigmoid activation), we'll use `losses.BinaryCrossentropy` loss function."
]
},
{
"cell_type": "code",
"metadata": {
"id": "j0T4PbQ1PQHS",
"colab_type": "code",
"colab": {}
},
"source": [
"model.compile(loss=losses.BinaryCrossentropy(from_logits=True),\n",
" optimizer='adam',\n",
" metrics=tf.metrics.BinaryAccuracy(threshold=0.0))"
],
"execution_count": 153,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ytzxNz4vPfEd",
"colab_type": "text"
},
"source": [
"### Training the model on the training data\n",
"\n",
"train the model by passing the `dataset` object to the `fit` method."
]
},
{
"cell_type": "code",
"metadata": {
"id": "83Cz151GPyoK",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 370
},
"outputId": "4b4c7894-6050-4862-eee0-86208eb9e2a3"
},
"source": [
"train_hist = model.fit(train_ds,\n",
" validation_data = valid_ds,\n",
" epochs = 10)"
],
"execution_count": 155,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"477/477 [==============================] - 12s 26ms/step - loss: 0.6733 - binary_accuracy: 0.6831 - val_loss: 0.6378 - val_binary_accuracy: 0.7624\n",
"Epoch 2/10\n",
"477/477 [==============================] - 3s 5ms/step - loss: 0.5825 - binary_accuracy: 0.7861 - val_loss: 0.5370 - val_binary_accuracy: 0.8088\n",
"Epoch 3/10\n",
"477/477 [==============================] - 3s 6ms/step - loss: 0.4862 - binary_accuracy: 0.8293 - val_loss: 0.4574 - val_binary_accuracy: 0.8356\n",
"Epoch 4/10\n",
"477/477 [==============================] - 3s 6ms/step - loss: 0.4148 - binary_accuracy: 0.8561 - val_loss: 0.4046 - val_binary_accuracy: 0.8522\n",
"Epoch 5/10\n",
"477/477 [==============================] - 3s 5ms/step - loss: 0.3677 - binary_accuracy: 0.8692 - val_loss: 0.3699 - val_binary_accuracy: 0.8620\n",
"Epoch 6/10\n",
"477/477 [==============================] - 3s 5ms/step - loss: 0.3333 - binary_accuracy: 0.8794 - val_loss: 0.3464 - val_binary_accuracy: 0.8670\n",
"Epoch 7/10\n",
"477/477 [==============================] - 3s 6ms/step - loss: 0.3082 - binary_accuracy: 0.8885 - val_loss: 0.3296 - val_binary_accuracy: 0.8714\n",
"Epoch 8/10\n",
"477/477 [==============================] - 3s 5ms/step - loss: 0.2875 - binary_accuracy: 0.8958 - val_loss: 0.3173 - val_binary_accuracy: 0.8720\n",
"Epoch 9/10\n",
"477/477 [==============================] - 3s 5ms/step - loss: 0.2712 - binary_accuracy: 0.9020 - val_loss: 0.3081 - val_binary_accuracy: 0.8728\n",
"Epoch 10/10\n",
"477/477 [==============================] - 3s 5ms/step - loss: 0.2561 - binary_accuracy: 0.9081 - val_loss: 0.3009 - val_binary_accuracy: 0.8766\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vDV8_ULIRE8C",
"colab_type": "text"
},
"source": [
"### Evaluate the model on the training data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7I913ELpQ8lC",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 66
},
"outputId": "0be83132-8c3b-4157-e98d-c3fae1b0e4db"
},
"source": [
"loss, accuracy = model.evaluate(test_ds)\n",
"\n",
"print(\"Loss: \", loss)\n",
"print(\"Accuracy: \", accuracy)"
],
"execution_count": 156,
"outputs": [
{
"output_type": "stream",
"text": [
"596/596 [==============================] - 8s 14ms/step - loss: 0.3185 - binary_accuracy: 0.8706\n",
"Loss: 0.3185163140296936\n",
"Accuracy: 0.8705999851226807\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_1DAa8IFRQsQ",
"colab_type": "text"
},
"source": [
"### Accuracy vs. Loss over time\n",
"\n",
"`model.fit()` returns a `History` object that contains a dictionary with everything that happened during training:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "e3Nwgt7RQ9Hc",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "2cf7bd80-444d-4fd5-dbfb-3a743d757185"
},
"source": [
"train_hist_dict = train_hist.history\n",
"train_hist_dict.keys()"
],
"execution_count": 157,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"dict_keys(['loss', 'binary_accuracy', 'val_loss', 'val_binary_accuracy'])"
]
},
"metadata": {
"tags": []
},
"execution_count": 157
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HNW5_tESR8Zm",
"colab_type": "text"
},
"source": [
"There are four entries: one for each monitored metric during training and validation. We can use these to plot the training and validation loss for comparison, as well as the training and validation accuracy:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "0YGiLE7zQ83f",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 294
},
"outputId": "7be7081b-5d65-428b-cff6-aa50f365deee"
},
"source": [
"acc = train_hist_dict['binary_accuracy']\n",
"val_acc = train_hist_dict['val_binary_accuracy']\n",
"loss = train_hist_dict['loss']\n",
"val_loss = train_hist_dict['val_loss']\n",
"\n",
"epochs = range(1, len(acc) + 1)\n",
"\n",
"# \"bo\" is for \"blue dot\"\n",
"plt.plot(epochs, loss, 'bo', label='Training loss')\n",
"# b is for \"solid blue line\"\n",
"plt.plot(epochs, val_loss, 'b', label='Validation loss')\n",
"plt.title('Training and validation loss')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Loss')\n",
"plt.legend()\n",
"\n",
"plt.show()"
],
"execution_count": 158,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de3xU1bn/8c8T7gii3FSugRakXMNdQRC1PYpQUPGGqUo5ilC8oVVRVDi2WFs5/iwVrahFbWPRQ1vEeq0KglorSBFEsVIINl4R5WZAuTy/P9aETEISEsjMnmS+79drXpnZs2fPk0Hnm73WXmuZuyMiIukrI+oCREQkWgoCEZE0pyAQEUlzCgIRkTSnIBARSXMKAhGRNKcgkEplZs+a2cWVvW+UzCzXzL6fgOO6mX03dv+3ZnZLefY9iPfJNrMXDrbOMo47xMzyKvu4knw1oy5Aomdm2+Me1ge+AfbEHl/m7jnlPZa7D03EvtWdu4+vjOOYWSawHqjl7rtjx84Byv1vKOlHQSC4e4OC+2aWC1zi7i8W38/MahZ8uYhI9aGmISlVwam/md1gZp8Cc8zsSDP7q5ltNLOvYvdbxb1mkZldErs/xsxeNbMZsX3Xm9nQg9y3nZktNrNtZvaimc0ysz+UUnd5avyZmb0WO94LZtY07vkLzWyDmW0ysyllfD79zexTM6sRt+1MM1sZu9/PzP5uZpvN7BMzu8fMapdyrIfN7Odxj6+LveZjMxtbbN9hZvZPM9tqZv8xs2lxTy+O/dxsZtvN7PiCzzbu9QPMbKmZbYn9HFDez6YsZva92Os3m9lqMxsR99zpZvZu7JgfmdlPY9ubxv59NpvZl2a2xMz0vZRk+sDlQI4GGgNtgXGE/2bmxB63AXYA95Tx+v7A+0BT4FfAQ2ZmB7HvY8CbQBNgGnBhGe9ZnhovAH4MNAdqAwVfTJ2B+2LHbxF7v1aUwN3/AXwNnFzsuI/F7u8BJsV+n+OBU4CflFE3sRpOi9XzA6ADULx/4mvgIuAIYBgwwczOiD03OPbzCHdv4O5/L3bsxsDTwMzY73YX8LSZNSn2O+z32Ryg5lrAU8ALsdddAeSY2bGxXR4iNDM2BLoCL8e2XwvkAc2Ao4CbAM17k2QKAjmQvcBUd//G3Xe4+yZ3/5O757v7NmA6cGIZr9/g7g+4+x7gEeAYwv/w5d7XzNoAfYFb3f1bd38VWFDaG5azxjnu/i933wE8AWTFtp8N/NXdF7v7N8Atsc+gNH8ERgOYWUPg9Ng23P0td3/D3Xe7ey5wfwl1lOTcWH3vuPvXhOCL//0Wufsqd9/r7itj71ee40IIjg/c/fexuv4IrAF+GLdPaZ9NWY4DGgB3xP6NXgb+SuyzAXYBnc3scHf/yt2Xx20/Bmjr7rvcfYlrArSkUxDIgWx0950FD8ysvpndH2s62UpoijgivnmkmE8L7rh7fuxugwru2wL4Mm4bwH9KK7icNX4adz8/rqYW8ceOfRFvKu29CH/9n2VmdYCzgOXuviFWR8dYs8ensTpuJ5wdHEiRGoANxX6//ma2MNb0tQUYX87jFhx7Q7FtG4CWcY9L+2wOWLO7x4dm/HFHEUJyg5m9YmbHx7bfCawFXjCzdWY2uXy/hlQmBYEcSPG/zq4FjgX6u/vhFDZFlNbcUxk+ARqbWf24ba3L2P9Qavwk/tix92xS2s7u/i7hC28oRZuFIDQxrQE6xOq46WBqIDRvxXuMcEbU2t0bAb+NO+6B/pr+mNBkFq8N8FE56jrQcVsXa9/fd1x3X+ruIwnNRvMJZxq4+zZ3v9bd2wMjgGvM7JRDrEUqSEEgFdWQ0Oa+OdbePDXRbxj7C3sZMM3Masf+mvxhGS85lBrnAcPN7IRYx+5tHPj/k8eAqwiB83/F6tgKbDezTsCEctbwBDDGzDrHgqh4/Q0JZ0g7zawfIYAKbCQ0ZbUv5djPAB3N7AIzq2lm5wGdCc04h+IfhLOH682slpkNIfwbzY39m2WbWSN330X4TPYCmNlwM/turC9oC6FfpaymOEkABYFU1N1APeAL4A3guSS9bzahw3UT8HPgccJ4h5IcdI3uvhqYSPhy/wT4itCZWZaCNvqX3f2LuO0/JXxJbwMeiNVcnhqejf0OLxOaTV4utstPgNvMbBtwK7G/rmOvzSf0ibwWuxLnuGLH3gQMJ5w1bQKuB4YXq7vC3P1bwhf/UMLnfi9wkbuvie1yIZAbayIbT/j3hNAZ/iKwHfg7cK+7LzyUWqTiTP0yUhWZ2ePAGndP+BmJSHWnMwKpEsysr5l9x8wyYpdXjiS0NYvIIdLIYqkqjgb+TOi4zQMmuPs/oy1JpHpQ05CISJpT05CISJqrck1DTZs29czMzKjLEBGpUt56660v3L1ZSc9VuSDIzMxk2bJlUZchIlKlmFnxEeX7qGlIRCTNKQhERNKcgkBEJM1VuT4CEUm+Xbt2kZeXx86dOw+8s0Sqbt26tGrVilq1apX7NQoCETmgvLw8GjZsSGZmJqWvKyRRc3c2bdpEXl4e7dq1K/fr0qJpKCcHMjMhIyP8zNEy3iIVsnPnTpo0aaIQSHFmRpMmTSp85lbtzwhycmDcOMiPLWmyYUN4DJCdXfrrRKQohUDVcDD/TtX+jGDKlMIQKJCfH7aLiEgaBMGHH1Zsu4iknk2bNpGVlUVWVhZHH300LVu23Pf422+/LfO1y5Yt48orrzzgewwYMKBSal20aBHDhw+vlGMlS7UPgjbFF/k7wHYROXSV3S/XpEkTVqxYwYoVKxg/fjyTJk3a97h27drs3r271Nf26dOHmTNnHvA9Xn/99UMrsgqr9kEwfTrUr190W/36YbuIVL6CfrkNG8C9sF+usi/SGDNmDOPHj6d///5cf/31vPnmmxx//PH07NmTAQMG8P777wNF/0KfNm0aY8eOZciQIbRv375IQDRo0GDf/kOGDOHss8+mU6dOZGdnUzBL8zPPPEOnTp3o3bs3V1555QH/8v/yyy8544wz6N69O8cddxwrV64E4JVXXtl3RtOzZ0+2bdvGJ598wuDBg8nKyqJr164sWbKkcj+wMlT7zuKCDuEpU0JzUJs2IQTUUSySGGX1y1X2/3d5eXm8/vrr1KhRg61bt7JkyRJq1qzJiy++yE033cSf/vSn/V6zZs0aFi5cyLZt2zj22GOZMGHCftfc//Of/2T16tW0aNGCgQMH8tprr9GnTx8uu+wyFi9eTLt27Rg9evQB65s6dSo9e/Zk/vz5vPzyy1x00UWsWLGCGTNmMGvWLAYOHMj27dupW7cus2fP5tRTT2XKlCns2bOH/OIfYgJV+yCA8B+fvvhFkiOZ/XLnnHMONWrUAGDLli1cfPHFfPDBB5gZu3btKvE1w4YNo06dOtSpU4fmzZvz2Wef0apVqyL79OvXb9+2rKwscnNzadCgAe3bt993ff7o0aOZPXt2mfW9+uqr+8Lo5JNPZtOmTWzdupWBAwdyzTXXkJ2dzVlnnUWrVq3o27cvY8eOZdeuXZxxxhlkZWUd0mdTEdW+aUhEkiuZ/XKHHXbYvvu33HILJ510Eu+88w5PPfVUqdfS16lTZ9/9GjVqlNi/UJ59DsXkyZN58MEH2bFjBwMHDmTNmjUMHjyYxYsX07JlS8aMGcOjjz5aqe9ZFgWBiFSqqPrltmzZQsuWLQF4+OGHK/34xx57LOvWrSM3NxeAxx9//ICvGTRoEDmxzpFFixbRtGlTDj/8cP7973/TrVs3brjhBvr27cuaNWvYsGEDRx11FJdeeimXXHIJy5cvr/TfoTQKAhGpVNnZMHs2tG0LZuHn7NmJb569/vrrufHGG+nZs2el/wUPUK9ePe69915OO+00evfuTcOGDWnUqFGZr5k2bRpvvfUW3bt3Z/LkyTzyyCMA3H333XTt2pXu3btTq1Ythg4dyqJFi+jRowc9e/bk8ccf56qrrqr036E0VW7N4j59+rgWphFJrvfee4/vfe97UZcRue3bt9OgQQPcnYkTJ9KhQwcmTZoUdVn7Kenfy8zecvc+Je2vMwIRkXJ64IEHyMrKokuXLmzZsoXLLrss6pIqRVpcNSQiUhkmTZqUkmcAh0pnBCIiaU5BICKS5hQEIiJpTkEgIpLmFAQikvJOOukknn/++SLb7r77biZMmFDqa4YMGULBpeann346mzdv3m+fadOmMWPGjDLfe/78+bz77rv7Ht966628+OKLFSm/RKk0XXXaBMGzz8KwYVDK9CMiksJGjx7N3Llzi2ybO3duuSZ+gzBr6BFHHHFQ7108CG677Ta+//3vH9SxUlXaBEF+PjzzDEybFnUlIlJRZ599Nk8//fS+RWhyc3P5+OOPGTRoEBMmTKBPnz506dKFqVOnlvj6zMxMvvjiCwCmT59Ox44dOeGEE/ZNVQ1hjEDfvn3p0aMHo0aNIj8/n9dff50FCxZw3XXXkZWVxb///W/GjBnDvHnzAHjppZfo2bMn3bp1Y+zYsXzzzTf73m/q1Kn06tWLbt26sWbNmjJ/v6inq06bcQSjRsHYsfCLX8Cpp8LgwVFXJFI1XX01rFhRucfMyoK77y79+caNG9OvXz+effZZRo4cydy5czn33HMxM6ZPn07jxo3Zs2cPp5xyCitXrqR79+4lHuett95i7ty5rFixgt27d9OrVy969+4NwFlnncWll14KwM0338xDDz3EFVdcwYgRIxg+fDhnn312kWPt3LmTMWPG8NJLL9GxY0cuuugi7rvvPq6++moAmjZtyvLly7n33nuZMWMGDz74YKm/X9TTVafNGQHAr38N7dvDhRdCCc2FIpLC4puH4puFnnjiCXr16kXPnj1ZvXp1kWac4pYsWcKZZ55J/fr1OfzwwxkxYsS+59555x0GDRpEt27dyMnJYfXq1WXW8/7779OuXTs6duwIwMUXX8zixYv3PX/WWWcB0Lt3730T1ZXm1Vdf5cILLwRKnq565syZbN68mZo1a9K3b1/mzJnDtGnTWLVqFQ0bNizz2OWRNmcEAA0ahFWSBg6EiRMrf8UkkXRQ1l/uiTRy5EgmTZrE8uXLyc/Pp3fv3qxfv54ZM2awdOlSjjzySMaMGVPq9NMHMmbMGObPn0+PHj14+OGHWbRo0SHVWzCV9aFMYz158mSGDRvGM888w8CBA3n++ef3TVf99NNPM2bMGK655houuuiiQ6o1rc4IAPr3h6lT4bHHFAQiVUmDBg046aSTGDt27L6zga1bt3LYYYfRqFEjPvvsM5599tkyjzF48GDmz5/Pjh072LZtG0899dS+57Zt28YxxxzDrl279k0dDdCwYUO2bdu237GOPfZYcnNzWbt2LQC///3vOfHEEw/qd4t6uuq0OiMocOON8Nxz8JOfhLODzMyoKxKR8hg9ejRnnnnmviaigmmbO3XqROvWrRk4cGCZr+/VqxfnnXcePXr0oHnz5vTt23ffcz/72c/o378/zZo1o3///vu+/M8//3wuvfRSZs6cua+TGKBu3brMmTOHc845h927d9O3b1/Gjx9/UL9XwVrK3bt3p379+kWmq164cCEZGRl06dKFoUOHMnfuXO68805q1apFgwYNKmUBm7Sdhnr9eujRI3RSLVwIsdXuRKQEmoa6atE01OXUrh3MmgVLlsAvfxl1NSIi0UnbIAD40Y/gvPNCn8HSpVFXIyISjbQOAjO47z445piwjN727VFXJJK6qlozcro6mH+ntA4CgCOPhEcfhbVr4Zproq5GJDXVrVuXTZs2KQxSnLuzadMm6tatW6HXpeVVQ8UNGQLXXx/6CoYOhTPPjLoikdTSqlUr8vLy2LhxY9SlyAHUrVuXVq1aVeg1aXvVUHHffgvHHw8bNsDKldCiRaW/hYhIZHTVUDnUrh0GmOXnw49/DHv3Rl2RiEhyKAjidOoEd90FL7wAv/lN1NWIiCRHQoPAzE4zs/fNbK2ZTS5ln3PN7F0zW21mjyWynvK47DL44Q/hhhtg1aqoqxERSbyEBYGZ1QBmAUOBzsBoM+tcbJ8OwI3AQHfvAlydqHrKywwefBAaNYILLoCDnL9KRKTKSOQZQT9grbuvc/dvgbnAyGL7XArMcvevANz98wTWU27Nm8PDD8M774R5iUREqrNEBkFL4D9xj/Ni2+J1BDqa2Wtm9oaZnVbSgcxsnJktM7Nlybp8behQuPzyMOXuCy8k5S1FRCIRdWdxTaADMAQYDTxgZvstLOrus929j7v3adasWdKK+9WvoHNnuPhiiK1yJyJS7SQyCD4CWsc9bhXbFi8PWODuu9x9PfAvQjCkhHr1wiWlX34Jl1wCVWzIhYhIuSQyCJYCHcysnZnVBs4HFhTbZz7hbAAza0poKlqXwJoqLCsLbr8dnnwydCKLiFQ3CQsCd98NXA48D7wHPOHuq83sNjMrWCj0eWCTmb0LLASuc/dNiarpYE2aBKecEhbt/te/oq5GRKRyJbSPwN2fcfeO7v4dd58e23aruy+I3Xd3v8bdO7t7N3efm8h6DlZGBjzyCNSpE2Yp3bXr4I6TkxNWQ8vICD+1VKaIpIKoO4urjJYt4YEHYNkymDat4q/PyYFx48JcRu7h57hxCgMRiZ6CoAJGjQrzEP3iF2Fls4qYMiXMYxQvPz9sFxGJkoKggn79a2jfHi68ELZsKf/rPvywYttFRJJFQVBBDRuG5py8PJg4sfyva9OmYttFRJJFQXAQ+veHW28NgfBYOafJmz4d6tcvuq1+/bBdRCRKCoKDdNNNMGAATJgQOn4PJDsbZs+Gtm3DxHZt24bH2dmJr1VEpCxaoewQrFsXBpxlZcHChVCjRtQViYiUTCuUJUj79nDPPeEKol/+MupqREQOjoLgEF14IZx7LkydGsYYiIhUNQqCQ2QGv/0tHH10WMjm66+jrkhEpGIUBJXgyCPh97+HtWvDvEQiIlWJgqCSDBkC110XpqGYPz/qakREyk9BUIl+9jPo1SusXfDJJ1FXIyJSPgqCSlS7dhhklp8PY8bA3r1RVyQicmAKgkrWqRP87/+GdY5/85uoqxEROTAFQQKMHw/Dh8MNN8CqVVFXIyJSNgVBApjBQw9Bo0ZhComdO6OuSESkdAqCBGneHObMCWcEN94YdTUiIqVTECTQ6afD5ZfD3XeHPgMRkVSkIEiwX/0KOncOVxF98UXU1YiI7E9BkGD16oVLSr/4Ai69NKxXLCKSShQESZCVFdY5nj8/dCKLiKQSBUGSTJoEp5wCV10FH3wQdTUiIoUUBEmSkQEPPwx16sB558H27VFXJCISKAiSqFWrMEvp22/DOefArl1RVyQioiBIumHD4P774bnnwuR06jwWkajVjLqAdFQwO+mtt0KLFqEjWUQkKgqCiNx8M3z0EdxxRwiDK66IuiIRSVcKgoiYwaxZ8Nln4Uqio48O/QYiIsmmPoII1agBjz0GAwbAj34EixZFXZGIpCMFQcTq1YMFC+A734GRI2HlyqgrEpF0oyBIAY0bh6uIGjaEoUPhww+jrkhE0omCIEW0aRPC4Ouv4dRTYdOmqCsSkXShIEghXbvCk0/C+vXwwx+GtY9FRBJNQZBiTjwxzFb6xhswejTs3h11RSJS3SkIUtCoUWHh+wULYOJEjT4WkcTSOIIUNXEifPwx3H57GHA2dWrUFYlIdaUgSGE//3kIg2nTQhhcemnUFYlIdZTQpiEzO83M3jeztWY2uYTnx5jZRjNbEbtdksh6qhozmD07XFI6fnxoKhIRqWwJCwIzqwHMAoYCnYHRZta5hF0fd/es2O3BRNVTVdWqBf/3f9C7d1jH4PXXo65IRKqbRJ4R9APWuvs6d/8WmAuMTOD7VVuHHQZPPw2tW4fLSt97L+qKRKQ6SWQQtAT+E/c4L7atuFFmttLM5plZ65IOZGbjzGyZmS3buHFjImpNec2ahQFntWrBaaeFmUtFRCpD1JePPgVkunt34G/AIyXt5O6z3b2Pu/dp1qxZUgtMJe3bw7PPwpdfhn6DzZujrkhEqoNEBsFHQPxf+K1i2/Zx903u/k3s4YNA7wTWUy307Al//jOsWQNnnAE7d0ZdkYhUdYkMgqVABzNrZ2a1gfOBIte9mNkxcQ9HAGr9Locf/AAefhheeQUuugj27o26IhGpyhIWBO6+G7gceJ7wBf+Eu682s9vMbERstyvNbLWZvQ1cCYxJVD3VzQUXwIwZ4Yqiq6+u2OjjnBzIzISMjPAzJydRVYpIVWBexeYv6NOnjy9btizqMlLGtdfCXXeFJS9vuOHA++fkwLhxRSe0q18/jFfIzk5cnSISLTN7y937lPRc1J3FcojuvDNMTjd5Mjz66IH3nzJl/1lN8/PDdhFJT5pioorLyIA5c+Dzz+G//xuaNw+Xl5amtEVvtBiOSPrSGUE1UKdOuJKoa1c4+2xYurT0fdu0qdh2Ean+yhUEZnaYmWXE7nc0sxFmViuxpUlFHH54GGPQrBkMGwZr15a83/TpoU8gXv36YbuIpKfynhEsBuqaWUvgBeBC4OFEFSUH5+ij4fnnwxVEp54Kn322/z7Z2aFjuG3bMKld27bqKBZJd+UNAnP3fOAs4F53Pwfokriy5GB17Ah//St8+imcfjps27b/PtnZkJsbxh/k5ioERNJduYPAzI4HsoGnY9tqJKYkOVT9+4fxBW+/HfoMvv026opEJJWVNwiuBm4E/hIbFNYeWJi4suRQnX46PPAAvPACjB2r0cciUrpyXT7q7q8ArwDEOo2/cPcrE1mYHLof/ziscHbzzWGFs1/9KuqKRCQVlfeqocfM7HAzOwx4B3jXzK5LbGlSGW66CX7ykzDw7O67o65GRFJReZuGOrv7VuAM4FmgHeHKIUlxZjBzJpx1FkyaBHPnRl2RiKSa8gZBrdi4gTOABe6+C6hakxSlsRo1whxDgwaF2UpffjnqikQklZQ3CO4HcoHDgMVm1hbYmqiipPLVrQtPPhkuLz3jDFixIuqKRCRVlCsI3H2mu7d099M92ACclODapJIdeWRY7rJRo7DCWW5u1BWJSCoob2dxIzO7q2DdYDP7X8LZgVQxrVqFMNi5E447Dl58MeqKRCRq5W0a+h2wDTg3dtsKzElUUZJYXbrAkiXQuDH813/BLbfA7t1RVyUiUSlvEHzH3ae6+7rY7X+A9oksTBKra9cwS+mPfww//zmcfDLk5UVdlYhEobxBsMPMTih4YGYDgR2JKUmS5bDD4KGH4A9/gH/+E7Ky4OmnD/w6EaleyhsE44FZZpZrZrnAPcBlCatKkio7G956C1q3huHDw/KXmp9IJH2U96qht929B9Ad6O7uPYGTE1qZJFXHjvD3v8PEiWEN5BNOgHXroq5KRJKhQiuUufvW2AhjgGsSUI9EqG5duOce+NOf4F//gp49Yd68qKsSkUQ7lKUqrdKqkJRy1lmhz+B734NzzglzFe3cGXVVIpIohxIEmmKiGmvXLlxiet11cN99YY2DNWuirkpEEqHMIDCzbWa2tYTbNqBFkmqUiNSqFaaufuaZMJ11nz7w+99HXZWIVLYyg8DdG7r74SXcGrp7udYykKpv6NAwN1Hv3mHSujFjYPv2qKsSkcpyKE1DkkZatoSXXoJbb4VHH4W+fWHlyqirEpHKoCCQcqtZE/7nf8L8RJs3h36D++8HV2+RSJWmIJAKO/lkePttGDwYxo+H88+HLVuirkpEDpaCQA5K8+bw7LNwxx1h3EGvXmHuIhGpehQEctAyMuCGG2Dx4jB76cCB8P/+n5qKRKoaBYEcsgEDwgC000+Ha66BkSNh06aoqxKR8lIQSKVo3Bj+8heYOROefz7MZPrqq1FXJSLloSCQSmMGV1wBr78OderAkCFw++2wd2/J++fkQGZmaGLKzAyPRST5FARS6Xr3huXLwzxFU6bAaafBZ58V3ScnB8aNgw0bQp/Chg3hscJAJPkUBJIQhx8Ojz0GDzwQ5izq0aPo+shTpkB+ftHX5OeH7SKSXAoCSRgzuOSScFlpkyZhfeSbbw5XGH34YcmvKW27iCSOgkASrmtXePPNsD7y9Olw0knQopQpC9u0SW5tIqIgkCSJXx95xYowErlOnaL71K8fgkJEkktBIElVsD7yd78L33wDDRuG7W3bwuzZ4XkRSa6EBoGZnWZm75vZWjObXMZ+o8zMzaxPIuuR1FCwPvLll8O2bdC6dVgAZ9SoqCsTSU8JCwIzqwHMAoYCnYHRZta5hP0aAlcB/0hULZJ66taF3/wG/va30C9w+eVhVbS77oKvv466OpH0ksgzgn7AWndf5+7fAnOBkSXs9zPgl4BWxU1D3/9+uLx04ULo3BmuvTYMLrv9dti6NerqRNJDIoOgJfCfuMd5sW37mFkvoLW7P13WgcxsnJktM7NlGzdurPxKJVJmYRTySy/Ba69Bv35hPEHbtjB1Knz5ZdQVilRvkXUWm1kGcBdw7YH2dffZ7t7H3fs0a9Ys8cVJZAYMgKefhmXLwmWmt90WAmHyZPj886irE6meEhkEHwGt4x63im0r0BDoCiwys1zgOGCBOowFwjQVf/4zrFoFP/wh3HlnaDK6+mr46KMDvlxEKiCRQbAU6GBm7cysNnA+sKDgSXff4u5N3T3T3TOBN4AR7r4sgTVJFdO1a5iq4r334Lzz4J57oH17mDABcnOjrk6kekhYELj7buBy4HngPeAJd19tZreZ2YhEva9UTx07wpw58MEHYYTy734HHTqE+x98EHV1IlWbeRVbTqpPnz6+bJlOGtJdXh7MmAH33w/ffhvOFqZMgS5doq5MJDWZ2VvuXmLTu0YWS5XUqhXcfXdoHvrpT2HBgtCMNGpUmAJbRMpPQSBV2lFHwS9/GdYzuOWWcAlq794wfDi88UbU1YlUDQoCqRaaNAmXmm7YECaue+MNOP74MGDtlVfC4jciUjIFgVQrjRrBTTeFJqMZM+Cdd8JgtcGDw1rKCgSR/SkIpFpq0CBMV7F+fZjTKDc3LJnZv3/oT1AgiBRSEEi1Vq9emNDu3/8Oy2Zu2gQjR0JWFjzxBOzZE3WFItFTEEhaqF07LJv5/vvw6KNhLYTzzgtXGk2YEKaxyMgIo5dzcqKuViS5FASSVmrWhAsvhNWr4fHHw03kHrkAAAxySURBVJTXv/1tWCvZPXQ2jxunMJD0oiCQtFSjBpx7bpj5tLj8fJg4McxzpL4ESQcKAklr//lPydu3bIHu3eF734Obbw7rLCsUpLpSEEhaa9Om5O2tWsF994Wfv/gF9OwZ5ju68cYwclmhINWJgkDS2vTpUL9+0W3168Mdd8D48fDii/DppzB7dlhK8847w8jl734XbrgBli5VKEjVpyCQtJadHb7k27YN/QVt24bH2dmF+zRrBpdeCi+8AJ99Bg8+GM4O7rorrKbWrl2Y7+iNNxQKUjVp9lGRg/TVV/DkkzBvXgiJXbugdesw8d3ZZ4cpLjL0p5akiLJmH1UQiFSCzZvhqadCKDz3XJgau0WLEArnnBOW4KxRI+oqJZ1pGmqRBDviiDA+4cknYePGMA6hX7/QzDR4cOh0njgRFi3SaGZJPQoCkUp2+OFwwQXwl7+EUJg7FwYODCusnXRSOFOYMCFMmb17d9TViigIRBKqYcMwlcW8eSEUnngizIb66KNhiuxjjgkjmQv6GESioD4CkQjk54e+hHnzQt/C9u3QuDGMGBGCYtCgcDVSSSOfRQ6GOotFUtiOHeGMYN48+OtfQ8czhLOFE04IoXDCCWGkszqc5WApCESqiL174d13YckSePXVcPvww/Bcw4bhktSCYOjXb//BcCKlURCIVGEfflgYCq++GlZdc4datcIo5xNOCLeBA6Fp06irlVSlIBBJcTk5MGVK+NJv0yZMfRE/ujneV1/B668XBsObb4ZxCxAmySsIhhNOUD+DFFIQiKSwnJxw5VB+fuG2+vX3n+qiNDt3wrJlhcHw2muF/QwtWhSGwqBB0K2b+hnSlYJAJIVlZoYFcYpr2zastVxRe/eGhXcKgmHJksLpths2DKOcC4KhX7+wnKdUfwoCkRSWkVHyZHVm4Uu9MpSnn2HQoNDP0KRJ5bynpBYFgUgKq+wzgvKI72dYsiRMpx3fz3D88aEZqWvX8LN5c/U1VHUKApEUdqh9BJWheD/Dm2+GkdAFmjYNoVBw69YNunSBRo2SU58curKCoGayixGRogq+7Mt71VAi1K1b2Klc4PPPQxNS/O2RR2DbtsJ9WrcuGg5du0KnTup3qGp0RiAi5eYewio+HFatgvfeK2xaysgIK7jFh0PXrmFbTf3pGRmdEYhIpShYxa1tWxg2rHD77t2wdu3+ATF/fmGHd+3aof8hPhy6dg1nQOp/iJbOCEQkYXbsgDVriobDO+8UXs4K4ZLW+GAouDVvHl3d1ZE6i0UkpWzZEsY6xAfEqlWwaVPhPs2bw7HHhtHR8bf27cNAOS0DWjFqGhKRlNKoURjYNmBA4Tb3oh3Uq1bBBx/Ayy/DRx8VHWtRu3ZonioeEgW3Jk3U3FQRCgIR2acicx5VNjM46qhwO+WUos99802oaf36/W/Ll8MXXxTdv0GD0kOiXbvwvBRSEIgIsP94hg0bwmNI7qWsJalTBzp0CLeSbNsWBt8VD4l168KSoF9/XXT/pk0Lm5mKh0SbNuGMI52oj0BEgGhGOCeDezhjKOlsYv368DvHLxOakQEtWxYNh9atQ7/EMceEn02aVL0+CvURiMgBFSyAU97tVYUZNGsWbv367f/8nj3w8cclh8SLL4bniv+9XKtWCIWCYIgPieKBURX6KhIaBGZ2GvBroAbwoLvfUez58cBEYA+wHRjn7u8msiYRKVmbNiWfEbRpk/xakqlGjfAXf+vWMHjw/s9/8w188kkIhIKf8ff/9S945RX48sv9X1u7Nhx9dNlh0aJFWK86ysBIWBCYWQ1gFvADIA9YamYLin3RP+buv43tPwK4CzgtUTWJSOmmTy95zqPp06OrKRXUqROazTIzy95v584QDiWFxccfh/EUCxeGCf+Kq127MBjKOstIVGAk8oygH7DW3dcBmNlcYCSwLwjcfWvc/ocBVavDQqQaSYU5j6qyunUL+xTKsmPHgQPj5ZcLFxeKN3MmXHFF5deeyCBoCcSNHyQP6F98JzObCFwD1AZOLulAZjYOGAfQprqfp4pEKDtbX/yJVq9euFqpffuy9ysIjPiwOPHExNQUeWexu88CZpnZBcDNwMUl7DMbmA3hqqHkVigiknzlDYzKkMgLoD4CWsc9bhXbVpq5wBkJrEdEREqQyCBYCnQws3ZmVhs4H1gQv4OZxQ8PGQZ8kMB6RKSKyMkJnbMZGeFnTk7UFVVvCWsacvfdZnY58Dzh8tHfuftqM7sNWObuC4DLzez7wC7gK0poFhKR9JLKI5yrK40sFpGUUl1HOEetrJHFVWyQtIhUd9V1hHMqUxCISEop7QpxXTmeOAoCEUkp06eHEc3xNMI5sRQEIpJSsrNh9uzQJ1CwRvLs2eooTqTIB5SJiBSnEc7JpTMCEZFSpMt4Bp0RiIiUIJ3GM+iMQESkBFOmFJ2SG8LjKVOiqSeRFAQiIiVIp/EMCgIRkRKk03gGBYGISAnSaTyDgkBEpATpNJ5BQSAiUors7DDR3d694WdUIZDoy1h1+aiISApLxmWsOiMQEUlhybiMVUEgIpLCknEZq4JARCSFJeMyVgWBiEgKS8ZlrAoCEZEUlozLWHXVkIhIikv0tNw6IxARSXMKAhGRNKcgEBFJcwoCEZE0pyAQEUlz5u5R11AhZrYR2BB1HYeoKfBF1EWkEH0ehfRZFKXPo6hD+Tzaunuzkp6ockFQHZjZMnfvE3UdqUKfRyF9FkXp8ygqUZ+HmoZERNKcgkBEJM0pCKIxO+oCUow+j0L6LIrS51FUQj4P9RGIiKQ5nRGIiKQ5BYGISJpTECSRmbU2s4Vm9q6ZrTazq6KuKWpmVsPM/mlmf426lqiZ2RFmNs/M1pjZe2Z2fNQ1RcnMJsX+P3nHzP5oZnWjrilZzOx3Zva5mb0Tt62xmf3NzD6I/Tyyst5PQZBcu4Fr3b0zcBww0cw6R1xT1K4C3ou6iBTxa+A5d+8E9CCNPxczawlcCfRx965ADeD8aKtKqoeB04ptmwy85O4dgJdijyuFgiCJ3P0Td18eu7+N8D96y2irio6ZtQKGAQ9GXUvUzKwRMBh4CMDdv3X3zdFWFbmaQD0zqwnUBz6OuJ6kcffFwJfFNo8EHondfwQ4o7LeT0EQETPLBHoC/4i2kkjdDVwP7I26kBTQDtgIzIk1lT1oZodFXVRU3P0jYAbwIfAJsMXdX4i2qsgd5e6fxO5/ChxVWQdWEETAzBoAfwKudvetUdcTBTMbDnzu7m9FXUuKqAn0Au5z957A11TiqX9VE2v/HkkIyBbAYWb2o2irSh0ervuvtGv/FQRJZma1CCGQ4+5/jrqeCA0ERphZLjAXONnM/hBtSZHKA/LcveAMcR4hGNLV94H17r7R3XcBfwYGRFxT1D4zs2MAYj8/r6wDKwiSyMyM0Ab8nrvfFXU9UXL3G929lbtnEjoBX3b3tP2Lz90/Bf5jZsfGNp0CvBthSVH7EDjOzOrH/r85hTTuPI9ZAFwcu38x8GRlHVhBkFwDgQsJf/2uiN1Oj7ooSRlXADlmthLIAm6PuJ7IxM6M5gHLgVWE76q0mW7CzP4I/B041szyzOy/gTuAH5jZB4Qzpjsq7f00xYSISHrTGYGISJpTEIiIpDkFgYhImlMQiIikOQWBiEiaUxCIxJjZnrjLeleYWaWN7DWzzPiZJEVSSc2oCxBJITvcPSvqIkSSTWcEIgdgZrlm9iszW2Vmb5rZd2PbM83sZTNbaWYvmVmb2PajzOwvZvZ27FYwNUINM3sgNsf+C2ZWL7b/lbE1Klaa2dyIfk1JYwoCkUL1ijUNnRf33BZ37wbcQ5g1FeA3wCPu3h3IAWbGts8EXnH3HoT5glbHtncAZrl7F2AzMCq2fTLQM3ac8Yn65URKo5HFIjFmtt3dG5SwPRc42d3XxSYN/NTdm5jZF8Ax7r4rtv0Td29qZhuBVu7+TdwxMoG/xRYVwcxuAGq5+8/N7DlgOzAfmO/u2xP8q4oUoTMCkfLxUu5XxDdx9/dQ2Ec3DJhFOHtYGluIRSRpFAQi5XNe3M+/x+6/TuHyidnAktj9l4AJsG9N5kalHdTMMoDW7r4QuAFoBOx3ViKSSPrLQ6RQPTNbEff4OXcvuIT0yNisoN8Ao2PbriCsKHYdYXWxH8e2XwXMjs0YuYcQCp9QshrAH2JhYcBMLVEpyaY+ApEDiPUR9HH3L6KuRSQR1DQkIpLmdEYgIpLmdEYgIpLmFAQiImlOQSAikuYUBCIiaU5BICKS5v4/ZjylCowYTRwAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "0DESR45GQ8WS",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 294
},
"outputId": "b89e2407-e493-43ee-fa0b-53d92082c84d"
},
"source": [
"plt.plot(epochs, acc, 'bo', label='Training acc')\n",
"plt.plot(epochs, val_acc, 'b', label='Validation acc')\n",
"plt.title('Training and validation accuracy')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Accuracy')\n",
"plt.legend(loc='lower right')\n",
"\n",
"plt.show()"
],
"execution_count": 159,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zfavK608SbG4",
"colab_type": "text"
},
"source": [
"In this plot, the dots represent the training loss and accuracy, and the solid lines are the validation loss and accuracy.\n",
"\n",
"Notice the training loss decreases with each epoch and the training accuracy increases with each epoch. This is expected when using a gradient descent optimization—it should minimize the desired quantity on every iteration.\n",
"\n",
"This isn't the case for the validation loss and accuracy—they seem to peak before the training accuracy. This is an example of overfitting: the model performs better on the training data than it does on data it has never seen before. After this point, the model over-optimizes and learns representations specific to the training data that do not generalize to test data.\n",
"\n",
"For this particular case, you could prevent overfitting by simply stopping the training when the validation accuracy is no longer increasing. One way to do so is to use the `EarlyStopping` callback."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sJZbc_SmARUq",
"colab_type": "text"
},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wdm8YKwJf0pF",
"colab_type": "text"
},
"source": [
"## Clustering in Tensorflow\n",
"\n",
"In machine learning, we often group examples as a first step to understand a subject (data set) in a machine learning system. Grouping unlabeled examples is called clustering.\n",
"\n",
"As the examples are **unlabeled**, clustering relies on **unsupervised machine learning**. If the examples are **labeled**, then clustering becomes **classification**.\n",
"\n",
"Broadly, there are two main kinds of Clustering Algorithm\n",
"- K-Means Clustering\n",
"- Hierachical Clustering\n",
"\n",
"Clustering has a myriad of uses in a variety of industries. Some common applications for clustering include the following:\n",
"\n",
" - market segmentation\n",
" - social network analysis\n",
" - search result grouping\n",
" - medical imaging\n",
" - image segmentation\n",
" - anomaly detection\n",
"\n",
"### **Hidden Markov Model**\n",
"\n",
"\n",
"\"The Hidden Markov Model is a finite set of states, each of which is associated with a (generally multidimensional) probability distribution. Transitions among the states are governed by a set of probabilities called transition probabilities.\" (http://jedlik.phy.bme.hu/~gerjanos/HMM/node4.html)\n",
"\n",
"A hidden markov model works with probabilities to predict future events or states.\n",
"\n",
"**States:** In each markov model we have a finite set of states. These states could be something like \"warm\" and \"cold\" or \"high\" and \"low\" or even \"red\", \"green\" and \"blue\". These states are \"hidden\" within the model, which means we do not direcly observe them.\n",
"\n",
"**Observations:** Each state has a particular outcome or observation associated with it based on a probability distribution. An example of this is the following: \n",
"> *On a hot day Tim has a 80% chance of being happy and a 20% chance of being sad.*\n",
"\n",
"**Transitions:** Each state will have a probability defining the likelyhood of transitioning to a different state. An example is the following: \n",
"> *a cold day has a 30% chance of being followed by a hot day and a 70% chance of being follwed by another cold day.*\n",
"\n",
"To create a hidden markov model we need.\n",
"- States\n",
"- Observation Distribution\n",
"- Transition Distribution\n",
"\n",
"Read more about it from the links given below:\n",
"\n",
"**Links**:\n",
"- https://medium.com/@kangeugine/hidden-markov-model-7681c22f5b9\n",
"- https://en.wikipedia.org/wiki/Hidden_Markov_model\n",
"- https://www.tensorflow.org/probability/\n",
"- https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/HiddenMarkovModel\n",
"- https://www.tensorflow.org/probability/api_docs/python/tfp/edward2/HiddenMarkovModel\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "y2VCgiG9OtYb",
"colab_type": "code",
"colab": {}
},
"source": [
"from IPython.display import clear_output\n",
"import tensorflow_probability as tfp\n",
"import tensorflow as tf"
],
"execution_count": 162,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "58I0VyKMsVNk",
"colab_type": "text"
},
"source": [
"### A Weather Model\n",
"\n",
"\n",
"We will model a simple weather system and try to predict the average temperature on each day given the following information.\n",
"1. Cold days are encoded by a 0 and hot days are encoded by a 1.\n",
"2. The first day in our sequence has an 80% chance of being cold.\n",
"3. A cold day has a 30% chance of being followed by a hot day.\n",
"4. A hot day has a 20% chance of being followed by a cold day.\n",
"5. On each day the temperature is\n",
" normally distributed with mean and standard deviation 0 and 5 on\n",
" a cold day and mean and standard deviation 15 and 10 on a hot day.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0sfPi_HYyxxt",
"colab_type": "text"
},
"source": [
"#### Define the distribution variables of the model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5p62mw_ilpLk",
"colab_type": "code",
"colab": {}
},
"source": [
"tfd = tfp.distributions\n",
"\n",
"# probs[0] : probability of day being cold\n",
"# probs[1] : probability of day being hot\n",
"\n",
"# first day : initial_distribution\n",
"# 80% chance of cold, 20% chance of hot\n",
"init_dist = tfd.Categorical(probs=[0.8, 0.2])\n",
"\n",
"# after first day : transitiion_distribution\n",
"# 70% chance of cold, 30% chance of hot\n",
"# 20% chance of cold, 80% chance of hot \n",
"trans_dist = tfd.Categorical(probs=[[0.7, 0.3],\n",
" [0.2, 0.8]])\n",
"\n",
"# Each day Temperature Distribution : observation_distribution\n",
"# Cold day : Mean = 0, Std. Deviation = 5\n",
"# Hot Day : Mean = 15, Std. Deviation = 10\n",
"# For any single day : Range of mean temp. = [0, 15]\n",
"# For any single day : Range of std. dev temp. = [5, 10]\n",
"obs_dist = tfd.Normal(loc=[0., 15.], scale = [5., 10.]) "
],
"execution_count": 163,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "FSvWongiy-Wt",
"colab_type": "text"
},
"source": [
"#### Create our `HiddenMarkovModel`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rspJP8eWzpct",
"colab_type": "code",
"colab": {}
},
"source": [
"# combine distributions into a single week long\n",
"hmm_model = tfd.HiddenMarkovModel(\n",
" initial_distribution = init_dist,\n",
" transition_distribution = trans_dist,\n",
" observation_distribution = obs_dist,\n",
" # num_steps: number of days in a week\n",
" num_steps = 7\n",
" )\n",
"\n",
"clear_output()"
],
"execution_count": 164,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ep6l8huz20Vw",
"colab_type": "text"
},
"source": [
"#### Get Average Temperature Values For a Week "
]
},
{
"cell_type": "code",
"metadata": {
"id": "6fA53GyV2El3",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
},
"outputId": "9af16ccd-d553-4991-fcb8-30a150e762a9"
},
"source": [
"mean = hmm_model.mean()\n",
"\n",
"print(mean.numpy())"
],
"execution_count": 165,
"outputs": [
{
"output_type": "stream",
"text": [
"[3. 6. 7.5 8.249999 8.625001 8.812501 8.90625 ]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GT5TapMY5C8-",
"colab_type": "text"
},
"source": [
"---"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment