Skip to content

Instantly share code, notes, and snippets.

@kyamagu
Last active July 19, 2022 13:26
Show Gist options
  • Save kyamagu/4d268f9cbb951ccf3fc6af139739e3df to your computer and use it in GitHub Desktop.
Save kyamagu/4d268f9cbb951ccf3fc6af139739e3df to your computer and use it in GitHub Desktop.
Keras warmup learning rate
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "7d79506f",
"metadata": {},
"source": [
"# Keras warmup learning rate\n",
"\n",
"The following demonstrates how to implement learning rate warmup for any schedule."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0637e3be",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf\n",
"from typing import Any, Dict, Optional, Union\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"id": "871a419f",
"metadata": {},
"source": [
"## Class definition"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5027a04",
"metadata": {},
"outputs": [],
"source": [
"@tf.keras.utils.register_keras_serializable()\n",
"class Warmup(tf.keras.optimizers.schedules.LearningRateSchedule): # type: ignore\n",
" \"\"\"Apply warmup steps to the given learning rate schedule.\n",
"\n",
" Warmup linearly interpolates from the initial learning rate to the\n",
"\n",
" Args:\n",
" base_schedule: A `LearningRateSchedule` object.\n",
" warmup_steps: A scalar `int32` or `int64` `Tensor` or a Python number.\n",
" Must be positive.\n",
" initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a\n",
" Python number. Defaults to 0.0.\n",
" name: String. Optional name of the operation. Defaults to \"Warmup\".\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" base_schedule: Union[\n",
" tf.keras.optimizers.schedules.LearningRateSchedule, str, Dict[str, Any]\n",
" ],\n",
" warmup_steps: int,\n",
" initial_learning_rate: float = 0.0,\n",
" name: Optional[str] = None,\n",
" ):\n",
" super().__init__()\n",
" if warmup_steps < 0:\n",
" raise ValueError(f\"Warmup steps must be positive: {warmup_steps}\")\n",
" if not isinstance(\n",
" base_schedule, tf.keras.optimizers.schedules.LearningRateSchedule\n",
" ):\n",
" base_schedule = tf.keras.optimizers.schedules.deserialize(base_schedule)\n",
" self.base_schedule = base_schedule\n",
" self.warmup_steps = warmup_steps\n",
" self.initial_learning_rate = initial_learning_rate\n",
" self.name = name\n",
"\n",
" self.warmup_schedule = tf.keras.optimizers.schedules.PolynomialDecay(\n",
" initial_learning_rate=self.initial_learning_rate,\n",
" end_learning_rate=base_schedule(0),\n",
" decay_steps=self.warmup_steps,\n",
" name=self.name or \"Warmup\",\n",
" )\n",
"\n",
" def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor:\n",
" global_step = tf.convert_to_tensor(step)\n",
" warmup_steps = tf.cast(self.warmup_steps, dtype=global_step.dtype)\n",
" return tf.where(\n",
" tf.less(global_step, warmup_steps),\n",
" self.warmup_schedule(global_step),\n",
" self.base_schedule(tf.subtract(global_step, warmup_steps)),\n",
" )\n",
"\n",
" def get_config(self) -> Dict[str, Any]:\n",
" return {\n",
" \"base_schedule\": tf.keras.optimizers.schedules.serialize(\n",
" self.base_schedule\n",
" ),\n",
" \"warmup_steps\": self.warmup_steps,\n",
" \"initial_learning_rate\": self.initial_learning_rate,\n",
" \"name\": self.name,\n",
" }"
]
},
{
"cell_type": "markdown",
"id": "a0b4e731",
"metadata": {},
"source": [
"## Usage\n",
"\n",
"One can combine any base schedule."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff2ab4cb",
"metadata": {},
"outputs": [],
"source": [
"steps = 100\n",
"warmup_steps = 10\n",
"\n",
"# Pick any base schedule of your choice.\n",
"base_schedule = tf.keras.optimizers.schedules.ExponentialDecay(\n",
" initial_learning_rate=1e-4,\n",
" decay_steps=30,\n",
" decay_rate=0.1,\n",
")\n",
"schedule = Warmup(base_schedule, warmup_steps)\n",
"\n",
"plt.plot(schedule(tf.range(steps)).numpy())"
]
},
{
"cell_type": "markdown",
"id": "2e765631",
"metadata": {},
"source": [
"## Serialization\n",
"\n",
"The class is serializable thanks to the `get_config` method and the `register_keras_serializable` decorator."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1627992",
"metadata": {},
"outputs": [],
"source": [
"schedule = tf.keras.optimizers.schedules.deserialize(\n",
" tf.keras.optimizers.schedules.serialize(schedule)\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ailab-psd-processing",
"language": "python",
"name": "ailab-psd-processing"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment