Skip to content

Instantly share code, notes, and snippets.

@katsugeneration
Last active March 20, 2020 14:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save katsugeneration/f2bd61d7745fe8cc7b2fd9730d50bb8b to your computer and use it in GitHub Desktop.
Save katsugeneration/f2bd61d7745fe8cc7b2fd9730d50bb8b to your computer and use it in GitHub Desktop.
TensorFlow Indexing.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "TensorFlow Indexing.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMtVRoWsnwzqUl8DjZzokeq",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/katsugeneration/f2bd61d7745fe8cc7b2fd9730d50bb8b/tensorflow-indexing.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "JkzQGL8bLetP",
"colab_type": "code",
"outputId": "d2cb477a-50bf-4a14-abf2-2ff8d0e6c8e7",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 632
}
},
"source": [
"!pip install tensorflow==2.1.0"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: tensorflow==2.1.0 in /usr/local/lib/python3.6/dist-packages (2.1.0)\n",
"Requirement already satisfied: tensorflow-estimator<2.2.0,>=2.1.0rc0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (2.1.0)\n",
"Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.24.3)\n",
"Requirement already satisfied: numpy<2.0,>=1.16.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.18.2)\n",
"Requirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (3.10.0)\n",
"Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.1.0)\n",
"Requirement already satisfied: keras-applications>=1.0.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.0.8)\n",
"Requirement already satisfied: keras-preprocessing>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.1.0)\n",
"Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (0.34.2)\n",
"Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (0.9.0)\n",
"Requirement already satisfied: google-pasta>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (0.2.0)\n",
"Requirement already satisfied: gast==0.2.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (0.2.2)\n",
"Requirement already satisfied: tensorboard<2.2.0,>=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (2.1.1)\n",
"Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.12.0)\n",
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (3.2.0)\n",
"Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (0.8.1)\n",
"Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.12.1)\n",
"Requirement already satisfied: scipy==1.4.1; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.1.0) (1.4.1)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.8.0->tensorflow==2.1.0) (46.0.0)\n",
"Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.8->tensorflow==2.1.0) (2.8.0)\n",
"Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (2.21.0)\n",
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (1.0.0)\n",
"Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (1.7.2)\n",
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (0.4.1)\n",
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (3.2.1)\n",
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (1.24.3)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (2019.11.28)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (2.8)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (0.2.8)\n",
"Requirement already satisfied: rsa<4.1,>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (4.0)\n",
"Requirement already satisfied: cachetools<3.2,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (3.1.1)\n",
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (1.3.0)\n",
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.6/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (0.4.8)\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.2.0,>=2.1.0->tensorflow==2.1.0) (3.1.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cgs8yFIPLI0A",
"colab_type": "code",
"colab": {}
},
"source": [
"import tensorflow as tf\n",
"import numpy as np"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "KXKJTu4oMNTf",
"colab_type": "text"
},
"source": [
"# NumpyでのIndexing"
]
},
{
"cell_type": "code",
"metadata": {
"id": "LDfdb4PVMJF7",
"colab_type": "code",
"outputId": "12057173-9b73-4e3f-b3e6-64072450c874",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
}
},
"source": [
"a = np.ones((4, 4, 4))\n",
"b = np.array([[2, 3]])\n",
"c = np.array([1, 2])\n",
"d = np.array([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)\n",
"\n",
"print(a[:, 2:3].shape)\n",
"print(a[:, [2, 3]].shape)\n",
"print(a[..., 2:3].shape)\n",
"print(a[b].shape)\n",
"print(a[b, c].shape)\n",
"print(a[d])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"(4, 1, 4)\n",
"(4, 2, 4)\n",
"(4, 4, 1)\n",
"(1, 2, 4, 4)\n",
"(1, 2, 4)\n",
"[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aQlWzm8fM58e",
"colab_type": "code",
"outputId": "c4ed88d5-a488-4eb7-f8be-0fca5ce26c85",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 340
}
},
"source": [
"a = np.ones((4, 4, 4))\n",
"a[:, 2:3] = np.ones((4, 1, 4)) * 2\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"[[[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "7FfC5TEtNPfw",
"colab_type": "code",
"outputId": "a04e0035-4082-4f1a-8323-66bd762c4320",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 340
}
},
"source": [
"a = np.ones((4, 4, 4))\n",
"a[:, [2, 3]] = np.ones((4, 2, 4)) * 2\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"[[[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "mJmTALemNaZv",
"colab_type": "code",
"outputId": "6f50cdf1-a86e-4104-e981-53cde2022066",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 340
}
},
"source": [
"a = np.ones((4, 4, 4))\n",
"a[..., 2:3] = np.ones((4, 4, 1)) * 2\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"[[[1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]]\n",
"\n",
" [[1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]]\n",
"\n",
" [[1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]]\n",
"\n",
" [[1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]\n",
" [1. 1. 2. 1.]]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Z6eRDA9xNhwW",
"colab_type": "code",
"outputId": "a88d31e3-9733-4317-a61a-015725f1860b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 340
}
},
"source": [
"a = np.ones((4, 4, 4))\n",
"a[b] = np.ones((1, 2, 4, 4)) * 2\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"[[[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "17Q-2ms6QsI4",
"colab_type": "code",
"outputId": "303d7e01-63cb-4487-dba9-d0572b6dabd0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 340
}
},
"source": [
"a = np.ones((4, 4, 4))\n",
"a[b, c] = np.ones((1, 2, 4)) * 2\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"[[[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "jF4DrNWlAc6J",
"colab_type": "code",
"outputId": "af6579ba-6765-47e8-c2af-04b15f49f113",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 340
}
},
"source": [
"a = np.ones((4, 4, 4))\n",
"a[d] = np.ones((4, 4, 4))[d] * 2\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"[[[2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2RqxmoKZM0Gh",
"colab_type": "text"
},
"source": [
"# TensorFlowでのIndexing"
]
},
{
"cell_type": "code",
"metadata": {
"id": "TbrArbh8MiVH",
"colab_type": "code",
"outputId": "a8a5b339-47b3-4cb6-8ca6-c6d4a1ccee8a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
}
},
"source": [
"a = tf.ones((4, 4, 4))\n",
"b = tf.constant([[2, 3]])\n",
"c = tf.constant([1, 2])\n",
"d = tf.constant([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)\n",
"\n",
"print(a[:, 2:3].shape)\n",
"print(tf.gather(a, [2, 3], axis=1).shape)\n",
"print(a[..., 2:3].shape)\n",
"print(tf.gather(a, b, axis=0).shape)\n",
"print(tf.gather_nd(a, tf.stack([b, [c]], axis=-1)).shape)\n",
"print(a[d])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"(4, 1, 4)\n",
"(4, 2, 4)\n",
"(4, 4, 1)\n",
"(1, 2, 4, 4)\n",
"(1, 2, 4)\n",
"tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(16,), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "1Ptc8UAVQ5rr",
"colab_type": "code",
"outputId": "a452bb3a-320c-45f0-fb90-7d36ff3c439b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 544
}
},
"source": [
"print(tf.gather_nd(a, [2, 3]))\n",
"print()\n",
"print(tf.gather_nd(a, [2, 3, 3]))\n",
"print()\n",
"print(tf.gather_nd(a, [[2, 3], [1, 2]]))\n",
"print()\n",
"print(tf.gather_nd(a, [[[[2,3]]]]))\n",
"print()\n",
"print(tf.gather_nd(a, [[[2], [3]], [[1], [2]]]))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor([1. 1. 1. 1.], shape=(4,), dtype=float32)\n",
"\n",
"tf.Tensor(1.0, shape=(), dtype=float32)\n",
"\n",
"tf.Tensor(\n",
"[[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]], shape=(2, 4), dtype=float32)\n",
"\n",
"tf.Tensor([[[[1. 1. 1. 1.]]]], shape=(1, 1, 1, 4), dtype=float32)\n",
"\n",
"tf.Tensor(\n",
"[[[[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]]\n",
"\n",
"\n",
" [[[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]]], shape=(2, 2, 4, 4), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DictFPc5OPvW",
"colab_type": "code",
"outputId": "814af2fc-6ed1-469c-c763-00e4c79ad5af",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 357
}
},
"source": [
"a = tf.Variable(tf.ones((4, 4, 4)))\n",
"a = a[:, 2:3].assign(tf.ones((4, 1, 4)) * 2)\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"<tf.Variable 'UnreadVariable' shape=(4, 4, 4) dtype=float32, numpy=\n",
"array([[[1., 1., 1., 1.],\n",
" [1., 1., 1., 1.],\n",
" [2., 2., 2., 2.],\n",
" [1., 1., 1., 1.]],\n",
"\n",
" [[1., 1., 1., 1.],\n",
" [1., 1., 1., 1.],\n",
" [2., 2., 2., 2.],\n",
" [1., 1., 1., 1.]],\n",
"\n",
" [[1., 1., 1., 1.],\n",
" [1., 1., 1., 1.],\n",
" [2., 2., 2., 2.],\n",
" [1., 1., 1., 1.]],\n",
"\n",
" [[1., 1., 1., 1.],\n",
" [1., 1., 1., 1.],\n",
" [2., 2., 2., 2.],\n",
" [1., 1., 1., 1.]]], dtype=float32)>\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "bmngHVB-VesZ",
"colab_type": "code",
"outputId": "dbbc9fb1-939e-4491-e9a3-91be8e272e6a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 357
}
},
"source": [
"a = tf.Variable(tf.ones((4, 4, 4)))\n",
"first = tf.tile(tf.expand_dims(tf.range(4), axis=1), (1, 2))\n",
"indexes = tf.tile([[2, 3]], (4, 1))\n",
"indices = tf.stack([first, indexes], axis=-1)\n",
"a = tf.tensor_scatter_nd_update(a, indices, tf.ones((4, 2, 4)) * 2)\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]], shape=(4, 4, 4), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "VBEhEQkna6yp",
"colab_type": "code",
"outputId": "b17731e9-3bed-46c2-cd2c-1620dd095067",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 357
}
},
"source": [
"a = tf.Variable(tf.ones((4, 4, 4)))\n",
"a[..., 2:3].assign(np.ones((4, 4, 1)) * 2)\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"<tf.Variable 'Variable:0' shape=(4, 4, 4) dtype=float32, numpy=\n",
"array([[[1., 1., 2., 1.],\n",
" [1., 1., 2., 1.],\n",
" [1., 1., 2., 1.],\n",
" [1., 1., 2., 1.]],\n",
"\n",
" [[1., 1., 2., 1.],\n",
" [1., 1., 2., 1.],\n",
" [1., 1., 2., 1.],\n",
" [1., 1., 2., 1.]],\n",
"\n",
" [[1., 1., 2., 1.],\n",
" [1., 1., 2., 1.],\n",
" [1., 1., 2., 1.],\n",
" [1., 1., 2., 1.]],\n",
"\n",
" [[1., 1., 2., 1.],\n",
" [1., 1., 2., 1.],\n",
" [1., 1., 2., 1.],\n",
" [1., 1., 2., 1.]]], dtype=float32)>\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dohp01hWulc2",
"colab_type": "code",
"outputId": "fb476830-ea93-4ea2-a248-514e6af61d4d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 357
}
},
"source": [
"a = tf.Variable(tf.ones((4, 4, 4)))\n",
"a = tf.tensor_scatter_nd_update(a, tf.expand_dims(b, axis=-1), tf.ones((1, 2, 4, 4)) * 2)\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]], shape=(4, 4, 4), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "2KvdK9Jwu82Y",
"colab_type": "code",
"outputId": "d2c999a0-4238-4559-84f4-70c6d0b49101",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 357
}
},
"source": [
"indices = tf.stack([b, [c]], axis=-1)\n",
"a = tf.tensor_scatter_nd_update(a, indices, tf.ones((1, 2, 4)) * 2)\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]\n",
" [2. 2. 2. 2.]]], shape=(4, 4, 4), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "27rtLRJfAKce",
"colab_type": "code",
"outputId": "e0d1dad2-a57c-4afa-b4c4-527b1115457e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 357
}
},
"source": [
"a = tf.Variable(tf.ones((4, 4, 4)))\n",
"d = tf.constant([[[True, True, True, True], [False, False, False, False], [False, False, False, False], [False, False, False, False]]] * 4)\n",
"d = tf.cast(d, dtype=a.dtype)\n",
"a = (1 - d) * a + d * np.ones((4, 4, 4)) * 2\n",
"print(a)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[[2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]\n",
"\n",
" [[2. 2. 2. 2.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]\n",
" [1. 1. 1. 1.]]], shape=(4, 4, 4), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LgLGCkTdA5zG",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment