Skip to content

Instantly share code, notes, and snippets.

@artste
Created May 8, 2019 16:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save artste/0ff97bc5f80693b101be206ca3a79c5b to your computer and use it in GitHub Desktop.
Save artste/0ff97bc5f80693b101be206ca3a79c5b to your computer and use it in GitHub Desktop.
s4tf_replacing_test.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "s4tf_replacing_test.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "swift",
"display_name": "Swift"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/artste/0ff97bc5f80693b101be206ca3a79c5b/s4tf_replacing_test.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NucAA_M-N-0a",
"colab_type": "text"
},
"source": [
"# replacing(with:where:)\n",
"*Replaces elements of this tensor with other in the lanes **where mask is true**.*\n",
"\n",
"PRECONDITION\n",
"\n",
"self and other must have the same shape. If self and other are scalar, then mask must also be scalar. If self and other have rank greater than or equal to 1, then mask must be either have the same shape as self or be a 1-D Tensor such that mask.scalarCount == self.shape[0].\n",
"\n",
"DOCS: https://www.tensorflow.org/swift/api_docs/Structs/Tensor#replacingwith:where:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kZRlD4utdPuX",
"colab_type": "code",
"outputId": "7e8940d9-93ce-4fed-8df6-78a54f2c6138",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 527
}
},
"source": [
"import TensorFlow\n",
"\n",
"typealias TF = Tensor<Float>\n",
"\n",
"// Initialize random numbers\n",
"let rr = TF(randomNormal: [10,5])\n",
"print(\"rr = \",rr)\n",
"\n",
"// Create a mask\n",
"let mask = rr.<(0-1.0)\n",
"print(\"mask = \", mask) \n",
"\n",
"// Replaced\n",
"let replaced = rr.replacing(with: TF(0).broadcast(like: rr), where: mask)\n",
"print(\"replaced = \",replaced)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"rr = [[ -0.2687458, 1.6474121, -0.061800487, -1.7457179, -0.3315584],\r\n",
" [ 1.3515772, 0.6790431, 0.14319876, 1.7425705, -1.9664636],\r\n",
" [ -0.32543635, 0.75455797, 0.9851794, -0.12352676, 0.029595692],\r\n",
" [ 0.54839504, -0.30570582, 1.7317035, 0.45856386, 0.8892455],\r\n",
" [ 0.14538142, -0.0019000744, -0.2195302, -0.3196175, -0.02673261],\r\n",
" [ -0.60845137, -0.36677366, 1.3494298, -1.3287013, -1.6256953],\r\n",
" [ 0.8583815, 1.1418674, -0.6815512, 1.0948774, 0.20448415],\r\n",
" [ -0.68553835, -0.9695941, -0.50244117, 1.037796, 0.70121026],\r\n",
" [ -0.5385072, 1.1612647, 1.7953675, 0.65119404, 1.5617983],\r\n",
" [ -1.237274, 1.0212688, -0.5734267, 0.91085374, 0.3272885]]\r\n",
"mask = [[false, false, false, true, false],\r\n",
" [false, false, false, false, true],\r\n",
" [false, false, false, false, false],\r\n",
" [false, false, false, false, false],\r\n",
" [false, false, false, false, false],\r\n",
" [false, false, false, true, true],\r\n",
" [false, false, false, false, false],\r\n",
" [false, false, false, false, false],\r\n",
" [false, false, false, false, false],\r\n",
" [ true, false, false, false, false]]\r\n",
"replaced = [[ 0.0, 0.0, 0.0, -1.7457179, 0.0],\r\n",
" [ 0.0, 0.0, 0.0, 0.0, -1.9664636],\r\n",
" [ 0.0, 0.0, 0.0, 0.0, 0.0],\r\n",
" [ 0.0, 0.0, 0.0, 0.0, 0.0],\r\n",
" [ 0.0, 0.0, 0.0, 0.0, 0.0],\r\n",
" [ 0.0, 0.0, 0.0, -1.3287013, -1.6256953],\r\n",
" [ 0.0, 0.0, 0.0, 0.0, 0.0],\r\n",
" [ 0.0, 0.0, 0.0, 0.0, 0.0],\r\n",
" [ 0.0, 0.0, 0.0, 0.0, 0.0],\r\n",
" [ -1.237274, 0.0, 0.0, 0.0, 0.0]]\r\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xYYoG1BYRFhx",
"colab_type": "text"
},
"source": [
"Seems to replace values where **mask is false**."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uqtlykK-PvfF",
"colab_type": "text"
},
"source": [
"## THE API I WOULD LIKE\n",
"```swift\n",
"rr[mask] = TF(123)\n",
"```\n",
"\n",
"Probably a new TensorRange case is needed for that...\n",
"https://www.tensorflow.org/swift/api_docs/Enums/TensorRange#==_:_:"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment