Skip to content

Instantly share code, notes, and snippets.

@vmoens
Created June 27, 2023 19:19
Show Gist options
  • Save vmoens/06eb46c01bd03fe22739821f9e769faf to your computer and use it in GitHub Desktop.
Save vmoens/06eb46c01bd03fe22739821f9e769faf to your computer and use it in GitHub Desktop.
torchrl_reinforcement_ppo_solution.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/vmoens/06eb46c01bd03fe22739821f9e769faf/torchrl_reinforcement_ppo_solution.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "tcwI0yJh9PAv"
},
"outputs": [],
"source": [
"# For tips on running notebooks in Google Colab, see\n",
"# https://pytorch.org/tutorials/beginner/colab\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gyDKDGZA9PAy"
},
"source": [
"\n",
"# Reinforcement Learning (PPO) with TorchRL Tutorial\n",
"**Author**: [Vincent Moens](https://github.com/vmoens)\n",
"\n",
"This tutorial demonstrates how to use PyTorch and TorchRL to train a parametric policy network to solve the Inverted Pendulum task from the [OpenAI-Gym/Farama-Gymnasium\n",
"control library](https://github.com/Farama-Foundation/Gymnasium).\n",
"\n",
"Key learnings:\n",
"\n",
"- How to create an environment in TorchRL, transform its outputs, and collect data from this environment;\n",
"- How to make your classes talk to each other using TensorDict;\n",
"- The basics of building your training loop with TorchRL:\n",
"\n",
" - How to compute the advantage signal for policy gradient methods;\n",
" - How to create a stochastic policy using a probabilistic neural network;\n",
" - How to create a dynamic replay buffer and sample from it without repetition.\n",
"\n",
"We will cover six crucial components of TorchRL:\n",
"\n",
"* [environments](https://pytorch.org/rl/reference/envs.html)\n",
"* [transforms](https://pytorch.org/rl/reference/envs.html#transforms)\n",
"* [models (policy and value function)](https://pytorch.org/rl/reference/modules.html)\n",
"* [loss modules](https://pytorch.org/rl/reference/objectives.html)\n",
"* [data collectors](https://pytorch.org/rl/reference/collectors.html)\n",
"* [replay buffers](https://pytorch.org/rl/reference/data.html#replay-buffers)\n",
"\n",
"Please refer to the links above throughout this tutorial to gather information about TorchRL features and solve exercies!\n",
"Make sure you install the following dependencies:"
]
},
{
"cell_type": "code",
"source": [
"!pip3 install git+https://github.com/pytorch-labs/tensordict\n",
"!pip3 install git+https://github.com/pytorch/rl\n",
"!pip3 install gym[mujoco,atari,accept-rom-license]\n",
"!pip3 install dm_control\n",
"!pip3 install tqdm"
],
"metadata": {
"id": "HCVcbnZW6IQ9",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d876b641-ab4d-440a-be3f-86108a13ccb2"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting git+https://github.com/pytorch-labs/tensordict\n",
" Cloning https://github.com/pytorch-labs/tensordict to /tmp/pip-req-build-v8ywyhzr\n",
" Running command git clone --filter=blob:none --quiet https://github.com/pytorch-labs/tensordict /tmp/pip-req-build-v8ywyhzr\n",
" Resolved https://github.com/pytorch-labs/tensordict to commit 41424f8feb6416b4603d6480f1132caca742762b\n",
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from tensordict==0.1.2+41424f8) (2.0.1+cu118)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from tensordict==0.1.2+41424f8) (1.22.4)\n",
"Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from tensordict==0.1.2+41424f8) (2.2.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->tensordict==0.1.2+41424f8) (3.12.2)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->tensordict==0.1.2+41424f8) (4.6.3)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->tensordict==0.1.2+41424f8) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->tensordict==0.1.2+41424f8) (3.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->tensordict==0.1.2+41424f8) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch->tensordict==0.1.2+41424f8) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->tensordict==0.1.2+41424f8) (3.25.2)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->tensordict==0.1.2+41424f8) (16.0.6)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->tensordict==0.1.2+41424f8) (2.1.3)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->tensordict==0.1.2+41424f8) (1.3.0)\n",
"Collecting git+https://github.com/pytorch/rl\n",
" Cloning https://github.com/pytorch/rl to /tmp/pip-req-build-qa4yvhf3\n",
" Running command git clone --filter=blob:none --quiet https://github.com/pytorch/rl /tmp/pip-req-build-qa4yvhf3\n",
" Resolved https://github.com/pytorch/rl to commit ae1dd3ac0eebd6503c3c8f8e364505ca4bae1945\n",
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from torchrl==0.1.1+ae1dd3a) (2.0.1+cu118)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchrl==0.1.1+ae1dd3a) (1.22.4)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from torchrl==0.1.1+ae1dd3a) (23.1)\n",
"Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from torchrl==0.1.1+ae1dd3a) (2.2.1)\n",
"Requirement already satisfied: tensordict>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from torchrl==0.1.1+ae1dd3a) (0.1.2+41424f8)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->torchrl==0.1.1+ae1dd3a) (3.12.2)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->torchrl==0.1.1+ae1dd3a) (4.6.3)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->torchrl==0.1.1+ae1dd3a) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->torchrl==0.1.1+ae1dd3a) (3.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->torchrl==0.1.1+ae1dd3a) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch->torchrl==0.1.1+ae1dd3a) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->torchrl==0.1.1+ae1dd3a) (3.25.2)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->torchrl==0.1.1+ae1dd3a) (16.0.6)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->torchrl==0.1.1+ae1dd3a) (2.1.3)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->torchrl==0.1.1+ae1dd3a) (1.3.0)\n",
"Requirement already satisfied: gym[accept-rom-license,atari,mujoco] in /usr/local/lib/python3.10/dist-packages (0.25.2)\n",
"Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.10/dist-packages (from gym[accept-rom-license,atari,mujoco]) (1.22.4)\n",
"Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gym[accept-rom-license,atari,mujoco]) (2.2.1)\n",
"Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym[accept-rom-license,atari,mujoco]) (0.0.8)\n",
"Collecting autorom[accept-rom-license]~=0.4.2 (from gym[accept-rom-license,atari,mujoco])\n",
" Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB)\n",
"Requirement already satisfied: ale-py~=0.7.5 in /usr/local/lib/python3.10/dist-packages (from gym[accept-rom-license,atari,mujoco]) (0.7.5)\n",
"Collecting mujoco==2.2.0 (from gym[accept-rom-license,atari,mujoco])\n",
" Using cached mujoco-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)\n",
"Requirement already satisfied: imageio>=2.14.1 in /usr/local/lib/python3.10/dist-packages (from gym[accept-rom-license,atari,mujoco]) (2.25.1)\n",
"Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from mujoco==2.2.0->gym[accept-rom-license,atari,mujoco]) (1.4.0)\n",
"Requirement already satisfied: glfw in /usr/local/lib/python3.10/dist-packages (from mujoco==2.2.0->gym[accept-rom-license,atari,mujoco]) (2.6.1)\n",
"Requirement already satisfied: pyopengl in /usr/local/lib/python3.10/dist-packages (from mujoco==2.2.0->gym[accept-rom-license,atari,mujoco]) (3.1.7)\n",
"Requirement already satisfied: importlib-resources in /usr/local/lib/python3.10/dist-packages (from ale-py~=0.7.5->gym[accept-rom-license,atari,mujoco]) (5.12.0)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from autorom[accept-rom-license]~=0.4.2->gym[accept-rom-license,atari,mujoco]) (8.1.3)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from autorom[accept-rom-license]~=0.4.2->gym[accept-rom-license,atari,mujoco]) (2.27.1)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from autorom[accept-rom-license]~=0.4.2->gym[accept-rom-license,atari,mujoco]) (4.65.0)\n",
"Collecting AutoROM.accept-rom-license (from autorom[accept-rom-license]~=0.4.2->gym[accept-rom-license,atari,mujoco])\n",
" Downloading AutoROM.accept-rom-license-0.6.1.tar.gz (434 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m434.7/434.7 kB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: pillow>=8.3.2 in /usr/local/lib/python3.10/dist-packages (from imageio>=2.14.1->gym[accept-rom-license,atari,mujoco]) (8.4.0)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->autorom[accept-rom-license]~=0.4.2->gym[accept-rom-license,atari,mujoco]) (1.26.16)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->autorom[accept-rom-license]~=0.4.2->gym[accept-rom-license,atari,mujoco]) (2023.5.7)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->autorom[accept-rom-license]~=0.4.2->gym[accept-rom-license,atari,mujoco]) (2.0.12)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->autorom[accept-rom-license]~=0.4.2->gym[accept-rom-license,atari,mujoco]) (3.4)\n",
"Building wheels for collected packages: AutoROM.accept-rom-license\n",
" Building wheel for AutoROM.accept-rom-license (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for AutoROM.accept-rom-license: filename=AutoROM.accept_rom_license-0.6.1-py3-none-any.whl size=446660 sha256=cc88b888a26d48a67e965722deaad355ad25a3476d1006451e7d3944cc90bf88\n",
" Stored in directory: /root/.cache/pip/wheels/6b/1b/ef/a43ff1a2f1736d5711faa1ba4c1f61be1131b8899e6a057811\n",
"Successfully built AutoROM.accept-rom-license\n",
"Installing collected packages: mujoco, AutoROM.accept-rom-license, autorom\n",
" Attempting uninstall: mujoco\n",
" Found existing installation: mujoco 2.3.6\n",
" Uninstalling mujoco-2.3.6:\n",
" Successfully uninstalled mujoco-2.3.6\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"dm-control 1.0.13 requires mujoco>=2.3.6, but you have mujoco 2.2.0 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0mSuccessfully installed AutoROM.accept-rom-license-0.6.1 autorom-0.4.2 mujoco-2.2.0\n",
"Requirement already satisfied: dm_control in /usr/local/lib/python3.10/dist-packages (1.0.13)\n",
"Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from dm_control) (1.4.0)\n",
"Requirement already satisfied: dm-env in /usr/local/lib/python3.10/dist-packages (from dm_control) (1.6)\n",
"Requirement already satisfied: dm-tree!=0.1.2 in /usr/local/lib/python3.10/dist-packages (from dm_control) (0.1.8)\n",
"Requirement already satisfied: glfw in /usr/local/lib/python3.10/dist-packages (from dm_control) (2.6.1)\n",
"Requirement already satisfied: labmaze in /usr/local/lib/python3.10/dist-packages (from dm_control) (1.0.6)\n",
"Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from dm_control) (4.9.2)\n",
"Collecting mujoco>=2.3.6 (from dm_control)\n",
" Using cached mujoco-2.3.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)\n",
"Requirement already satisfied: numpy>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from dm_control) (1.22.4)\n",
"Requirement already satisfied: protobuf>=3.19.4 in /usr/local/lib/python3.10/dist-packages (from dm_control) (3.20.3)\n",
"Requirement already satisfied: pyopengl>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from dm_control) (3.1.7)\n",
"Requirement already satisfied: pyparsing>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from dm_control) (3.1.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from dm_control) (2.27.1)\n",
"Requirement already satisfied: setuptools!=50.0.0 in /usr/local/lib/python3.10/dist-packages (from dm_control) (67.7.2)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from dm_control) (1.10.1)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from dm_control) (4.65.0)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->dm_control) (1.26.16)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->dm_control) (2023.5.7)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->dm_control) (2.0.12)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->dm_control) (3.4)\n",
"Installing collected packages: mujoco\n",
" Attempting uninstall: mujoco\n",
" Found existing installation: mujoco 2.2.0\n",
" Uninstalling mujoco-2.2.0:\n",
" Successfully uninstalled mujoco-2.2.0\n",
"Successfully installed mujoco-2.3.6\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.65.0)\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YoWBBpU49PA0"
},
"source": [
"\n",
"Proximal Policy Optimization (PPO) is a policy-gradient algorithm where a\n",
"batch of data is being collected and directly consumed to train the policy to maximise\n",
"the expected return given some proximality constraints. You can think of it\n",
"as a sophisticated version of [REINFORCE](https://link.springer.com/content/pdf/10.1007/BF00992696.pdf),\n",
"the foundational policy-optimization algorithm. For more information, see the\n",
"[Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) paper.\n",
"\n",
"PPO is usually regarded as a fast and efficient method for online, on-policy\n",
"reinforcement algorithm. TorchRL provides a loss-module that does all the work\n",
"for you, so that you can rely on this implementation and focus on solving your\n",
"problem rather than re-inventing the wheel every time you want to train a policy.\n",
"\n",
"For completeness, here is a brief overview of what the loss computes, even though this is taken care of by our ClipPPOLoss module—the algorithm works as follows:\n",
"1. we will sample a batch of data by playing the\n",
"policy in the environment for a given number of steps.\n",
"2. Then, we will perform a given number of optimization steps with random sub-samples of this batch using\n",
"a clipped version of the REINFORCE loss.\n",
"3. The clipping will put a pessimistic bound on our loss: lower return estimates will\n",
"be favored compared to higher ones.\n",
"The precise formula of the loss is:\n",
"\n",
"\\begin{align}L(s,a,\\theta_k,\\theta) = \\min\\left(\n",
" \\frac{\\pi_{\\theta}(a|s)}{\\pi_{\\theta_k}(a|s)} A^{\\pi_{\\theta_k}}(s,a), \\;\\;\n",
" g(\\epsilon, A^{\\pi_{\\theta_k}}(s,a))\n",
" \\right),\\end{align}\n",
"\n",
"There are two components in that loss: in the first part of the minimum operator,\n",
"we simply compute an importance-weighted version of the REINFORCE loss (for example, a\n",
"REINFORCE loss that we have corrected for the fact that the current policy\n",
"configuration lags the one that was used for the data collection).\n",
"The second part of that minimum operator is a similar loss where we have clipped\n",
"the ratios when they exceeded or were below a given pair of thresholds.\n",
"\n",
"This loss ensures that whether the advantage is positive or negative, policy\n",
"updates that would produce significant shifts from the previous configuration\n",
"are being discouraged.\n",
"\n",
"This tutorial is structured as follows:\n",
"\n",
"1. First, we will define a set of hyperparameters we will be using for training.\n",
"\n",
"2. Next, we will focus on creating our environment, or simulator, using TorchRL's\n",
" wrappers and transforms.\n",
"\n",
"3. Next, we will design the policy network and the value model,\n",
" which is indispensable to the loss function. These modules will be used\n",
" to configure our loss module.\n",
"\n",
"4. Next, we will create the replay buffer and data loader.\n",
"\n",
"5. Finally, we will run our training loop and analyze the results.\n",
"\n",
"Throughout this tutorial, we'll be using the tensordict library.\n",
"TensorDict is the lingua franca of TorchRL: it helps us abstract\n",
"what a module reads and writes and care less about the specific data\n",
"description and more about the algorithm itself.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "FHA36J4F9PA2"
},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1ZMiXFne9PA2"
},
"source": [
"## Define Hyperparameters\n",
"\n",
"We set the hyperparameters for our algorithm. Depending on the resources\n",
"available, one may choose to execute the policy on GPU or on another\n",
"device.\n",
"The ``frame_skip`` will control how for how many frames is a single\n",
"action being executed. The rest of the arguments that count frames\n",
"must be corrected for this value (since one environment step will\n",
"actually return ``frame_skip`` frames).\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "vv4kSXU-9PA3"
},
"outputs": [],
"source": [
"device = \"cpu\" if not torch.cuda.device_count() else \"cuda:0\"\n",
"num_cells = 256 # number of cells in each layer\n",
"lr = 3e-4\n",
"max_grad_norm = 1.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EWFyRJTj9PA3"
},
"source": [
"### Data collection parameters\n",
"\n",
"When collecting data, we will be able to choose how big each batch will be\n",
"by defining a ``frames_per_batch`` parameter. We will also define how many\n",
"frames (such as the number of interactions with the simulator) we will allow ourselves to\n",
"use. In general, the goal of an RL algorithm is to learn to solve the task\n",
"as fast as it can in terms of environment interactions: the lower the ``total_frames``\n",
"the better.\n",
"We also define a ``frame_skip``: in some contexts, repeating the same action\n",
"multiple times over the course of a trajectory may be beneficial as it makes\n",
"the behavior more consistent and less erratic. However, \"skipping\"\n",
"too many frames will hamper training by reducing the reactivity of the actor\n",
"to observation changes.\n",
"\n",
"When using ``frame_skip`` it is good practice to\n",
"correct the other frame counts by the number of frames we are grouping\n",
"together. If we configure a total count of X frames for training but\n",
"use a ``frame_skip`` of Y, we will be actually collecting ``XY`` frames in total\n",
"which exceeds our predefined budget.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "95CTm8jz9PA4"
},
"outputs": [],
"source": [
"frame_skip = 1\n",
"frames_per_batch = 1000 // frame_skip\n",
"# For a complete training, bring the number of frames up to 1M\n",
"total_frames = 50_000 // frame_skip"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GaBSLXuW9PA5"
},
"source": [
"### PPO parameters\n",
"\n",
"At each data collection (or batch collection) we will run the optimization\n",
"over a certain number of *epochs*, each time consuming the entire data we just\n",
"acquired in a nested training loop. Here, the ``sub_batch_size`` is different from the\n",
"``frames_per_batch`` here above: recall that we are working with a \"batch of data\"\n",
"coming from our collector, which size is defined by ``frames_per_batch``, and that\n",
"we will further split in smaller sub-batches during the inner training loop.\n",
"The size of these sub-batches is controlled by ``sub_batch_size``.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "O_w1Ct_E9PA5"
},
"outputs": [],
"source": [
"sub_batch_size = 64 # cardinality of the sub-samples gathered from the current data in the inner loop\n",
"num_epochs = 10 # optimization steps per batch of data collected\n",
"clip_epsilon = (\n",
" 0.2 # clip value for PPO loss: see the equation in the intro for more context.\n",
")\n",
"gamma = 0.99\n",
"lmbda = 0.95\n",
"entropy_eps = 1e-4"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XKmP30GO9PA6"
},
"source": [
"## Define an environment\n",
"\n",
"In RL, an *environment* is usually the way we refer to a simulator or a\n",
"control system. Various libraries provide simulation environments for reinforcement\n",
"learning, including Gymnasium (previously OpenAI Gym), DeepMind Control Suite, and\n",
"many others.\n",
"As a general library, TorchRL's goal is to provide an interchangeable interface\n",
"to a large panel of RL simulators, allowing you to easily swap one environment\n",
"with another. For example, creating a wrapped gym environment can be achieved with few characters:\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "fJiLpbS09PA6"
},
"outputs": [],
"source": [
"from torchrl.envs import (\n",
" Compose,\n",
" DoubleToFloat,\n",
" ObservationNorm,\n",
" TransformedEnv,\n",
")\n",
"from torchrl.envs.libs.gym import GymEnv\n",
"from torchrl.envs.utils import check_env_specs, set_exploration_mode\n",
"\n",
"base_env = GymEnv(\"InvertedDoublePendulum-v4\", device=device, frame_skip=frame_skip)"
]
},
{
"cell_type": "markdown",
"source": [
"### Exercise\n",
"\n",
"1. Play a bit with this environment: how can you reset the environment? How can you make a step?\n",
"\n",
"2. Try creating environments for yourself:\n",
"- how can you buils an environment from Deepmind-control (instead of gym)?\n",
"- Can you build an Atari environment? Keep this environment (say ``atari_env``), we will need it later.\n",
"- How can you build a version of the InvertedDoublePendulum with rendering of images and states?"
],
"metadata": {
"id": "JA6YeqMg-Jre"
}
},
{
"cell_type": "code",
"source": [
"td = base_env.reset()\n",
"td = base_env.rand_step()\n",
"# creating envs\n",
"from torchrl.envs.libs.dm_control import DMControlEnv\n",
"dmenv = DMControlEnv(\"cheetah\", \"run\")\n",
"atari_env = GymEnv(\"ALE/Pong-v5\")\n",
"# leads to a bug on colab :/\n",
"# pendulum_render = GymEnv(\"InvertedDoublePendulum-v4\", from_pixels=True)"
],
"metadata": {
"id": "56vVuXKd-LrO"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "1C6-3rZC9PA6"
},
"source": [
"There are a few things to notice in this code: first, we created\n",
"the environment by calling the ``GymEnv`` wrapper. If extra keyword arguments\n",
"are passed, they will be transmitted to the ``gym.make`` method, hence covering\n",
"the most common environment construction commands.\n",
"Alternatively, one could also directly create a gym environment using ``gym.make(env_name, **kwargs)``\n",
"and wrap it in a `GymWrapper` class.\n",
"\n",
"Also the ``device`` argument: for gym, this only controls the device where\n",
"input action and observed states will be stored, but the execution will always\n",
"be done on CPU. The reason for this is simply that gym does not support on-device\n",
"execution, unless specified otherwise. For other libraries, we have control over\n",
"the execution device and, as much as we can, we try to stay consistent in terms of\n",
"storing and execution backends.\n",
"\n",
"## Using environments: Rollouts\n",
"\n",
"For fun, let's see what a simple random rollout looks like. You can\n",
"call `env.rollout(n_steps)` and get an overview of what the environment inputs\n",
"and outputs look like. Actions will automatically be drawn from the action spec\n",
"domain, so you don't need to care about designing a random sampler.\n",
"\n",
"Typically, at each step, an RL environment receives an\n",
"action as input, and outputs an observation, a reward and a done state. The\n",
"observation may be composite, meaning that it could be composed of more than one\n",
"tensor. This is not a problem for TorchRL, since the whole set of observations\n",
"is automatically packed in the output `TensorDict`. After executing a rollout\n",
"(for example, a sequence of environment steps and random action generations) over a given\n",
"number of steps, we will retrieve a `TensorDict` instance with a shape\n",
"that matches this trajectory length:"
]
},
{
"cell_type": "code",
"source": [
"rollout = base_env.rollout(3)\n",
"print(\"rollout of three steps:\", rollout)\n",
"print(\"Shape of the rollout TensorDict:\", rollout.batch_size)"
],
"metadata": {
"id": "nEsPjhWV31o9",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "0da0f4a4-cdcf-4f3d-ecbf-12e0d67e2eb5"
},
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"rollout of three steps: TensorDict(\n",
" fields={\n",
" action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
" next: TensorDict(\n",
" fields={\n",
" done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
" observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float64, is_shared=False),\n",
" reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},\n",
" batch_size=torch.Size([3]),\n",
" device=cpu,\n",
" is_shared=False),\n",
" observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float64, is_shared=False)},\n",
" batch_size=torch.Size([3]),\n",
" device=cpu,\n",
" is_shared=False)\n",
"Shape of the rollout TensorDict: torch.Size([3])\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Our rollout data has a shape of ``torch.Size([3])``, which matches the number of steps\n",
"we ran it for. The ``\"next\"`` entry points to the data coming after the current step.\n",
"In most cases, the ``\"next\"\"`` data at time `t` matches the data at ``t+1``, but this\n",
"may not be the case if we are using some specific transformations (for example, multi-step).\n",
"\n",
"## Exercise\n",
"\n",
"Check that the content of ``\"next\"`` matches the content of the root by indexing tensordicts. Hint: you can index a tensordict along the time dimension (as you would index a tensor) and along the key dimension."
],
"metadata": {
"id": "0OfA7M3b4FWA"
}
},
{
"cell_type": "code",
"source": [
"## your answer there\n",
"next_rollout = rollout.get(\"next\")\n",
"print((next_rollout[\"observation\"][:-1] - rollout[\"observation\"][1:]).norm())"
],
"metadata": {
"id": "vHNBHIVO_rSz",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "25d9b1c8-b149-4a70-dc2b-12b08c061b95"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor(0., dtype=torch.float64)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Transforms\n",
"\n",
"We will append some transforms to our environments to prepare the data for\n",
"the policy. In Gym, this is usually achieved via wrappers. TorchRL takes a different\n",
"approach, more similar to other pytorch domain libraries, through the use of transforms.\n",
"To add transforms to an environment, one should simply wrap it in a :class:`TransformedEnv`\n",
"instance, and append the sequence of transforms to it. The transformed environment will inherit the device and meta-data of the wrapped environment, and transform these depending on the sequence of transforms it contains.\n",
"\n",
"### Normalization\n",
"\n",
"The first to encode is a normalization transform.\n",
"As a rule of thumbs, it is preferable to have data that loosely\n",
"match a unit Gaussian distribution: to obtain this, we will\n",
"run a certain number of random steps in the environment and compute\n",
"the summary statistics of these observations.\n",
"\n",
"We'll append two other transforms: the `DoubleToFloat` transform will\n",
"convert double entries to single-precision numbers, ready to be read by the\n",
"policy. The `StepCounter` transform will be used to count the steps before\n",
"the environment is terminated. We will use this measure as a supplementary measure of performance.\n",
"\n",
"As we will see later, many of the TorchRL's classes rely on `TensorDict`\n",
"to communicate. You could think of it as a python dictionary with some extra\n",
"tensor features. In practice, this means that many modules we will be working\n",
"with need to be told what key to read (``in_keys``) and what key to write\n",
"(``out_keys``) in the ``tensordict`` they will receive. Usually, if ``out_keys``\n",
"is omitted, it is assumed that the ``in_keys`` entries will be updated\n",
"in-place. For our transforms, the only entry we are interested in is referred\n",
"to as ``\"observation\"`` and our transform layers will be told to modify this\n",
"entry and this entry only:\n",
"\n",
"\n"
],
"metadata": {
"id": "V1Odiq7X32V7"
}
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "Wv1bAts59PA6"
},
"outputs": [],
"source": [
"from torchrl.envs import StepCounter\n",
"env = TransformedEnv(\n",
" base_env,\n",
" Compose(\n",
" # normalize observations\n",
" ObservationNorm(in_keys=[\"observation\"]),\n",
" DoubleToFloat(in_keys=[\"observation\"]),\n",
" StepCounter(),\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"source": [
"## Exercise\n",
"\n",
"1. The code above will not run because one transform is missing! The one we want should count the steps executed in a trajectory. Can you find what it is?\n",
"No need to code the transform, just import it from `torchrl.envs`.\n",
"\n",
"2. Transforms are one of the coolest things in torchrl! Try them for yourself.\n",
"- Can you find a transform to normalize rewards on the fly?\n",
"- Can you find a transform to read images from a numpy to a torch format? Try using it with the Atari environment ``atari_env`` you built just before.\n",
"\n",
"3. Can you find a transform to compute the total reward of a trajectory? Try adding it to our transformed environment after it has been created (ie without reinstantiating the environment).\n"
],
"metadata": {
"id": "kPqF1hS6-RuF"
}
},
{
"cell_type": "code",
"source": [
"## Your answer here\n",
"from torchrl.envs import VecNorm, ToTensorImage, RewardSum\n",
"env.append_transform(VecNorm()) # normalizes rewards on-the-fly\n",
"atari_env = TransformedEnv(atari_env, ToTensorImage())\n",
"env.append_transform(RewardSum()) # sums rewards for a trajectory\n"
],
"metadata": {
"id": "XNuOKMtN-QbI"
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "3kCEfTXB9PA7"
},
"source": [
"As you may have noticed, we have created a normalization layer but we did not\n",
"set its normalization parameters. To do this, `ObservationNorm` can\n",
"automatically gather the summary statistics of our environment:\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "BiLqkZi99PA7"
},
"outputs": [],
"source": [
"env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_t3f4c4f9PA7"
},
"source": [
"The `ObservationNorm` transform has now been populated with a\n",
"location and a scale that will be used to normalize the data.\n",
"\n",
"Let us do a little sanity check for the shape of our summary stats:\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "qX3oemra9PA7",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f35c3e62-64ec-49a6-e117-6d7b53749f5b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"normalization constant shape: torch.Size([11])\n"
]
}
],
"source": [
"print(\"normalization constant shape:\", env.transform[0].loc.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bMWqTvhL9PA7"
},
"source": [
"An environment is not only defined by its simulator and transforms, but also\n",
"by a series of metadata that describe what can be expected during its\n",
"execution.\n",
"For efficiency purposes, TorchRL is quite stringent when it comes to\n",
"environment specs, but you can easily check that your environment specs are\n",
"adequate.\n",
"In our example, the :class:`GymWrapper` and :class:`GymEnv` that inherits\n",
"from it already take care of setting the proper specs for your environment so\n",
"you should not have to care about this.\n",
"\n",
"Nevertheless, let's see a concrete example using our transformed\n",
"environment by looking at its specs.\n",
"There are three specs to look at: ``observation_spec`` which defines what\n",
"is to be expected when executing an action in the environment,\n",
"``reward_spec`` which indicates the reward domain and finally the\n",
"``input_spec`` (which contains the ``action_spec``) and which represents\n",
"everything an environment requires to execute a single step.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "syiBvoQj9PA7",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a9d3e895-8c1f-4eea-fa54-c9729ec587bb"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"observation_spec: CompositeSpec(\n",
" observation: UnboundedContinuousTensorSpec(\n",
" shape=torch.Size([11]),\n",
" space=None,\n",
" device=cpu,\n",
" dtype=torch.float32,\n",
" domain=continuous),\n",
" step_count: UnboundedDiscreteTensorSpec(\n",
" shape=torch.Size([1]),\n",
" space=ContinuousBox(\n",
" minimum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True), \n",
" maximum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),\n",
" device=cpu,\n",
" dtype=torch.int64,\n",
" domain=continuous),\n",
" episode_reward: UnboundedContinuousTensorSpec(\n",
" shape=torch.Size([1]),\n",
" space=ContinuousBox(\n",
" minimum=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), \n",
" maximum=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),\n",
" device=cpu,\n",
" dtype=torch.float32,\n",
" domain=continuous), device=cpu, shape=torch.Size([]))\n",
"reward_spec: UnboundedContinuousTensorSpec(\n",
" shape=torch.Size([1]),\n",
" space=ContinuousBox(\n",
" minimum=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), \n",
" maximum=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),\n",
" device=cpu,\n",
" dtype=torch.float32,\n",
" domain=continuous)\n",
"input_spec: CompositeSpec(\n",
" _state_spec: CompositeSpec(\n",
" step_count: UnboundedDiscreteTensorSpec(\n",
" shape=torch.Size([1]),\n",
" space=ContinuousBox(\n",
" minimum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True), \n",
" maximum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),\n",
" device=cpu,\n",
" dtype=torch.int64,\n",
" domain=continuous), device=cpu, shape=torch.Size([])),\n",
" _action_spec: CompositeSpec(\n",
" action: BoundedTensorSpec(\n",
" shape=torch.Size([1]),\n",
" space=ContinuousBox(\n",
" minimum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), \n",
" maximum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),\n",
" device=cpu,\n",
" dtype=torch.float32,\n",
" domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))\n",
"action_spec (as defined by input_spec): BoundedTensorSpec(\n",
" shape=torch.Size([1]),\n",
" space=ContinuousBox(\n",
" minimum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), \n",
" maximum=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),\n",
" device=cpu,\n",
" dtype=torch.float32,\n",
" domain=continuous)\n"
]
}
],
"source": [
"print(\"observation_spec:\", env.observation_spec)\n",
"print(\"reward_spec:\", env.reward_spec)\n",
"print(\"input_spec:\", env.input_spec)\n",
"print(\"action_spec (as defined by input_spec):\", env.action_spec)"
]
},
{
"cell_type": "markdown",
"source": [
"## Exercise\n",
"\n",
"How would you normalize data from pixels (images)? How can you control the dimensions of your normalization statistics? Try it out with your ``atari_env``!"
],
"metadata": {
"id": "BM3kR1zh-0NS"
}
},
{
"cell_type": "code",
"source": [
"atari_env = GymEnv(\"ALE/Pong-v5\")\n",
"atari_env = TransformedEnv(atari_env, ToTensorImage())\n",
"atari_env.append_transform(ObservationNorm(in_keys=[\"pixels\"]))\n",
"atari_env.transform[-1].init_stats(num_iter=100, reduce_dim=[-4, -3, -2, -1], keep_dims=[-3, -2, -1], cat_dim=-4)"
],
"metadata": {
"id": "M0qb4YoB-_ym"
},
"execution_count": 25,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "R229JigR9PA8"
},
"source": [
"The `check_env_specs` function runs a small rollout and compares its output against the environment specs. If no error is raised, we can be confident that the specs are properly defined:\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "oXLsJmWZ9PA8",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "fbd6ff8c-231e-4dbf-a7ac-85f2effc21b1"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"check_env_specs succeeded!\n"
]
}
],
"source": [
"check_env_specs(env)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kh6uqcen9PA8"
},
"source": [
"## Policy\n",
"\n",
"PPO utilizes a stochastic policy to handle exploration. This means that our\n",
"neural network will have to output the parameters of a distribution, rather\n",
"than a single value corresponding to the action taken.\n",
"\n",
"As the data is continuous, we use a Tanh-Normal distribution to respect the\n",
"action space boundaries. TorchRL provides such distribution, and the only\n",
"thing we need to care about is to build a neural network that outputs the\n",
"right number of parameters for the policy to work with (a location, or mean,\n",
"and a scale):\n",
"\n",
"\\begin{align}f_{\\theta}(\\text{observation}) = \\mu_{\\theta}(\\text{observation}), \\sigma^{+}_{\\theta}(\\text{observation})\\end{align}\n",
"\n",
"The only extra-difficulty that is brought up here is to split our output in two\n",
"equal parts and map the second to a strictly positive space.\n",
"\n",
"We design the policy in three steps:\n",
"\n",
"1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``;\n",
"\n",
"2. Append a :class:`NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts\n",
" and applies a positive transformation to the scale parameter);\n",
"\n",
"3. Create a probabilistic :class:`TensorDictModule` that can create this distribution and sample from it.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "iuFoQbOc9PA8",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "12e46143-8a0a-4464-c9c1-1b39999d2070"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
" warnings.warn('Lazy modules are a new feature under heavy development '\n"
]
}
],
"source": [
"from tensordict.nn import TensorDictModule\n",
"from tensordict.nn.distributions import NormalParamExtractor\n",
"from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator\n",
"from torch import nn\n",
"\n",
"actor_net = nn.Sequential(\n",
" nn.LazyLinear(num_cells, device=device),\n",
" nn.Tanh(),\n",
" nn.LazyLinear(num_cells, device=device),\n",
" nn.Tanh(),\n",
" nn.LazyLinear(num_cells, device=device),\n",
" nn.Tanh(),\n",
" nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A7OnyzKm9PA9"
},
"source": [
"To enable the policy to \"talk\" with the environment through the ``tensordict``\n",
"data carrier, we wrap the ``nn.Module`` in a `TensorDictModule`. This\n",
"class will simply ready the ``in_keys`` it is provided with and write the\n",
"outputs in-place at the registered ``out_keys``.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"id": "nHhrBOrF9PA9"
},
"outputs": [],
"source": [
"policy_module = TensorDictModule(\n",
" actor_net, in_keys=[\"observation\"], out_keys=[\"hidden\"]\n",
")"
]
},
{
"cell_type": "markdown",
"source": [
"## Exercise\n",
"\n",
"The module we have just written reads an observation and writes a `hidden` variable in the tensordict provided.\n",
"Using ``env.reset()``, print what a call to our policy_module looks like.\n",
"Then, using `NormalParamExtractor`, `TensorDictModule` and `tensordict.TensorDictSequential`, combine this module with another to obtain a location `loc` and scale `scale` variables in the tensordict."
],
"metadata": {
"id": "qXNl-3C6_Maj"
}
},
{
"cell_type": "code",
"source": [
"## your answer here\n",
"from tensordict.nn import TensorDictSequential, NormalParamExtractor\n",
"extractor = NormalParamExtractor()\n",
"td_extractor = TensorDictModule(extractor, in_keys=[\"hidden\"], out_keys=[\"loc\", \"scale\"])\n",
"combined_policy_module = TensorDictSequential(policy_module, td_extractor)"
],
"metadata": {
"id": "1c5CF-xM_LlZ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4249726e-fd20-42a5-b59a-fd18d4161d74"
},
"execution_count": 30,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3cVfph29PA9"
},
"source": [
"We now need to build a distribution out of the location and scale of our\n",
"normal distribution.\n",
"To do so, we instruct the `ProbabilisticActor`\n",
"class to build a `TanhNormal` out of the location and scale\n",
"parameters. We also provide the minimum and maximum values of this\n",
"distribution, which we gather from the environment specs.\n",
"\n",
"The name of the ``in_keys`` (and hence the name of the ``out_keys`` from\n",
"the `TensorDictModule` above) cannot be set to any value one may\n",
"like, as the :class:`TanhNormal` distribution constructor will expect the\n",
"``loc`` and ``scale`` keyword arguments. That being said,\n",
"`ProbabilisticActor` also accepts ``Dict[str, str]`` typed ``in_keys``\n",
"where the key-value pair indicates what ``in_key`` string should be used for\n",
"every keyword argument that is to be used.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"id": "eyJxQlhH9PA9"
},
"outputs": [],
"source": [
"policy_module = ProbabilisticActor(\n",
" module=combined_policy_module,\n",
" spec=env.action_spec,\n",
" in_keys=[\"loc\", \"scale\"],\n",
" distribution_class=TanhNormal,\n",
" distribution_kwargs={\n",
" \"min\": env.action_spec.space.minimum,\n",
" \"max\": env.action_spec.space.maximum,\n",
" },\n",
" return_log_prob=True,\n",
" # we'll need the log-prob for the numerator of the importance weights\n",
")"
]
},
{
"cell_type": "markdown",
"source": [
"## Exercise\n",
"\n",
"We now have a policy and an environment, which means that we can already observe how the two interact with each other.\n",
"Try it for yourself: call `reset` from the environment. Observe what the output is. How would you pass this to the policy we have built? What output do you expect?"
],
"metadata": {
"id": "vRKHHRVG5-be"
}
},
{
"cell_type": "code",
"source": [
"## your answer here\n",
"td = env.reset()\n",
"policy_module(td)"
],
"metadata": {
"id": "8WxNRO5j7H55",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "e34d22ab-1c23-48eb-dfc5-cdacf215aa8c"
},
"execution_count": 32,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
" and should_run_async(code)\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorDict(\n",
" fields={\n",
" action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
" episode_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" hidden: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" loc: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" sample_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" scale: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)"
]
},
"metadata": {},
"execution_count": 32
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "H-MnqQ5u9PA9"
},
"source": [
"## Value network\n",
"\n",
"The value network is a crucial component of the PPO algorithm, even though it\n",
"won't be used at inference time. This module will read the observations and\n",
"return an estimation of the discounted return for the following trajectory.\n",
"This allows us to amortize learning by relying on the some utility estimation\n",
"that is learned on-the-fly during training. Our value network share the same\n",
"structure as the policy, but for simplicity we assign it its own set of\n",
"parameters.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "Ynin5tOn9PA9",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "be5c4643-0b5f-48c6-9cdb-981983aa9814"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
" warnings.warn('Lazy modules are a new feature under heavy development '\n"
]
}
],
"source": [
"value_net = nn.Sequential(\n",
" nn.LazyLinear(num_cells, device=device),\n",
" nn.Tanh(),\n",
" nn.LazyLinear(num_cells, device=device),\n",
" nn.Tanh(),\n",
" nn.LazyLinear(num_cells, device=device),\n",
" nn.Tanh(),\n",
" nn.LazyLinear(1, device=device),\n",
")\n",
"\n",
"value_module = ValueOperator(\n",
" module=value_net,\n",
" in_keys=[\"observation\"],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tkQlC4iI9PA9"
},
"source": [
"let's try our policy and value modules. As we said earlier, the usage of\n",
"`TensorDictModule` makes it possible to directly read the output\n",
"of the environment to run these modules, as they know what information to read\n",
"and where to write it:\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"id": "YkpaKu0U9PA9",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "5224e02d-6c70-42ab-ed3d-1a3322dc50d3"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Running policy: TensorDict(\n",
" fields={\n",
" action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
" episode_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" hidden: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" loc: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" sample_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" scale: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)\n",
"Running value: TensorDict(\n",
" fields={\n",
" done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
" episode_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" state_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
" step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)},\n",
" batch_size=torch.Size([]),\n",
" device=cpu,\n",
" is_shared=False)\n"
]
}
],
"source": [
"print(\"Running policy:\", policy_module(env.reset()))\n",
"print(\"Running value:\", value_module(env.reset()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Jdx-isbn9PA-"
},
"source": [
"## Data collector\n",
"\n",
"TorchRL provides a set of `DataCollector` classes. Briefly, these\n",
"classes execute three operations: reset an environment, compute an action\n",
"given the latest observation, execute a step in the environment, and repeat\n",
"the last two steps until the environment reaches a stop signal (or ``\"done\"``\n",
"state).\n",
"\n",
"They allow you to control how many frames to collect at each iteration\n",
"(through the ``frames_per_batch`` parameter),\n",
"when to reset the environment (through the ``max_frames_per_traj`` argument),\n",
"on which ``device`` the policy should be executed, etc. They are also\n",
"designed to work efficiently with batched and multiprocessed environments.\n",
"\n",
"The simplest data collector is the :class:`SyncDataCollector`: it is an\n",
"iterator that you can use to get batches of data of a given length, and\n",
"that will stop once a total number of frames (``total_frames``) have been\n",
"collected.\n",
"Other data collectors (``MultiSyncDataCollector`` and\n",
"``MultiaSyncDataCollector``) will execute the same operations in synchronous\n",
"and asynchronous manner over a set of multiprocessed workers.\n",
"\n",
"As for the policy and environment before, the data collector will return\n",
"`TensorDict` instances with a total number of elements that will\n",
"match ``frames_per_batch``. Using `TensorDict` to pass data to the\n",
"training loop allows you to write data loading pipelines\n",
"that are 100% oblivious to the actual specificities of the rollout content.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"id": "oj36Dw2x9PA-"
},
"outputs": [],
"source": [
"from torchrl.collectors import SyncDataCollector\n",
"\n",
"collector = SyncDataCollector(\n",
" env,\n",
" policy_module,\n",
" frames_per_batch=frames_per_batch,\n",
" total_frames=total_frames,\n",
" split_trajs=False,\n",
" device=device,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P64eS7_Y9PA-"
},
"source": [
"## Replay buffer\n",
"\n",
"Replay buffers are a common building piece of off-policy RL algorithms.\n",
"In on-policy contexts, a replay buffer is refilled every time a batch of\n",
"data is collected, and its data is repeatedly consumed for a certain number\n",
"of epochs.\n",
"\n",
"TorchRL's replay buffers are built using a common container\n",
"`ReplayBuffer` which takes as argument the components of the buffer:\n",
"a storage, a writer, a sampler and possibly some transforms.\n",
"Only the storage (which indicates the replay buffer capacity) is mandatory.\n",
"We also specify a sampler without repetition to avoid sampling multiple times\n",
"the same item in one epoch.\n",
"Using a replay buffer for PPO is not mandatory and we could simply\n",
"sample the sub-batches from the collected batch, but using these classes\n",
"make it easy for us to build the inner training loop in a reproducible way."
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"id": "tIkcHEy_9PA-"
},
"outputs": [],
"source": [
"from torchrl.data.replay_buffers import TensorDictReplayBuffer\n",
"from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement\n",
"from torchrl.data.replay_buffers.storages import LazyTensorStorage\n",
"\n",
"replay_buffer = TensorDictReplayBuffer(\n",
" storage=LazyTensorStorage(frames_per_batch),\n",
" sampler=SamplerWithoutReplacement(),\n",
" batch_size=256,\n",
")"
]
},
{
"cell_type": "markdown",
"source": [
"## Exercise\n",
"\n",
"As environments do, replay buffers also support transforms.\n",
"Using `ReplayBuffer.append_transform` and `torchrl.envs.RewardScaling`, pass a transform to the replay buffer that scales the reward to half of its value."
],
"metadata": {
"id": "Ho4An2Mc_ch7"
}
},
{
"cell_type": "code",
"source": [
"## your answer here\n",
"from torchrl.envs import RewardScaling\n",
"replay_buffer.append_transform(RewardScaling(scale=0.5, loc=0.0, in_keys=[(\"next\", \"reward\")]))"
],
"metadata": {
"id": "WeFe89bI_aEQ"
},
"execution_count": 50,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "BnYxRQ7s9PA-"
},
"source": [
"## Loss function\n",
"\n",
"The PPO loss can be directly imported from TorchRL for convenience using the\n",
"`ClipPPOLoss` class. This is the easiest way of utilizing PPO:\n",
"it hides away the mathematical operations of PPO and the control flow that\n",
"goes with it.\n",
"\n",
"PPO requires some \"advantage estimation\" to be computed. In short, an advantage\n",
"is a value that reflects an expectancy over the return value while dealing with\n",
"the bias / variance tradeoff.\n",
"To compute the advantage, one just needs to (1) build the advantage module, which\n",
"utilizes our value operator, and (2) pass each batch of data through it before each\n",
"epoch.\n",
"The GAE module will update the input ``tensordict`` with new ``\"advantage\"`` and\n",
"``\"value_target\"`` entries.\n",
"The ``\"value_target\"`` is a gradient-free tensor that represents the empirical\n",
"value that the value network should represent with the input observation.\n",
"Both of these will be used by `ClipPPOLoss` to\n",
"return the policy and value losses.\n"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"id": "EyE4MeP19PBC"
},
"outputs": [],
"source": [
"from torchrl.objectives import ClipPPOLoss\n",
"from torchrl.objectives.value import GAE\n",
"\n",
"advantage_module = GAE(\n",
" gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True\n",
")\n",
"\n",
"loss_module = ClipPPOLoss(\n",
" actor=policy_module,\n",
" critic=value_module,\n",
" clip_epsilon=clip_epsilon,\n",
" entropy_bonus=bool(entropy_eps),\n",
" entropy_coef=entropy_eps,\n",
")\n",
"\n",
"optim = torch.optim.Adam(loss_module.parameters(), lr)\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
" optim, total_frames // frames_per_batch, 0.0\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2w-rKB5c9PBC"
},
"source": [
"## Training loop\n",
"We now have all the pieces needed to code our training loop.\n",
"The steps include:\n",
"\n",
"* Collect data\n",
"\n",
" * Compute advantage\n",
"\n",
" * Loop over the collected to compute loss values\n",
" * Back propagate\n",
" * Optimize\n",
" * Repeat\n",
"\n",
" * Repeat\n",
"\n",
"* Repeat\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RBpNV5-h9PBD",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "67c7ac97-cbda-4da1-e631-8213dcd4fd22"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n",
" 0%| | 0/50000 [00:11<?, ?it/s]\n",
"\n",
" 2%|▏ | 1000/50000 [00:07<05:55, 137.89it/s]\u001b[A/usr/local/lib/python3.10/dist-packages/tensordict/nn/probabilistic.py:79: DeprecationWarning: set_interaction_mode is deprecated for naming clarity. Please use set_interaction_type with InteractionType enum instead.\n",
" _insert_interaction_mode_deprecation_warning(\"set_\")\n",
"\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward=-0.0169 (init=-0.0169), step count (max): 11, lr policy: 0.0003: 2%|▏ | 1000/50000 [00:07<05:55, 137.89it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward=-0.0169 (init=-0.0169), step count (max): 11, lr policy: 0.0003: 4%|▍ | 2000/50000 [00:12<05:03, 157.90it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.1121 (init=-0.0169), step count (max): 16, lr policy: 0.0003: 4%|▍ | 2000/50000 [00:12<05:03, 157.90it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.1121 (init=-0.0169), step count (max): 16, lr policy: 0.0003: 6%|▌ | 3000/50000 [00:21<05:42, 137.42it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.2313 (init=-0.0169), step count (max): 13, lr policy: 0.0003: 6%|▌ | 3000/50000 [00:21<05:42, 137.42it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.2313 (init=-0.0169), step count (max): 13, lr policy: 0.0003: 8%|▊ | 4000/50000 [00:28<05:33, 138.02it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.2865 (init=-0.0169), step count (max): 18, lr policy: 0.0003: 8%|▊ | 4000/50000 [00:28<05:33, 138.02it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.2865 (init=-0.0169), step count (max): 18, lr policy: 0.0003: 10%|█ | 5000/50000 [00:39<06:19, 118.53it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.3637 (init=-0.0169), step count (max): 21, lr policy: 0.0003: 10%|█ | 5000/50000 [00:39<06:19, 118.53it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.3637 (init=-0.0169), step count (max): 21, lr policy: 0.0003: 12%|█▏ | 6000/50000 [00:45<05:42, 128.52it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.4334 (init=-0.0169), step count (max): 28, lr policy: 0.0003: 12%|█▏ | 6000/50000 [00:45<05:42, 128.52it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.4334 (init=-0.0169), step count (max): 28, lr policy: 0.0003: 14%|█▍ | 7000/50000 [00:51<05:13, 137.30it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.3815 (init=-0.0169), step count (max): 27, lr policy: 0.0003: 14%|█▍ | 7000/50000 [00:51<05:13, 137.30it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.3815 (init=-0.0169), step count (max): 27, lr policy: 0.0003: 16%|█▌ | 8000/50000 [00:58<04:52, 143.80it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.3853 (init=-0.0169), step count (max): 30, lr policy: 0.0003: 16%|█▌ | 8000/50000 [00:58<04:52, 143.80it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.3853 (init=-0.0169), step count (max): 30, lr policy: 0.0003: 18%|█▊ | 9000/50000 [01:05<04:55, 138.72it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.3376 (init=-0.0169), step count (max): 32, lr policy: 0.0003: 18%|█▊ | 9000/50000 [01:05<04:55, 138.72it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.3376 (init=-0.0169), step count (max): 32, lr policy: 0.0003: 20%|██ | 10000/50000 [01:11<04:29, 148.59it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.2928 (init=-0.0169), step count (max): 30, lr policy: 0.0003: 20%|██ | 10000/50000 [01:11<04:29, 148.59it/s]\u001b[A\n",
"eval cumulative reward: 4.9523 (init: 4.9523), eval step-count: 9, average reward= 0.2928 (init=-0.0169), step count (max): 30, lr policy: 0.0003: 22%|██▏ | 11000/50000 [01:18<04:20, 149.90it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.3027 (init=-0.0169), step count (max): 31, lr policy: 0.0003: 22%|██▏ | 11000/50000 [01:18<04:20, 149.90it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.3027 (init=-0.0169), step count (max): 31, lr policy: 0.0003: 24%|██▍ | 12000/50000 [01:23<04:01, 157.49it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2565 (init=-0.0169), step count (max): 24, lr policy: 0.0003: 24%|██▍ | 12000/50000 [01:23<04:01, 157.49it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2565 (init=-0.0169), step count (max): 24, lr policy: 0.0003: 26%|██▌ | 13000/50000 [01:30<03:59, 154.33it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2208 (init=-0.0169), step count (max): 40, lr policy: 0.0003: 26%|██▌ | 13000/50000 [01:30<03:59, 154.33it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2208 (init=-0.0169), step count (max): 40, lr policy: 0.0003: 28%|██▊ | 14000/50000 [01:36<03:48, 157.44it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2097 (init=-0.0169), step count (max): 43, lr policy: 0.0003: 28%|██▊ | 14000/50000 [01:36<03:48, 157.44it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2097 (init=-0.0169), step count (max): 43, lr policy: 0.0003: 30%|███ | 15000/50000 [01:43<03:44, 155.86it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2436 (init=-0.0169), step count (max): 34, lr policy: 0.0002: 30%|███ | 15000/50000 [01:43<03:44, 155.86it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2436 (init=-0.0169), step count (max): 34, lr policy: 0.0002: 32%|███▏ | 16000/50000 [01:48<03:29, 162.47it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2096 (init=-0.0169), step count (max): 32, lr policy: 0.0002: 32%|███▏ | 16000/50000 [01:48<03:29, 162.47it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2096 (init=-0.0169), step count (max): 32, lr policy: 0.0002: 34%|███▍ | 17000/50000 [01:55<03:27, 159.37it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2414 (init=-0.0169), step count (max): 54, lr policy: 0.0002: 34%|███▍ | 17000/50000 [01:55<03:27, 159.37it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.2414 (init=-0.0169), step count (max): 54, lr policy: 0.0002: 36%|███▌ | 18000/50000 [02:00<03:14, 164.68it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.1786 (init=-0.0169), step count (max): 39, lr policy: 0.0002: 36%|███▌ | 18000/50000 [02:00<03:14, 164.68it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.1786 (init=-0.0169), step count (max): 39, lr policy: 0.0002: 38%|███▊ | 19000/50000 [02:06<03:06, 166.15it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.1887 (init=-0.0169), step count (max): 38, lr policy: 0.0002: 38%|███▊ | 19000/50000 [02:06<03:06, 166.15it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.1887 (init=-0.0169), step count (max): 38, lr policy: 0.0002: 40%|████ | 20000/50000 [02:13<03:03, 163.09it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.1988 (init=-0.0169), step count (max): 37, lr policy: 0.0002: 40%|████ | 20000/50000 [02:13<03:03, 163.09it/s]\u001b[A\n",
"eval cumulative reward: 6.9856 (init: 4.9523), eval step-count: 18, average reward= 0.1988 (init=-0.0169), step count (max): 37, lr policy: 0.0002: 42%|████▏ | 21000/50000 [02:18<02:52, 167.77it/s]\u001b[A\n",
"eval cumulative reward: 0.7546 (init: 4.9523), eval step-count: 17, average reward= 0.1385 (init=-0.0169), step count (max): 37, lr policy: 0.0002: 42%|████▏ | 21000/50000 [02:18<02:52, 167.77it/s]\u001b[A\n",
"eval cumulative reward: 0.7546 (init: 4.9523), eval step-count: 17, average reward= 0.1385 (init=-0.0169), step count (max): 37, lr policy: 0.0002: 44%|████▍ | 22000/50000 [02:25<02:52, 162.20it/s]\u001b[A\n",
"eval cumulative reward: 0.7546 (init: 4.9523), eval step-count: 17, average reward= 0.1715 (init=-0.0169), step count (max): 40, lr policy: 0.0002: 44%|████▍ | 22000/50000 [02:25<02:52, 162.20it/s]\u001b[A\n",
"eval cumulative reward: 0.7546 (init: 4.9523), eval step-count: 17, average reward= 0.1715 (init=-0.0169), step count (max): 40, lr policy: 0.0002: 46%|████▌ | 23000/50000 [02:30<02:40, 168.73it/s]\u001b[A\n",
"eval cumulative reward: 0.7546 (init: 4.9523), eval step-count: 17, average reward= 0.1808 (init=-0.0169), step count (max): 35, lr policy: 0.0002: 46%|████▌ | 23000/50000 [02:30<02:40, 168.73it/s]\u001b[A\n",
"eval cumulative reward: 0.7546 (init: 4.9523), eval step-count: 17, average reward= 0.1808 (init=-0.0169), step count (max): 35, lr policy: 0.0002: 48%|████▊ | 24000/50000 [02:37<02:39, 162.83it/s]\u001b[A\n",
"eval cumulative reward: 0.7546 (init: 4.9523), eval step-count: 17, average reward= 0.1405 (init=-0.0169), step count (max): 55, lr policy: 0.0002: 48%|████▊ | 24000/50000 [02:37<02:39, 162.83it/s]\u001b[A\n",
"eval cumulative reward: 0.7546 (init: 4.9523), eval step-count: 17, average reward= 0.1405 (init=-0.0169), step count (max): 55, lr policy: 0.0002: 50%|█████ | 25000/50000 [02:42<02:29, 167.77it/s]\u001b[A\n",
"eval cumulative reward: 0.7546 (init: 4.9523), eval step-count: 17, average reward= 0.1733 (init=-0.0169), step count (max): 38, lr policy: 0.0002: 50%|█████ | 25000/50000 [02:42<02:29, 167.77it/s]\u001b[A"
]
}
],
"source": [
"logs = defaultdict(list)\n",
"pbar = tqdm(total=total_frames * frame_skip)\n",
"eval_str = \"\"\n",
"\n",
"# We iterate over the collector until it reaches the total number of frames it was\n",
"# designed to collect:\n",
"for i, tensordict_data in enumerate(collector):\n",
" # we now have a batch of data to work with. Let's learn something from it.\n",
" for _ in range(num_epochs):\n",
" # We'll need an \"advantage\" signal to make PPO work.\n",
" # We re-compute it at each epoch as its value depends on the value\n",
" # network which is updated in the inner loop.\n",
" advantage_module(tensordict_data)\n",
" data_view = tensordict_data.reshape(-1)\n",
" replay_buffer.empty()\n",
" replay_buffer.extend(data_view.cpu())\n",
" for subdata in replay_buffer:\n",
" loss_vals = loss_module(subdata.to(device))\n",
" loss_value = (\n",
" loss_vals[\"loss_objective\"]\n",
" + loss_vals[\"loss_critic\"]\n",
" + loss_vals[\"loss_entropy\"]\n",
" )\n",
"\n",
" # Optimization: backward, grad clipping and optimization step\n",
" loss_value.backward()\n",
" # this is not strictly mandatory but it's good practice to keep\n",
" # your gradient norm bounded\n",
" torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)\n",
" optim.step()\n",
" optim.zero_grad()\n",
"\n",
" logs[\"reward\"].append(tensordict_data[\"next\", \"reward\"].mean().item())\n",
" pbar.update(tensordict_data.numel() * frame_skip)\n",
" cum_reward_str = (\n",
" f\"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})\"\n",
" )\n",
" logs[\"step_count\"].append(tensordict_data[\"step_count\"].max().item())\n",
" stepcount_str = f\"step count (max): {logs['step_count'][-1]}\"\n",
" logs[\"lr\"].append(optim.param_groups[0][\"lr\"])\n",
" lr_str = f\"lr policy: {logs['lr'][-1]: 4.4f}\"\n",
" if i % 10 == 0:\n",
" # We evaluate the policy once every 10 batches of data.\n",
" # Evaluation is rather simple: execute the policy without exploration\n",
" # (take the expected value of the action distribution) for a given\n",
" # number of steps (1000, which is our ``env`` horizon).\n",
" # The ``rollout`` method of the ``env`` can take a policy as argument:\n",
" # it will then execute this policy at each step.\n",
" with set_exploration_mode(\"mean\"), torch.no_grad():\n",
" # execute a rollout with the trained policy\n",
" eval_rollout = env.rollout(1000, policy_module)\n",
" logs[\"eval reward\"].append(eval_rollout[\"next\", \"reward\"].mean().item())\n",
" logs[\"eval reward (sum)\"].append(\n",
" eval_rollout[\"next\", \"reward\"].sum().item()\n",
" )\n",
" logs[\"eval step_count\"].append(eval_rollout[\"step_count\"].max().item())\n",
" eval_str = (\n",
" f\"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} \"\n",
" f\"(init: {logs['eval reward (sum)'][0]: 4.4f}), \"\n",
" f\"eval step-count: {logs['eval step_count'][-1]}\"\n",
" )\n",
" del eval_rollout\n",
" pbar.set_description(\", \".join([eval_str, cum_reward_str, stepcount_str, lr_str]))\n",
"\n",
" # We're also using a learning rate scheduler. Like the gradient clipping,\n",
" # this is a nice-to-have but nothing necessary for PPO to work.\n",
" scheduler.step()"
]
},
{
"cell_type": "markdown",
"source": [
"## Exercise\n",
"\n",
"To push a bit further, have a go at one of these challenges:\n",
"- How would you build a value network and policy that share a common backbone? How would that affect your training loop? What design decision would you need to make?\n",
"- What if you wanted a different optimizer for the policy and the value network? How would you build that? How would that change your training loop?"
],
"metadata": {
"id": "RARhLCgF-fc0"
}
},
{
"cell_type": "code",
"source": [
"# these are open questions, happy to give some hints during the talk"
],
"metadata": {
"id": "v5cqkz_ESWHZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "70uOF0VJ9PBD"
},
"source": [
"## Results\n",
"\n",
"Before the 1M step cap is reached, the algorithm should have reached a max\n",
"step count of 1000 steps, which is the maximum number of steps before the\n",
"trajectory is truncated.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iIlN7FzW9PBD"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"plt.subplot(2, 2, 1)\n",
"plt.plot(logs[\"reward\"])\n",
"plt.title(\"training rewards (average)\")\n",
"plt.subplot(2, 2, 2)\n",
"plt.plot(logs[\"step_count\"])\n",
"plt.title(\"Max step count (training)\")\n",
"plt.subplot(2, 2, 3)\n",
"plt.plot(logs[\"eval reward (sum)\"])\n",
"plt.title(\"Return (test)\")\n",
"plt.subplot(2, 2, 4)\n",
"plt.plot(logs[\"eval step_count\"])\n",
"plt.title(\"Max step count (test)\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6wHb-TZN9PBE"
},
"source": [
"## Conclusion and next steps\n",
"\n",
"In this tutorial, we have learned:\n",
"\n",
"1. How to create and customize an environment with `torchrl`;\n",
"2. How to write a model and a loss function;\n",
"3. How to set up a typical training loop.\n",
"\n",
"If you want to experiment with this tutorial a bit more, you can apply the following modifications:\n",
"\n",
"* From an efficiency perspective,\n",
" we could run several simulations in parallel to speed up data collection.\n",
" Check `torchrl.envs.ParallelEnv` for further information.\n",
"\n",
"* From a logging perspective, one could add a `torchrl.record.VideoRecorder` transform to\n",
" the environment after asking for rendering to get a visual rendering of the\n",
" inverted pendulum in action. Check `torchrl.record` to\n",
" know more.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "u3nng0P9SHjd"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
},
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment