Last active
August 20, 2020 08:26
-
-
Save OnlyBelter/31fc0f61dd94117d170111bb50236197 to your computer and use it in GitHub Desktop.
custom constraint function in TF2.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "custom constraint function in TF2.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/OnlyBelter/31fc0f61dd94117d170111bb50236197/untitled10.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "RFaiLjDlyYOP", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 973 | |
}, | |
"outputId": "1f87433a-4fae-4744-d8e3-e42f18c831e1" | |
}, | |
"source": [ | |
"!pip install tensorflow==2.2.0" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting tensorflow==2.2.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/3d/be/679ce5254a8c8d07470efb4a4c00345fae91f766e64f1c2aece8796d7218/tensorflow-2.2.0-cp36-cp36m-manylinux2010_x86_64.whl (516.2MB)\n", | |
"\u001b[K |████████████████████████████████| 516.2MB 30kB/s \n", | |
"\u001b[?25hCollecting tensorflow-estimator<2.3.0,>=2.2.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a4/f5/926ae53d6a226ec0fda5208e0e581cffed895ccc89e36ba76a8e60895b78/tensorflow_estimator-2.2.0-py2.py3-none-any.whl (454kB)\n", | |
"\u001b[K |████████████████████████████████| 460kB 49.3MB/s \n", | |
"\u001b[?25hRequirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (0.3.3)\n", | |
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (3.3.0)\n", | |
"Collecting tensorboard<2.3.0,>=2.2.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/1d/74/0a6fcb206dcc72a6da9a62dd81784bfdbff5fedb099982861dc2219014fb/tensorboard-2.2.2-py3-none-any.whl (3.0MB)\n", | |
"\u001b[K |████████████████████████████████| 3.0MB 45.6MB/s \n", | |
"\u001b[?25hRequirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.15.0)\n", | |
"Requirement already satisfied: keras-preprocessing>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.1.2)\n", | |
"Requirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (0.2.0)\n", | |
"Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.6.3)\n", | |
"Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.1.0)\n", | |
"Requirement already satisfied: numpy<2.0,>=1.16.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.18.5)\n", | |
"Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.31.0)\n", | |
"Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (0.34.2)\n", | |
"Requirement already satisfied: scipy==1.4.1; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.4.1)\n", | |
"Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (2.10.0)\n", | |
"Requirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (3.12.4)\n", | |
"Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.12.1)\n", | |
"Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (0.9.0)\n", | |
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (3.2.2)\n", | |
"Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (1.17.2)\n", | |
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (0.4.1)\n", | |
"Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (2.23.0)\n", | |
"Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (49.2.0)\n", | |
"Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (1.7.0)\n", | |
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (1.0.1)\n", | |
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (1.7.0)\n", | |
"Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (4.6)\n", | |
"Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (4.1.1)\n", | |
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (0.2.8)\n", | |
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (1.3.0)\n", | |
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (3.0.4)\n", | |
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (1.24.3)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (2020.6.20)\n", | |
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (2.10)\n", | |
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (3.1.0)\n", | |
"Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<5,>=3.1.4; python_version >= \"3\"->google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (0.4.8)\n", | |
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (3.1.0)\n", | |
"Installing collected packages: tensorflow-estimator, tensorboard, tensorflow\n", | |
" Found existing installation: tensorflow-estimator 2.3.0\n", | |
" Uninstalling tensorflow-estimator-2.3.0:\n", | |
" Successfully uninstalled tensorflow-estimator-2.3.0\n", | |
" Found existing installation: tensorboard 2.3.0\n", | |
" Uninstalling tensorboard-2.3.0:\n", | |
" Successfully uninstalled tensorboard-2.3.0\n", | |
" Found existing installation: tensorflow 2.3.0\n", | |
" Uninstalling tensorflow-2.3.0:\n", | |
" Successfully uninstalled tensorflow-2.3.0\n", | |
"Successfully installed tensorboard-2.2.2 tensorflow-2.2.0 tensorflow-estimator-2.2.0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9kdSSwqN7plW", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "48628b14-a295-4eac-d39f-a01f7a560ecf" | |
}, | |
"source": [ | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"from tensorflow import keras\n", | |
"tf.__version__" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
}, | |
"text/plain": [ | |
"'2.2.0'" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 2 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pIIghQ18qj89", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 191 | |
}, | |
"outputId": "d63895db-4536-4d8a-e6d6-8f39b8c490f9" | |
}, | |
"source": [ | |
"x = np.random.randint(10, size=(10, 3))\n", | |
"x" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[3, 1, 1],\n", | |
" [8, 1, 7],\n", | |
" [8, 2, 1],\n", | |
" [1, 4, 9],\n", | |
" [5, 1, 9],\n", | |
" [6, 7, 1],\n", | |
" [4, 6, 0],\n", | |
" [5, 3, 9],\n", | |
" [3, 0, 0],\n", | |
" [8, 6, 2]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Iym6UbmYrtWG", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 191 | |
}, | |
"outputId": "9b989e9a-b1bb-4e6e-a14e-9cd7f08782ea" | |
}, | |
"source": [ | |
"y = np.random.randint(20, size=(10, 1))\n", | |
"y" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[11],\n", | |
" [ 1],\n", | |
" [ 6],\n", | |
" [12],\n", | |
" [18],\n", | |
" [ 0],\n", | |
" [ 7],\n", | |
" [15],\n", | |
" [18],\n", | |
" [17]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gzCzordRtHLO", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"class SoftMax(tf.keras.constraints.Constraint):\n", | |
" \"\"\"Constrains weight tensors to be centered around `ref_value`.\n", | |
" refer to: https://keras.io/api/layers/constraints/\n", | |
" \"\"\"\n", | |
"\n", | |
" def __call__(self, w):\n", | |
" return tf.nn.softmax(w, axis=0)" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Kk-u8esCr-p6", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def regression_model(x, y, lr, epochs, verbose=1, result_dir='.'):\n", | |
" print(y.shape)\n", | |
" model = keras.Sequential([keras.layers.Dense(y.shape[1], activation='softplus', \n", | |
" kernel_constraint=SoftMax(),\n", | |
" bias_constraint=keras.constraints.MinMaxNorm(min_value=-100, max_value=100),\n", | |
" input_shape=[x.shape[1]])])\n", | |
" opt = keras.optimizers.RMSprop(learning_rate=lr)\n", | |
" # opt = keras.optimizers.SGD(lr=lr, momentum=0.9, nesterov=True, decay=1e-4)\n", | |
" model.compile(optimizer=opt, loss='mse',\n", | |
" metrics=['mae'])\n", | |
" history = model.fit(x, y, epochs=epochs, batch_size=8, verbose=verbose)\n", | |
" return model" | |
], | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2k86JOFTssGi", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 728 | |
}, | |
"outputId": "a894642e-65ce-4a7d-993d-379ce06d99bb" | |
}, | |
"source": [ | |
"model = regression_model(x=x, y=y, lr=1.2, epochs=20)" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"(10, 1)\n", | |
"Epoch 1/20\n", | |
"2/2 [==============================] - 0s 3ms/step - loss: 59.5285 - mae: 6.3197\n", | |
"Epoch 2/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 84.6337 - mae: 7.7409\n", | |
"Epoch 3/20\n", | |
"2/2 [==============================] - 0s 1ms/step - loss: 50.9881 - mae: 5.9794\n", | |
"Epoch 4/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 48.8637 - mae: 5.9349\n", | |
"Epoch 5/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 47.3336 - mae: 5.5083\n", | |
"Epoch 6/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 51.5201 - mae: 6.3404\n", | |
"Epoch 7/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 44.7624 - mae: 5.4833\n", | |
"Epoch 8/20\n", | |
"2/2 [==============================] - 0s 1ms/step - loss: 47.1421 - mae: 5.4710\n", | |
"Epoch 9/20\n", | |
"2/2 [==============================] - 0s 3ms/step - loss: 49.1786 - mae: 5.7191\n", | |
"Epoch 10/20\n", | |
"2/2 [==============================] - 0s 1ms/step - loss: 57.8849 - mae: 6.4477\n", | |
"Epoch 11/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 47.6442 - mae: 5.7809\n", | |
"Epoch 12/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 50.2724 - mae: 5.6781\n", | |
"Epoch 13/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 47.2194 - mae: 5.5349\n", | |
"Epoch 14/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 44.0376 - mae: 5.3864\n", | |
"Epoch 15/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 53.2405 - mae: 6.0772\n", | |
"Epoch 16/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 52.0939 - mae: 5.4610\n", | |
"Epoch 17/20\n", | |
"2/2 [==============================] - 0s 3ms/step - loss: 43.9738 - mae: 5.4538\n", | |
"Epoch 18/20\n", | |
"2/2 [==============================] - 0s 1ms/step - loss: 46.2081 - mae: 5.6604\n", | |
"Epoch 19/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 44.6553 - mae: 5.2687\n", | |
"Epoch 20/20\n", | |
"2/2 [==============================] - 0s 2ms/step - loss: 52.0702 - mae: 5.7864\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "6TyRURiztSMR", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 69 | |
}, | |
"outputId": "9517798e-2a78-4a33-8ac0-8139ce26347c" | |
}, | |
"source": [ | |
"weights = model.get_weights()\n", | |
"weights[0], weights[1]" | |
], | |
"execution_count": 56, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(array([[0.20204712],\n", | |
" [0.35213786],\n", | |
" [0.44581497]], dtype=float32), array([-1.1485989], dtype=float32))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 56 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "d4nkDsdfteEw", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "ae694d5e-abbf-420a-c3d9-7b249c50b432" | |
}, | |
"source": [ | |
"np.sum(weights[0])" | |
], | |
"execution_count": 57, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.99999994" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 57 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "musxy45QvULB", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 69 | |
}, | |
"outputId": "caa75a49-64ae-4727-d096-95d2b04f7b3a" | |
}, | |
"source": [ | |
"tf.nn.softmax(np.array([[1.0,2.0,3.0], [2.0, 3.0, 10.0]]), axis=0)" | |
], | |
"execution_count": 52, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor: shape=(2, 3), dtype=float64, numpy=\n", | |
"array([[2.68941421e-01, 2.68941421e-01, 9.11051194e-04],\n", | |
" [7.31058579e-01, 7.31058579e-01, 9.99088949e-01]])>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 52 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "7iFwDiR_yXJ8", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9hFdsOKqvl4-", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment