Skip to content

Instantly share code, notes, and snippets.

@ceshine
Created December 30, 2019 09:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ceshine/993d2ea9950e50f443a9841c6524bf81 to your computer and use it in GitHub Desktop.
Save ceshine/993d2ea9950e50f443a9841c6524bf81 to your computer and use it in GitHub Desktop.
LR Scheduler which works with TF 2.x distributed mode
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensorflow version 2.0.0\n"
]
}
],
"source": [
"import math\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.python.framework import constant_op\n",
"from tensorflow.python.framework import ops\n",
"from tensorflow.python.ops import control_flow_ops\n",
"from tensorflow.python.ops import math_ops\n",
"from tensorflow.python.keras.optimizer_v2.learning_rate_schedule import LearningRateSchedule\n",
"\n",
"from matplotlib import pyplot as plt\n",
"\n",
"print(\"Tensorflow version \" + tf.__version__)\n",
"\n",
"class CosineDecayWithWarmup(LearningRateSchedule):\n",
" \"\"\"A LearningRateSchedule that uses a cosine decay schedule.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" initial_learning_rate,\n",
" max_learning_rate,\n",
" warmup_steps,\n",
" decay_steps,\n",
" alpha=0.0,\n",
" name=None):\n",
" super().__init__()\n",
" self.initial_learning_rate = initial_learning_rate\n",
" self.max_learning_rate = max_learning_rate\n",
" self.warmup_steps = warmup_steps\n",
" self.decay_steps = decay_steps\n",
" self.alpha = alpha\n",
" self.name = name\n",
"\n",
" @staticmethod\n",
" def lr_warmup(steps, warmup_steps, max_learning_rate, initial_learning_rate):\n",
" return initial_learning_rate + (\n",
" max_learning_rate - initial_learning_rate\n",
" ) * (steps / warmup_steps)\n",
"\n",
" @staticmethod\n",
" def cosine_decay(steps, warmup_steps, decay_steps, max_learning_rate, alpha):\n",
" completed_fraction = (\n",
" steps - warmup_steps) / decay_steps\n",
" cosine_decayed = 0.5 * (1.0 + math_ops.cos(\n",
" constant_op.constant(math.pi) * completed_fraction))\n",
" decayed = (1 - alpha) * cosine_decayed + alpha\n",
" return math_ops.multiply(max_learning_rate, decayed)\n",
"\n",
" def __call__(self, step):\n",
" with ops.name_scope_v2(self.name or \"CosineDecayWithWarmup\"):\n",
" initial_learning_rate = ops.convert_to_tensor(\n",
" self.initial_learning_rate, name=\"initial_learning_rate\")\n",
" max_learning_rate = ops.convert_to_tensor(\n",
" self.max_learning_rate, name=\"initial_learning_rate\")\n",
" dtype = initial_learning_rate.dtype\n",
" decay_steps = math_ops.cast(self.decay_steps, dtype)\n",
" warmup_steps = math_ops.cast(self.warmup_steps, dtype)\n",
" total_steps = decay_steps + warmup_steps\n",
"\n",
" global_step_recomp = math_ops.cast(step, dtype)\n",
" global_step_recomp = math_ops.minimum(\n",
" global_step_recomp, total_steps)\n",
"\n",
" return control_flow_ops.cond(\n",
" math_ops.less_equal(global_step_recomp, warmup_steps),\n",
" lambda: self.lr_warmup(\n",
" global_step_recomp, warmup_steps, max_learning_rate,\n",
" initial_learning_rate\n",
" ),\n",
" lambda: self.cosine_decay(\n",
" global_step_recomp, warmup_steps, decay_steps,\n",
" max_learning_rate, self.alpha\n",
" )\n",
" )\n",
"\n",
" def get_config(self):\n",
" return {\n",
" \"initial_learning_rate\": self.initial_learning_rate,\n",
" \"decay_steps\": self.decay_steps,\n",
" \"alpha\": self.alpha,\n",
" \"name\": self.name\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"lr_schedule = CosineDecayWithWarmup(\n",
" initial_learning_rate=1e-5, max_learning_rate=5e-4,\n",
" warmup_steps=200,\n",
" decay_steps=800,\n",
" alpha=1e-6\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Learning rate schedule:\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"print(\"Learning rate schedule:\")\n",
"rng = [i for i in range(1000)]\n",
"plt.plot(rng, [lr_schedule(x) for x in rng])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment