Skip to content

Instantly share code, notes, and snippets.

@mani3
Last active December 10, 2019 07:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mani3/2d9c849175dad7bd97ea9204f3425898 to your computer and use it in GitHub Desktop.
Save mani3/2d9c849175dad7bd97ea9204f3425898 to your computer and use it in GitHub Desktop.
Custom differentiable in Swift for TensorFlow
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('inline', 'module://ipykernel.pylab.backend_inline')\n"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import TensorFlow\n",
"import Python\n",
"\n",
"%include \"EnableIPythonDisplay.swift\"\n",
"IPythonDisplay.shell.enable_matplotlib(\"inline\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.125\r\n"
]
}
],
"source": [
"import Glibc\n",
"print(pow(0.5, 3))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"error: <Cell 3>:2:33: error: expression is not differentiable\nvalueWithGradient(at: a) { a in pow(a, 3) }\n ^\n\n<Cell 3>:2:33: note: cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files\nvalueWithGradient(at: a) { a in pow(a, 3) }\n ^\n\n"
]
}
],
"source": [
"let a: Float = 0.5\n",
"valueWithGradient(at: a) { a in pow(a, 3) }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"@differentiableに対応していない場合、以下のようにエラーが表示されます。\n",
"\n",
"```\n",
"error: expression is not differentiable\n",
"```\n",
"\n",
"独自に pow の偏微分を記述してみます。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"▿ 2 elements\n",
" - value : 0.125\n",
" ▿ gradient : 2 elements\n",
" - .0 : 0.75\n",
" - .1 : -0.08664339756999316\n"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"func myPow(_ x: Double, _ num: Double) -> Double {\n",
" return pow(x, num)\n",
"}\n",
"\n",
"@derivative(of: myPow)\n",
"func myPowDerivative(_ x: Double, _ num: Double) -> (value: Double, pullback: (Double) -> (Double, Double)) {\n",
" return (\n",
" value: myPow(x, num),\n",
" pullback: {\n",
" chain in (\n",
" chain * num * myPow(x, num - 1),\n",
" chain * myPow(x, num) * log(x)\n",
" )\n",
" }\n",
" )\n",
"}\n",
"\n",
"@differentiable\n",
"func power(_ x: Double, _ num: Double) -> Double {\n",
" return myPow(x, num)\n",
"}\n",
"let a: Double = 0.5\n",
"let b: Double = 3\n",
"\n",
"valueWithGradient(at: a, b) { a, b in power(a, b) }"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Swift",
"language": "swift",
"name": "swift"
},
"language_info": {
"file_extension": ".swift",
"mimetype": "text/x-swift",
"name": "swift",
"version": ""
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment