Last active
October 29, 2020 22:52
-
-
Save visualDust/358c4e9c7c654acc2a64a74c4688e44d to your computer and use it in GitHub Desktop.
TF2高阶操作
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "TF2高阶操作", | |
"provenance": [], | |
"collapsed_sections": [], | |
"toc_visible": true, | |
"authorship_tag": "ABX9TyPHyuBfRr9HLFoUP2veMVy/", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/visualDust/358c4e9c7c654acc2a64a74c4688e44d/tf2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "tr5--0KmJYes" | |
}, | |
"source": [ | |
"# 合并与分割(Merge and split) \n", | |
"* tf.concat\n", | |
"* tf.split\n", | |
"* tf.stack\n", | |
"* tf.unstack" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "waPKeLXlPaXH" | |
}, | |
"source": [ | |
"# importing\n", | |
"import os\n", | |
"# changing env log level\n", | |
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n", | |
"import tensorflow as tf\n", | |
"from tensorflow.keras import layers\n", | |
"from tensorflow.keras import optimizers" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "XLzdji8OJvjZ" | |
}, | |
"source": [ | |
"## concat\n", | |
"concat用于张量的拼接操作。例如:一共有六个班级需要统计成绩。其中第一个人统计前四个班级的成绩,另一个人统计后两个班级的成绩。假设每个班有35人,每个人有八门科目的成绩,那么两个人获得的成绩单的shape应该分别是[4,35,8]和[2,35,8],拼接后的成绩单的shape应该是[6,35,8]。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "lVSJ_nAJJO5a", | |
"outputId": "a9151573-0419-46b7-eeda-f3f1216e3f2b", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"a = tf.ones([4,35,8])\n", | |
"b = tf.ones([2,35,8])\n", | |
"# The operation of concat is seemingly two steps. 1. broadcast. 2. combine \n", | |
"c = tf.concat([a,b],axis = 0)\n", | |
"print(c.shape)" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"(6, 35, 8)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "96WPotIHQCXh" | |
}, | |
"source": [ | |
"另外一个类似的场景是两个人统计一个班级的成绩信息,该班级一共有35名学生,第一个人统计前32名学生的成绩,第二个人统计后3名学生的成绩,拼接后得到全班的总成绩单。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0IiGnOf4QVem", | |
"outputId": "6d69b825-425d-4780-cfea-19f3fda26300", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"a = tf.ones([1,32,8])\n", | |
"b = tf.ones([1,3,8])\n", | |
"# The operation of concat is seemingly two steps. 1. broadcast. 2. combine \n", | |
"c = tf.concat([a,b],axis = 1)\n", | |
"print(c.shape)" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"(1, 35, 8)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "DPFtTqW7Tqwi" | |
}, | |
"source": [ | |
"当然也有这样的场景:现在有四个班级的人考了总共16门考试,其成绩分别记录在了两张表中,每张表记录了8门成绩。现在要将这些成绩放入同一张表:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pz67APchTrKb", | |
"outputId": "9b3794d6-161c-4248-fb79-d96e123220b7", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"a = tf.ones([4,35,8])\n", | |
"b = tf.ones([4,35,8])\n", | |
"c = tf.concat([a,b],axis = -1)\n", | |
"print(c.shape)" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"(4, 35, 16)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "dc2WbmwXQbob" | |
}, | |
"source": [ | |
"请注意这两个场景在运算时的区别是基于哪个维度进行拼接。第一种场景下对第0维度进行拼接,第二种场景下对第一维度进行拼接。 \n", | |
"* concat的使用限制条件为:出了要拼接的维度的大小可以不等之外其它维度需要相等。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "dg8ut_bQSbXn" | |
}, | |
"source": [ | |
"## stack\n", | |
"stack用于张量的堆叠操作。例如:现在有两个班级的成绩信息,张量结构为[class,student,scoer]。这两个班级分别属于两个学校,现在要将它们放入一张成绩表中,但是要能区分他们的学校。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "iMNRV5ZGQruP", | |
"outputId": "250b340f-3bc0-446f-d868-7b0a9eae00fc", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"a = tf.ones([4,35,8])\n", | |
"b = tf.ones([4,35,8])\n", | |
"# add a new dim and combine them\n", | |
"c = tf.stack([a,b],axis = 0)\n", | |
"print(c.shape)" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"(2, 4, 35, 8)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "TZs7oJELUq4C" | |
}, | |
"source": [ | |
"stack可选要扩展维度的位置,例如,我们希望将学校一列放在最后:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Xqz1djfuUvOt", | |
"outputId": "654f566b-00e0-4474-f7f4-8e7f29128ac9", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"a = tf.ones([4,35,8])\n", | |
"b = tf.ones([4,35,8])\n", | |
"# add a new dim and combine them\n", | |
"c = tf.stack([a,b],axis = 3)\n", | |
"print(c.shape)" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"(4, 35, 8, 2)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "TsEcng9fVdq2" | |
}, | |
"source": [ | |
"不过一般习惯上把更大的维度(学校)放在前面。 \n", | |
"* stack的使用限制条件为shape相等。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Btqr09g2Wtkr" | |
}, | |
"source": [ | |
"## unstack\n", | |
"对应stack,也有unstack。unstack可以在指定的axis上将tensor打散为该axis的size份" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JebktaPDWtT_", | |
"outputId": "ce7818e6-83aa-4c11-e8b5-825673311d1c", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"a = tf.ones([4,35,8])\n", | |
"b = tf.ones([4,35,8])\n", | |
"# add a new dim and combine them\n", | |
"c = tf.stack([a,b],axis = 0)\n", | |
"print(\"shape of the origin : \",c.shape)\n", | |
"# unstack\n", | |
"a_2,b_2 = tf.unstack(c,axis = 0)\n", | |
"print(\"after unstack : a2:\",a_2.shape,\",b2:\",b_2.shape)" | |
], | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"shape of the origin : (2, 4, 35, 8)\n", | |
"after unstack : a2: (4, 35, 8) ,b2: (4, 35, 8)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "yPmuIestPS24" | |
}, | |
"source": [ | |
"## split \n", | |
"unstack的使用场景有限。split的功能更加强大。split大体上有两种用法: \n", | |
"* 第一种:num_or_size_splits是数字,例如\"num_or_size_splits=2\"的情况,split会将tensor再指定的axis上分成两半。 \n", | |
"* 第二种:num_or_size_splits是一个list,例如\"num_or_size_splits=[1,2,3]\"的情况,split会将tensor在指定的axis上分为这个list的size份,在这里是3份,每份的相对大小分别是1、2、3。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WqGiWlWPPX37", | |
"outputId": "c9869449-2d96-4096-a3fd-6c9ef0bafb97", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"a = tf.ones([4,35,8])\n", | |
"b = tf.ones([4,35,8])\n", | |
"# add a new dim and combine them\n", | |
"c = tf.stack([a,b],axis = 0)\n", | |
"print(\"shape of the origin : \",c.shape)\n", | |
"# split into two part on axis 0\n", | |
"res = tf.split(c,axis = 0,num_or_size_splits=2)\n", | |
"print(\"after split into two part, len = \",len(res),\", shape = \",res[0].shape,\" and \",res[1].shape)" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"shape of the origin : (2, 4, 35, 8)\n", | |
"after split into two part, len = 2 , shape = (1, 4, 35, 8) and (1, 4, 35, 8)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "yRfteAaWRDl0", | |
"outputId": "b51f805d-6d7b-4b51-dc4b-461e1c522832", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"a = tf.ones([4,35,8])\n", | |
"b = tf.ones([4,35,8])\n", | |
"# add a new dim and combine them\n", | |
"c = tf.stack([a,b],axis = 0)\n", | |
"print(\"shape of the origin : \",c.shape)\n", | |
"# split into three part on axis 3, relative size = 2 ,2 ,4\n", | |
"res = tf.split(c,axis = 3,num_or_size_splits=[2,2,4])" | |
], | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"shape of the origin : (2, 4, 35, 8)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "iv7nxDvyaHYV" | |
}, | |
"source": [ | |
"# 数据统计\n", | |
"* tf.norm:张量范数(一范数、二范数、...、无穷范数)\n", | |
"* tf.reduce_min:最小值\n", | |
"* tf,reduce_max:最大值\n", | |
"* tf.argmin:最小值位置\n", | |
"* tf.argmax:最大值位置\n", | |
"* tf.equal:张量比较\n", | |
"* tf.unique:独特值" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "UBUKML45eX_2" | |
}, | |
"source": [ | |
"## tf.norm \n", | |
"为了好理解,暂时只讨论向量的范数。向量的二范数的公式为: \n", | |
"$$^2\\sqrt{sum_{i=1}^{size}x_i^2}$$ \n", | |
"向量的n范数的公式为: \n", | |
"$$^n\\sqrt{sum_{i=1}^{size}x_i^n}$$ \n", | |
"可以理解为:范数是一个函数,是一个向量到数值的映射。向量之间无法比较大小,进行范数运算之后就能直接比较大小了。再换句话理解,这是一种特殊的\"欧氏距离(x)\",可以比较向量到远点的距离(x)(我瞎理解的)。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "__bO63gbe3rc", | |
"outputId": "97648975-0ce5-4243-8cd1-0599d1fdf759", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"origin = tf.ones([2,2])\n", | |
"# 二范数\n", | |
"print(\"origin = \",origin,\"\\nafter norm: \",tf.norm(origin))\n", | |
"# 验证一下二范数的运算方式和我们上面说的是否一致\n", | |
"print(\"origin = \",origin,\"\\nafter square-square-aqrt : \",tf.sqrt(tf.reduce_sum(tf.square(origin))))\n", | |
"print(\"They are the same\")" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"origin = tf.Tensor(\n", | |
"[[1. 1.]\n", | |
" [1. 1.]], shape=(2, 2), dtype=float32) \n", | |
"after norm: tf.Tensor(2.0, shape=(), dtype=float32)\n", | |
"origin = tf.Tensor(\n", | |
"[[1. 1.]\n", | |
" [1. 1.]], shape=(2, 2), dtype=float32) \n", | |
"after square-square-aqrt : tf.Tensor(2.0, shape=(), dtype=float32)\n", | |
"They are the same\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JcRv1-WphNh5", | |
"outputId": "453bc006-bb13-46d3-bb68-5727347ac973", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"origin = tf.ones([4,28,28,3])\n", | |
"# more complex example\n", | |
"print(\"origin = \",origin.shape,\"\\nafter norm: \",tf.norm(origin)) \n", | |
"# 验证一下二范数的运算方式和我们上面说的是否一致\n", | |
"print(\"origin = \",origin.shape,\"\\nafter square-square-aqrt : \",tf.sqrt(tf.reduce_sum(tf.square(origin))))\n", | |
"print(\"They are the same\")" | |
], | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"origin = (4, 28, 28, 3) \n", | |
"after norm: tf.Tensor(96.99484, shape=(), dtype=float32)\n", | |
"origin = (4, 28, 28, 3) \n", | |
"after square-square-aqrt : tf.Tensor(96.99484, shape=(), dtype=float32)\n", | |
"They are the same\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "9mo0TUQgCUMT" | |
}, | |
"source": [ | |
"norm除了可以作用在整个张量上,也可以作用在某一个维度上。大概可以理解为对这个维度进行一次unstack然后再对unstack出来的每一个向量求norm。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AgpNVByuC419", | |
"outputId": "19fc6cc8-73df-4d50-8e26-1958d6515298", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"origin = tf.ones([4,28,28,3])\n", | |
"# norm working on specific axis\n", | |
"print(\"origin = \",origin.shape,\"\\nafter norm on axis = 3 : \",tf.norm(origin,axis = 3))" | |
], | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"origin = (4, 28, 28, 3) \n", | |
"after norm on axis = 3 : tf.Tensor(\n", | |
"[[[1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" ...\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]]\n", | |
"\n", | |
" [[1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" ...\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]]\n", | |
"\n", | |
" [[1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" ...\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]]\n", | |
"\n", | |
" [[1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" ...\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]\n", | |
" [1.7320508 1.7320508 1.7320508 ... 1.7320508 1.7320508 1.7320508]]], shape=(4, 28, 28), dtype=float32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "D10dcSG2EW46" | |
}, | |
"source": [ | |
"除了默认的二范数外,norm也可以求n范数。方法是指定ord参数。例如:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "NBYQNNIjEfyz", | |
"outputId": "78703c1a-2381-4127-e181-8341118365de", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"origin = tf.ones([2,2])\n", | |
"# 一范数\n", | |
"print(\"ord = 1 : \",tf.norm(origin,ord=1))\n", | |
"# 二范数\n", | |
"print(\"ord = 2 : \",tf.norm(origin,ord=2))\n", | |
"# 三范数\n", | |
"print(\"ord = 3 : \",tf.norm(origin,ord=3))\n", | |
"# 四范数\n", | |
"print(\"ord = 4 : \",tf.norm(origin,ord=4))\n", | |
"# 五范数\n", | |
"print(\"ord = 5 : \",tf.norm(origin,ord=5))" | |
], | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"ord = 1 : tf.Tensor(4.0, shape=(), dtype=float32)\n", | |
"ord = 2 : tf.Tensor(2.0, shape=(), dtype=float32)\n", | |
"ord = 3 : tf.Tensor(1.587401, shape=(), dtype=float32)\n", | |
"ord = 4 : tf.Tensor(1.4142135, shape=(), dtype=float32)\n", | |
"ord = 5 : tf.Tensor(1.319508, shape=(), dtype=float32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Pl35JjaQGasK" | |
}, | |
"source": [ | |
"## tf.reduce_min / max / mean / sum\n", | |
"* tf.reduce_min\n", | |
"* tf.reduce_max\n", | |
"* tf.reduce_mean\n", | |
"* tf.reduce_sum\n", | |
"\n", | |
"其实就是求最小值最大值平均值。名字里带着reduce表明,这个操作会有一个类似\"打平\"的过程。例如,当不指定axis参数时,一个[10,4]的tensor会被\"打平\"成一个[40]的\"list\"并求最大值、最小值....;再如,带有axis=2参数时,一个[10,4,10]的tensor会被\"降维\"变成一个元素为[10,4]的tensor的list,大小是10,然后对着十个元素进行最大、最小....运算。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "imKwvLzTIext", | |
"outputId": "f6e479b7-c103-4e1d-cc78-2d776777ebef", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"origin = tf.random.normal([4,10])\n", | |
"print(\"origin = \",origin,\"\\nreduce_min = \",tf.reduce_min(origin),\"\\nreduce_max = \",tf.reduce_max(origin),\"\\nreduce_mean = \",tf.reduce_mean(origin))" | |
], | |
"execution_count": 24, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"origin = tf.Tensor(\n", | |
"[[-1.4178391e+00 1.0799263e+00 1.8141992e+00 -2.9427743e-01\n", | |
" 3.5776252e-01 -6.9446379e-01 -7.1207196e-01 9.6388352e-01\n", | |
" -2.1230397e+00 4.8318788e-01]\n", | |
" [-4.1854006e-01 -2.2664030e-01 -9.8776561e-01 3.3819950e-01\n", | |
" 2.4363371e-02 -3.2178679e+00 -2.8521428e-01 -5.3039378e-01\n", | |
" -1.0285269e+00 -1.2320877e+00]\n", | |
" [ 6.0093373e-01 1.3320454e-02 9.5860285e-01 1.4495020e+00\n", | |
" 5.1962131e-01 1.1331964e+00 -1.0149366e+00 -5.1126540e-02\n", | |
" -5.0443190e-01 3.9746460e-01]\n", | |
" [-4.1444901e-01 -1.2171540e+00 -8.4814447e-01 1.4405949e+00\n", | |
" 7.2787516e-04 1.2379333e+00 1.0925928e+00 -9.9176753e-01\n", | |
" 3.8999468e-02 1.0164096e+00]], shape=(4, 10), dtype=float32) \n", | |
"reduce_min = tf.Tensor(-3.2178679, shape=(), dtype=float32) \n", | |
"reduce_max = tf.Tensor(1.8141992, shape=(), dtype=float32) \n", | |
"reduce_mean = tf.Tensor(-0.08123293, shape=(), dtype=float32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KePndTq3ML45", | |
"outputId": "0ae30f7c-b02d-4b2b-e5cc-428eddb1c654", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"origin = tf.random.normal([4,10])\n", | |
"print(\"origin = \",origin)\n", | |
"print(\"\\nreduce_min on axis 1 = \",tf.reduce_min(origin,axis = 1))\n", | |
"print(\"\\nreduce_max on axis 1 = \",tf.reduce_max(origin,axis = 1))\n", | |
"print(\"\\nreduce_mean on axis 1 = \",tf.reduce_mean(origin,axis = 1))" | |
], | |
"execution_count": 25, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"origin = tf.Tensor(\n", | |
"[[-3.1204236 0.67563623 -0.9232384 1.1589053 0.8515049 -0.47955766\n", | |
" -1.723766 0.12821583 -0.6078169 -0.07115268]\n", | |
" [-0.03351626 0.5452725 0.4999855 -0.13481826 0.6798329 0.23792107\n", | |
" -0.6113948 1.3868407 0.24892737 -0.41333905]\n", | |
" [-0.9676226 -0.3656622 -0.688232 1.721823 0.6695465 -0.44504106\n", | |
" 0.90125936 0.5428907 1.4090685 -0.9626962 ]\n", | |
" [-0.87203074 0.9285623 0.56897074 -1.4624474 1.8943952 -0.5554827\n", | |
" -0.8351434 -0.3565093 -1.5708245 -1.1640625 ]], shape=(4, 10), dtype=float32)\n", | |
"\n", | |
"reduce_min on axis 1 = tf.Tensor([-3.1204236 -0.6113948 -0.9676226 -1.5708245], shape=(4,), dtype=float32)\n", | |
"\n", | |
"reduce_max on axis 1 = tf.Tensor([1.1589053 1.3868407 1.721823 1.8943952], shape=(4,), dtype=float32)\n", | |
"\n", | |
"reduce_mean on axis 1 = tf.Tensor([-0.4111693 0.24057117 0.18153341 -0.34245723], shape=(4,), dtype=float32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "j_TVY0HPNPLb" | |
}, | |
"source": [ | |
"## tf.argmax/argmin \n", | |
"* tf.argmax\n", | |
"* tf.argmin\n", | |
"\n", | |
"用于求最小值和最大值的位置。当不指定axis参数时,默认再维度0上求每个维度下标下的最大、最小值的位置。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zR7Vtf61Nu7v", | |
"outputId": "b6a7b6e0-c81c-429b-9a04-d407a94d3f7c", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"origin = tf.nn.relu(tf.random.normal([4,3])*100)\n", | |
"print(\"origin = \",origin)\n", | |
"# 维度0下求最大值位置\n", | |
"print(\"argmax : \",tf.argmax(origin))" | |
], | |
"execution_count": 33, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"origin = tf.Tensor(\n", | |
"[[ 0. 0. 0. ]\n", | |
" [ 0. 0. 54.485176]\n", | |
" [86.468796 74.000046 0. ]\n", | |
" [ 0. 29.033602 69.07481 ]], shape=(4, 3), dtype=float32)\n", | |
"argmax : tf.Tensor([2 2 3], shape=(3,), dtype=int64)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Jl9brJX2PK38", | |
"outputId": "0413e344-8e8c-4e30-87e5-d71e99cb8e92", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
} | |
}, | |
"source": [ | |
"origin = tf.nn.relu(tf.random.normal([4,3,2])*100)\n", | |
"print(\"origin = \",origin)\n", | |
"# 维度0下求最大值位置,这里对维度0展开会得到二位的tensor,所以得到的最值得位置也会是二维坐标\n", | |
"print(\"argmax : \",tf.argmax(origin))" | |
], | |
"execution_count": 34, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"origin = tf.Tensor(\n", | |
"[[[ 0. 44.180485]\n", | |
" [ 60.554047 54.124874]\n", | |
" [ 11.455048 54.916447]]\n", | |
"\n", | |
" [[ 0. 70.35009 ]\n", | |
" [147.33435 110.680046]\n", | |
" [ 0. 59.37093 ]]\n", | |
"\n", | |
" [[ 0. 20.160051]\n", | |
" [ 0. 24.07408 ]\n", | |
" [ 0. 0. ]]\n", | |
"\n", | |
" [[ 84.53291 0. ]\n", | |
" [131.28426 103.82523 ]\n", | |
" [192.91162 31.05115 ]]], shape=(4, 3, 2), dtype=float32)\n", | |
"argmax : tf.Tensor(\n", | |
"[[3 1]\n", | |
" [1 1]\n", | |
" [3 1]], shape=(3, 2), dtype=int64)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "kKwO6iFZQb1c" | |
}, | |
"source": [ | |
"## tf.equal\n", | |
"用于比较" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Xy3xNeBqXJ9J", | |
"outputId": "ffe37aac-7c46-4964-effa-bf4f7eedc508", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"a = tf.constant([1,2,3,4,5])\n", | |
"b = tf.constant(range(5))\n", | |
"print(\"a = \",a,\", b = \",b)\n", | |
"result = tf.equal(a,b)\n", | |
"print(\"Equal : \",result)\n", | |
"cast_to_int = tf.reduce_sum(tf.cast(result, dtype=tf.int32))\n", | |
"print(\"To int32 : \",cast_to_int)" | |
], | |
"execution_count": 37, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"a = tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32) , b = tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n", | |
"Equal : tf.Tensor([False False False False False], shape=(5,), dtype=bool)\n", | |
"To int32 : tf.Tensor(0, shape=(), dtype=int32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "xPUymLEAZHzf" | |
}, | |
"source": [ | |
"tf.equal在精确度计算过上似乎有点用。例如,当有一个测试数据集,你的模型跑出来的预测值和测试数据的y做一次equal,然后cast成一个数字,根据大小可以判断accuracy。(也就是相同的部分是准确预测的)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "QWOFm0wDhTKm" | |
}, | |
"source": [ | |
"## tf.unique \n", | |
"tf.unique能得到一个包含tensor中所有元素的“set”,并且得到另一个idx的tensor用户标注每一个元素在得到的set里的下标。例如:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "hEjDBo8siJjZ", | |
"outputId": "59c49ba0-3b02-49b6-d29c-b86a32a186dd", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"origin = tf.constant([4,2,2,4,3])\n", | |
"result, idx = tf.unique(origin) \n", | |
"print(\"origin = \",origin,\"\\nunique : \",result,\", \\nidx : \",idx)" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"origin = tf.Tensor([4 2 2 4 3], shape=(5,), dtype=int32) \n", | |
"unique : tf.Tensor([4 2 3], shape=(3,), dtype=int32) , \n", | |
"idx : tf.Tensor([0 1 1 0 2], shape=(5,), dtype=int32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "We6BP_RfjP69" | |
}, | |
"source": [ | |
"回忆一下tf2基本操作中的gather,我们可以通过得到的结果的得到的idx将它复原" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nGRU51jujB8V", | |
"outputId": "be9a8a66-4c1e-4b22-93b2-7765afa0f841", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"origin = tf.constant([4,2,2,4,3])\n", | |
"result, idx = tf.unique(origin) \n", | |
"print(\"origin = \",origin,\"\\nunique : \",result,\", \\nidx : \",idx)\n", | |
"print(\"using tf.gather : \",tf.gather(result,idx))" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"origin = tf.Tensor([4 2 2 4 3], shape=(5,), dtype=int32) \n", | |
"unique : tf.Tensor([4 2 3], shape=(3,), dtype=int32) , \n", | |
"idx : tf.Tensor([0 1 1 0 2], shape=(5,), dtype=int32)\n", | |
"using tf.gather : tf.Tensor([4 2 2 4 3], shape=(5,), dtype=int32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment