Last active
May 6, 2019 18:53
-
-
Save rxwei/2739515f77a62d26add66947cb179911 to your computer and use it in GitHub Desktop.
Differentiable reduction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Differentiable reduction", | |
"version": "0.3.2", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "swift", | |
"display_name": "Swift" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/rxwei/2739515f77a62d26add66947cb179911/differentiable-reduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "kZRlD4utdPuX", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"public extension Array where Element: Differentiable {\n", | |
" func differentiableReduce<Result: Differentiable>(\n", | |
" _ initialResult: Result,\n", | |
" _ nextPartialResult: @differentiable (Result, Element) -> Result\n", | |
" ) -> Result {\n", | |
" return reduce(initialResult, nextPartialResult)\n", | |
" }\n", | |
"\n", | |
" @usableFromInline\n", | |
" @differentiating(differentiableReduce(_:_:), wrt: (self, initialResult))\n", | |
" internal func reduceDerivative<Result: Differentiable>(\n", | |
" _ initialResult: Result,\n", | |
" _ nextPartialResult: @differentiable (Result, Element) -> Result\n", | |
" ) -> (value: Result,\n", | |
" pullback: (Result.CotangentVector) -> (Array.CotangentVector, Result.CotangentVector)) {\n", | |
" var pullbacks: [(Result.CotangentVector) -> (Result.CotangentVector, Element.CotangentVector)] = []\n", | |
" let count = self.count\n", | |
" pullbacks.reserveCapacity(count)\n", | |
" var result = initialResult\n", | |
" for element in self {\n", | |
" let (y, pb) = Swift.valueWithPullback(at: result, element, in: nextPartialResult)\n", | |
" result = y\n", | |
" pullbacks.append(pb)\n", | |
" }\n", | |
" return (value: result, pullback: { cotangent in\n", | |
" var resultCotangent = cotangent\n", | |
" var elementCotangents = CotangentVector([])\n", | |
" elementCotangents.base.reserveCapacity(count)\n", | |
" for pullback in pullbacks.reversed() {\n", | |
" let (newResultCotangent, elementCotangent) = pullback(resultCotangent)\n", | |
" resultCotangent = newResultCotangent\n", | |
" elementCotangents.base.append(elementCotangent)\n", | |
" }\n", | |
" return (CotangentVector(elementCotangents.base.reversed()), resultCotangent)\n", | |
" })\n", | |
" }\n", | |
"}" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "5dZqoBgCseON", | |
"colab_type": "code", | |
"outputId": "2f2e7d7b-b56f-432c-bf5a-c2dd61ab6cc3", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 170 | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"let xx = [1.0, 2.0, 3.0, 4.0, 5.0]\n", | |
"let initial = 1.0\n", | |
"valueWithGradient(at: xx) { xx in\n", | |
" xx.differentiableReduce(initial) { $0 * $1 }\n", | |
"}" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"▿ 2 elements\n", | |
" - value : 120.0\n", | |
" ▿ gradient : DifferentiableView\n", | |
" ▿ _base : 5 elements\n", | |
" - 0 : 120.0\n", | |
" - 1 : 60.0\n", | |
" - 2 : 40.0\n", | |
" - 3 : 30.0\n", | |
" - 4 : 24.0\n" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 2 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "lqfaQ2Ozusdi", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment