Created
July 4, 2016 09:50
-
-
Save yusuke0519/6d246c1d417c31f373ce7c138767f928 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Implementing KL aneeling in Keras\n", | |
"- [Generating Sentences from a Continuous Space](https://arxiv.org/pdf/1511.06349.pdf)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using Theano backend.\n", | |
"Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN 5005)\n" | |
] | |
} | |
], | |
"source": [ | |
"from keras import backend as K" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"hp_lambda = K.variable(0) # default values\n", | |
"\n", | |
"def original_loss(y_true, y_pred):\n", | |
" loss = K.categorical_crossentropy(y_true, y_pred)\n", | |
" KL = K.variable([100])\n", | |
" return loss + (hp_lambda) * KL" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"y_pred = K.placeholder(ndim=1)\n", | |
"y_true = K.placeholder(ndim=1)\n", | |
"f = K.function([y_true, y_pred], original_loss(y_true, y_pred))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"When lambda=0.0: [ 1.19209304e-07]\n" | |
] | |
} | |
], | |
"source": [ | |
"K.set_value(hp_lambda, 0)\n", | |
"print(\"When lambda={}: {}\".format(K.get_value(hp_lambda), f([[0, 1, 0], [0, 1, 0]])))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"When lambda=1.0: [ 100.]\n" | |
] | |
} | |
], | |
"source": [ | |
"K.set_value(hp_lambda, 1)\n", | |
"print(\"When lambda={}: {}\".format(K.get_value(hp_lambda), f([[0, 1, 0], [0, 1, 0]])))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"- K.set_valueをepochごとに呼び出せば良い\n", | |
"- やり方としては\n", | |
" 1. エポックごとにset_valueを実行\n", | |
" 2. hp_lambdaをアップデートするコールバックを書く\n", | |
"- の2通りありそう\n", | |
"- 後者に関しては[url](https://github.com/fchollet/keras/blob/master/keras/callbacks.py)のLerningRateSchedulerが参考になりそう\n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment