-
-
Save katsugeneration/f2bd61d7745fe8cc7b2fd9730d50bb8b to your computer and use it in GitHub Desktop.
TensorFlow Indexing.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "TensorFlow Indexing.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyMtVRoWsnwzqUl8DjZzokeq", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/katsugeneration/f2bd61d7745fe8cc7b2fd9730d50bb8b/tensorflow-indexing.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JkzQGL8bLetP", | |
"colab_type": "code", | |
"outputId": "d2cb477a-50bf-4a14-abf2-2ff8d0e6c8e7", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 632 | |
} | |
}, | |
"source": [ | |
"!pip install tensorflow==2.1.0" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: tensorflow==2.1.0 in /usr/local/lib/python3.6/dist-packages (2.1.0)\n", | |
"Requirement already satisfied: tensorflow-estimator<2.2.0,>=2.1.0rc0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (2.1.0)\n", | |
"Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.24.3)\n", | |
"Requirement already satisfied: numpy<2.0,>=1.16.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.18.2)\n", | |
"Requirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (3.10.0)\n", | |
"Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.1.0)\n", | |
"Requirement already satisfied: keras-applications>=1.0.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.0.8)\n", | |
"Requirement already satisfied: keras-preprocessing>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.1.0)\n", | |
"Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (0.34.2)\n", | |
"Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (0.9.0)\n", | |
"Requirement already satisfied: google-pasta>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (0.2.0)\n", | |
"Requirement already satisfied: gast==0.2.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (0.2.2)\n", | |
"Requirement already satisfied: tensorboard<2.2.0,>=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (2.1.1)\n", | |
"Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.12.0)\n", | |
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (3.2.0)\n", | |
"Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (0.8.1)\n", | |
"Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.12.1)\n", | |
"Requirement already satisfied: scipy==1.4.1; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.4.1)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.8.0->tensorflow==2.1.0) (46.0.0)\n", | |
"Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.8->tensorflow==2.1.0) (2.8.0)\n", | |
"Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (2.21.0)\n", | |
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (1.0.0)\n", | |
"Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (1.7.2)\n", | |
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (0.4.1)\n", | |
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (3.2.1)\n", | |
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (1.24.3)\n", | |
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (3.0.4)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (2019.11.28)\n", | |
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (2.8)\n", | |
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (0.2.8)\n", | |
"Requirement already satisfied: rsa<4.1,>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (4.0)\n", | |
"Requirement already satisfied: cachetools<3.2,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (3.1.1)\n", | |
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (1.3.0)\n", | |
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.6/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (0.4.8)\n", | |
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (3.1.0)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "cgs8yFIPLI0A", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import tensorflow as tf\n", | |
"import numpy as np" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "KXKJTu4oMNTf", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# NumpyでのIndexing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LDfdb4PVMJF7", | |
"colab_type": "code", | |
"outputId": "12057173-9b73-4e3f-b3e6-64072450c874", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 119 | |
} | |
}, | |
"source": [ | |
"a = np.ones((4, 4, 4))\n", | |
"b = np.array([[2, 3]])\n", | |
"c = np.array([1, 2])\n", | |
"d = np.array([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)\n", | |
"\n", | |
"print(a[:, 2:3].shape)\n", | |
"print(a[:, [2, 3]].shape)\n", | |
"print(a[..., 2:3].shape)\n", | |
"print(a[b].shape)\n", | |
"print(a[b, c].shape)\n", | |
"print(a[d])" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"(4, 1, 4)\n", | |
"(4, 2, 4)\n", | |
"(4, 4, 1)\n", | |
"(1, 2, 4, 4)\n", | |
"(1, 2, 4)\n", | |
"[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aQlWzm8fM58e", | |
"colab_type": "code", | |
"outputId": "c4ed88d5-a488-4eb7-f8be-0fca5ce26c85", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 340 | |
} | |
}, | |
"source": [ | |
"a = np.ones((4, 4, 4))\n", | |
"a[:, 2:3] = np.ones((4, 1, 4)) * 2\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[[[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]]]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "7FfC5TEtNPfw", | |
"colab_type": "code", | |
"outputId": "a04e0035-4082-4f1a-8323-66bd762c4320", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 340 | |
} | |
}, | |
"source": [ | |
"a = np.ones((4, 4, 4))\n", | |
"a[:, [2, 3]] = np.ones((4, 2, 4)) * 2\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[[[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "mJmTALemNaZv", | |
"colab_type": "code", | |
"outputId": "6f50cdf1-a86e-4104-e981-53cde2022066", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 340 | |
} | |
}, | |
"source": [ | |
"a = np.ones((4, 4, 4))\n", | |
"a[..., 2:3] = np.ones((4, 4, 1)) * 2\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[[[1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]]\n", | |
"\n", | |
" [[1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]]\n", | |
"\n", | |
" [[1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]]\n", | |
"\n", | |
" [[1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]\n", | |
" [1. 1. 2. 1.]]]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Z6eRDA9xNhwW", | |
"colab_type": "code", | |
"outputId": "a88d31e3-9733-4317-a61a-015725f1860b", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 340 | |
} | |
}, | |
"source": [ | |
"a = np.ones((4, 4, 4))\n", | |
"a[b] = np.ones((1, 2, 4, 4)) * 2\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[[[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "17Q-2ms6QsI4", | |
"colab_type": "code", | |
"outputId": "303d7e01-63cb-4487-dba9-d0572b6dabd0", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 340 | |
} | |
}, | |
"source": [ | |
"a = np.ones((4, 4, 4))\n", | |
"a[b, c] = np.ones((1, 2, 4)) * 2\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[[[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]]]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "jF4DrNWlAc6J", | |
"colab_type": "code", | |
"outputId": "af6579ba-6765-47e8-c2af-04b15f49f113", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 340 | |
} | |
}, | |
"source": [ | |
"a = np.ones((4, 4, 4))\n", | |
"a[d] = np.ones((4, 4, 4))[d] * 2\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[[[2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "2RqxmoKZM0Gh", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# TensorFlowでのIndexing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "TbrArbh8MiVH", | |
"colab_type": "code", | |
"outputId": "a8a5b339-47b3-4cb6-8ca6-c6d4a1ccee8a", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 119 | |
} | |
}, | |
"source": [ | |
"a = tf.ones((4, 4, 4))\n", | |
"b = tf.constant([[2, 3]])\n", | |
"c = tf.constant([1, 2])\n", | |
"d = tf.constant([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)\n", | |
"\n", | |
"print(a[:, 2:3].shape)\n", | |
"print(tf.gather(a, [2, 3], axis=1).shape)\n", | |
"print(a[..., 2:3].shape)\n", | |
"print(tf.gather(a, b, axis=0).shape)\n", | |
"print(tf.gather_nd(a, tf.stack([b, [c]], axis=-1)).shape)\n", | |
"print(a[d])" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"(4, 1, 4)\n", | |
"(4, 2, 4)\n", | |
"(4, 4, 1)\n", | |
"(1, 2, 4, 4)\n", | |
"(1, 2, 4)\n", | |
"tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(16,), dtype=float32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "1Ptc8UAVQ5rr", | |
"colab_type": "code", | |
"outputId": "a452bb3a-320c-45f0-fb90-7d36ff3c439b", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 544 | |
} | |
}, | |
"source": [ | |
"print(tf.gather_nd(a, [2, 3]))\n", | |
"print()\n", | |
"print(tf.gather_nd(a, [2, 3, 3]))\n", | |
"print()\n", | |
"print(tf.gather_nd(a, [[2, 3], [1, 2]]))\n", | |
"print()\n", | |
"print(tf.gather_nd(a, [[[[2,3]]]]))\n", | |
"print()\n", | |
"print(tf.gather_nd(a, [[[2], [3]], [[1], [2]]]))" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"tf.Tensor([1. 1. 1. 1.], shape=(4,), dtype=float32)\n", | |
"\n", | |
"tf.Tensor(1.0, shape=(), dtype=float32)\n", | |
"\n", | |
"tf.Tensor(\n", | |
"[[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]], shape=(2, 4), dtype=float32)\n", | |
"\n", | |
"tf.Tensor([[[[1. 1. 1. 1.]]]], shape=(1, 1, 1, 4), dtype=float32)\n", | |
"\n", | |
"tf.Tensor(\n", | |
"[[[[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]]\n", | |
"\n", | |
"\n", | |
" [[[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]]], shape=(2, 2, 4, 4), dtype=float32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "DictFPc5OPvW", | |
"colab_type": "code", | |
"outputId": "814af2fc-6ed1-469c-c763-00e4c79ad5af", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 357 | |
} | |
}, | |
"source": [ | |
"a = tf.Variable(tf.ones((4, 4, 4)))\n", | |
"a = a[:, 2:3].assign(tf.ones((4, 1, 4)) * 2)\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"<tf.Variable 'UnreadVariable' shape=(4, 4, 4) dtype=float32, numpy=\n", | |
"array([[[1., 1., 1., 1.],\n", | |
" [1., 1., 1., 1.],\n", | |
" [2., 2., 2., 2.],\n", | |
" [1., 1., 1., 1.]],\n", | |
"\n", | |
" [[1., 1., 1., 1.],\n", | |
" [1., 1., 1., 1.],\n", | |
" [2., 2., 2., 2.],\n", | |
" [1., 1., 1., 1.]],\n", | |
"\n", | |
" [[1., 1., 1., 1.],\n", | |
" [1., 1., 1., 1.],\n", | |
" [2., 2., 2., 2.],\n", | |
" [1., 1., 1., 1.]],\n", | |
"\n", | |
" [[1., 1., 1., 1.],\n", | |
" [1., 1., 1., 1.],\n", | |
" [2., 2., 2., 2.],\n", | |
" [1., 1., 1., 1.]]], dtype=float32)>\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "bmngHVB-VesZ", | |
"colab_type": "code", | |
"outputId": "dbbc9fb1-939e-4491-e9a3-91be8e272e6a", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 357 | |
} | |
}, | |
"source": [ | |
"a = tf.Variable(tf.ones((4, 4, 4)))\n", | |
"first = tf.tile(tf.expand_dims(tf.range(4), axis=1), (1, 2))\n", | |
"indexes = tf.tile([[2, 3]], (4, 1))\n", | |
"indices = tf.stack([first, indexes], axis=-1)\n", | |
"a = tf.tensor_scatter_nd_update(a, indices, tf.ones((4, 2, 4)) * 2)\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"tf.Tensor(\n", | |
"[[[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]], shape=(4, 4, 4), dtype=float32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VBEhEQkna6yp", | |
"colab_type": "code", | |
"outputId": "b17731e9-3bed-46c2-cd2c-1620dd095067", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 357 | |
} | |
}, | |
"source": [ | |
"a = tf.Variable(tf.ones((4, 4, 4)))\n", | |
"a[..., 2:3].assign(np.ones((4, 4, 1)) * 2)\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"<tf.Variable 'Variable:0' shape=(4, 4, 4) dtype=float32, numpy=\n", | |
"array([[[1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.]],\n", | |
"\n", | |
" [[1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.]],\n", | |
"\n", | |
" [[1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.]],\n", | |
"\n", | |
" [[1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.],\n", | |
" [1., 1., 2., 1.]]], dtype=float32)>\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dohp01hWulc2", | |
"colab_type": "code", | |
"outputId": "fb476830-ea93-4ea2-a248-514e6af61d4d", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 357 | |
} | |
}, | |
"source": [ | |
"a = tf.Variable(tf.ones((4, 4, 4)))\n", | |
"a = tf.tensor_scatter_nd_update(a, tf.expand_dims(b, axis=-1), tf.ones((1, 2, 4, 4)) * 2)\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"tf.Tensor(\n", | |
"[[[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]], shape=(4, 4, 4), dtype=float32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2KvdK9Jwu82Y", | |
"colab_type": "code", | |
"outputId": "d2c999a0-4238-4559-84f4-70c6d0b49101", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 357 | |
} | |
}, | |
"source": [ | |
"indices = tf.stack([b, [c]], axis=-1)\n", | |
"a = tf.tensor_scatter_nd_update(a, indices, tf.ones((1, 2, 4)) * 2)\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"tf.Tensor(\n", | |
"[[[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]\n", | |
" [2. 2. 2. 2.]]], shape=(4, 4, 4), dtype=float32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "27rtLRJfAKce", | |
"colab_type": "code", | |
"outputId": "e0d1dad2-a57c-4afa-b4c4-527b1115457e", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 357 | |
} | |
}, | |
"source": [ | |
"a = tf.Variable(tf.ones((4, 4, 4)))\n", | |
"d = tf.constant([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)\n", | |
"d = tf.cast(d, dtype=a.dtype)\n", | |
"a = (1 - d) * a + d * np.ones((4, 4, 4)) * 2\n", | |
"print(a)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"tf.Tensor(\n", | |
"[[[2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]\n", | |
"\n", | |
" [[2. 2. 2. 2.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]\n", | |
" [1. 1. 1. 1.]]], shape=(4, 4, 4), dtype=float32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LgLGCkTdA5zG", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment