Last active
July 19, 2022 13:26
-
-
Save kyamagu/4d268f9cbb951ccf3fc6af139739e3df to your computer and use it in GitHub Desktop.
Keras warmup learning rate
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "7d79506f", | |
"metadata": {}, | |
"source": [ | |
"# Keras warmup learning rate\n", | |
"\n", | |
"The following demonstrates how to implement learning rate warmup for any schedule." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0637e3be", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"import tensorflow as tf\n", | |
"from typing import Any, Dict, Optional, Union\n", | |
"\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "871a419f", | |
"metadata": {}, | |
"source": [ | |
"## Class definition" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b5027a04", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@tf.keras.utils.register_keras_serializable()\n", | |
"class Warmup(tf.keras.optimizers.schedules.LearningRateSchedule): # type: ignore\n", | |
" \"\"\"Apply warmup steps to the given learning rate schedule.\n", | |
"\n", | |
" Warmup linearly interpolates from the initial learning rate to the\n", | |
"\n", | |
" Args:\n", | |
" base_schedule: A `LearningRateSchedule` object.\n", | |
" warmup_steps: A scalar `int32` or `int64` `Tensor` or a Python number.\n", | |
" Must be positive.\n", | |
" initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a\n", | |
" Python number. Defaults to 0.0.\n", | |
" name: String. Optional name of the operation. Defaults to \"Warmup\".\n", | |
" \"\"\"\n", | |
"\n", | |
" def __init__(\n", | |
" self,\n", | |
" base_schedule: Union[\n", | |
" tf.keras.optimizers.schedules.LearningRateSchedule, str, Dict[str, Any]\n", | |
" ],\n", | |
" warmup_steps: int,\n", | |
" initial_learning_rate: float = 0.0,\n", | |
" name: Optional[str] = None,\n", | |
" ):\n", | |
" super().__init__()\n", | |
" if warmup_steps < 0:\n", | |
" raise ValueError(f\"Warmup steps must be positive: {warmup_steps}\")\n", | |
" if not isinstance(\n", | |
" base_schedule, tf.keras.optimizers.schedules.LearningRateSchedule\n", | |
" ):\n", | |
" base_schedule = tf.keras.optimizers.schedules.deserialize(base_schedule)\n", | |
" self.base_schedule = base_schedule\n", | |
" self.warmup_steps = warmup_steps\n", | |
" self.initial_learning_rate = initial_learning_rate\n", | |
" self.name = name\n", | |
"\n", | |
" self.warmup_schedule = tf.keras.optimizers.schedules.PolynomialDecay(\n", | |
" initial_learning_rate=self.initial_learning_rate,\n", | |
" end_learning_rate=base_schedule(0),\n", | |
" decay_steps=self.warmup_steps,\n", | |
" name=self.name or \"Warmup\",\n", | |
" )\n", | |
"\n", | |
" def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor:\n", | |
" global_step = tf.convert_to_tensor(step)\n", | |
" warmup_steps = tf.cast(self.warmup_steps, dtype=global_step.dtype)\n", | |
" return tf.where(\n", | |
" tf.less(global_step, warmup_steps),\n", | |
" self.warmup_schedule(global_step),\n", | |
" self.base_schedule(tf.subtract(global_step, warmup_steps)),\n", | |
" )\n", | |
"\n", | |
" def get_config(self) -> Dict[str, Any]:\n", | |
" return {\n", | |
" \"base_schedule\": tf.keras.optimizers.schedules.serialize(\n", | |
" self.base_schedule\n", | |
" ),\n", | |
" \"warmup_steps\": self.warmup_steps,\n", | |
" \"initial_learning_rate\": self.initial_learning_rate,\n", | |
" \"name\": self.name,\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a0b4e731", | |
"metadata": {}, | |
"source": [ | |
"## Usage\n", | |
"\n", | |
"One can combine any base schedule." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ff2ab4cb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"steps = 100\n", | |
"warmup_steps = 10\n", | |
"\n", | |
"# Pick any base schedule of your choice.\n", | |
"base_schedule = tf.keras.optimizers.schedules.ExponentialDecay(\n", | |
" initial_learning_rate=1e-4,\n", | |
" decay_steps=30,\n", | |
" decay_rate=0.1,\n", | |
")\n", | |
"schedule = Warmup(base_schedule, warmup_steps)\n", | |
"\n", | |
"plt.plot(schedule(tf.range(steps)).numpy())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2e765631", | |
"metadata": {}, | |
"source": [ | |
"## Serialization\n", | |
"\n", | |
"The class is serializable thanks to the `get_config` method and the `register_keras_serializable` decorator." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c1627992", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"schedule = tf.keras.optimizers.schedules.deserialize(\n", | |
" tf.keras.optimizers.schedules.serialize(schedule)\n", | |
")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "ailab-psd-processing", | |
"language": "python", | |
"name": "ailab-psd-processing" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment