Skip to content

Instantly share code, notes, and snippets.

@ssghost
Created January 18, 2024 22:25
Show Gist options
  • Save ssghost/bb6cb0b2d37e7ac733f70e855de546d9 to your computer and use it in GitHub Desktop.
Save ssghost/bb6cb0b2d37e7ac733f70e855de546d9 to your computer and use it in GitHub Desktop.
rl-baselines-zoo.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ssghost/bb6cb0b2d37e7ac733f70e855de546d9/rl-baselines-zoo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XJy9QoDC7XA7"
},
"source": [
"# RL Baselines3 Zoo: Training in Colab\n",
"\n",
"\n",
"\n",
"Github Repo: [https://github.com/DLR-RM/rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo)\n",
"\n",
"Stable-Baselines3 Repo: [https://github.com/DLR-RM/rl-baselines3-zoo](https://github.com/DLR-RM/stable-baselines3)\n",
"\n",
"\n",
"# Install Dependencies\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "twhd0l3yiy2L"
},
"outputs": [],
"source": [
"# for autoformatting\n",
"# %load_ext jupyter_black"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AXVDDlTn02M9",
"outputId": "8165522f-c698-4628-b988-9803db771e04"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]\n",
"Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64 InRelease\n",
"Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease\n",
"Get:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [119 kB]\n",
"Get:5 http://security.ubuntu.com/ubuntu jammy-security InRelease [110 kB]\n",
"Get:6 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [109 kB]\n",
"Hit:7 https://ppa.launchpadcontent.net/c2d4u.team/c2d4u4.0+/ubuntu jammy InRelease\n",
"Get:8 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 Packages [1,622 kB]\n",
"Hit:9 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease\n",
"Hit:10 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease\n",
"Hit:11 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease\n",
"Get:12 http://archive.ubuntu.com/ubuntu jammy-updates/multiverse amd64 Packages [50.4 kB]\n",
"Get:13 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 Packages [1,321 kB]\n",
"Get:14 http://archive.ubuntu.com/ubuntu jammy-backports/main amd64 Packages [78.7 kB]\n",
"Get:15 http://archive.ubuntu.com/ubuntu jammy-backports/universe amd64 Packages [32.7 kB]\n",
"Get:16 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1,056 kB]\n",
"Get:17 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [1,344 kB]\n",
"Fetched 5,845 kB in 2s (2,435 kB/s)\n",
"Reading package lists... Done\n",
"Reading package lists... Done\n",
"Building dependency tree... Done\n",
"Reading state information... Done\n",
"cmake is already the newest version (3.22.1-1ubuntu1.22.04.1).\n",
"ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).\n",
"The following additional packages will be installed:\n",
" freeglut3 libegl-dev libfontenc1 libgl-dev libgl1-mesa-dev libgles-dev libgles1 libglu1-mesa\n",
" libglu1-mesa-dev libglvnd-core-dev libglvnd-dev libglx-dev libice-dev libopengl-dev libsm-dev\n",
" libxfont2 libxkbfile1 libxt-dev swig4.0 x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils\n",
" xserver-common\n",
"Suggested packages:\n",
" libice-doc libsm-doc libxt-doc swig-doc swig-examples swig4.0-examples swig4.0-doc\n",
"The following NEW packages will be installed:\n",
" freeglut3 freeglut3-dev libegl-dev libfontenc1 libgl-dev libgl1-mesa-dev libgles-dev libgles1\n",
" libglu1-mesa libglu1-mesa-dev libglvnd-core-dev libglvnd-dev libglx-dev libice-dev libopengl-dev\n",
" libsm-dev libxfont2 libxkbfile1 libxt-dev swig swig4.0 x11-xkb-utils xfonts-base xfonts-encodings\n",
" xfonts-utils xserver-common xvfb\n",
"0 upgraded, 27 newly installed, 0 to remove and 32 not upgraded.\n",
"Need to get 10.2 MB of archives.\n",
"After this operation, 24.2 MB of additional disk space will be used.\n",
"Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 freeglut3 amd64 2.8.1-6 [74.0 kB]\n",
"Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 libglx-dev amd64 1.4.0-1 [14.1 kB]\n",
"Get:3 http://archive.ubuntu.com/ubuntu jammy/main amd64 libgl-dev amd64 1.4.0-1 [101 kB]\n",
"Get:4 http://archive.ubuntu.com/ubuntu jammy/main amd64 libglvnd-core-dev amd64 1.4.0-1 [12.7 kB]\n",
"Get:5 http://archive.ubuntu.com/ubuntu jammy/main amd64 libegl-dev amd64 1.4.0-1 [18.0 kB]\n",
"Get:6 http://archive.ubuntu.com/ubuntu jammy/main amd64 libgles1 amd64 1.4.0-1 [11.5 kB]\n",
"Get:7 http://archive.ubuntu.com/ubuntu jammy/main amd64 libgles-dev amd64 1.4.0-1 [49.4 kB]\n",
"Get:8 http://archive.ubuntu.com/ubuntu jammy/main amd64 libopengl-dev amd64 1.4.0-1 [3,400 B]\n",
"Get:9 http://archive.ubuntu.com/ubuntu jammy/main amd64 libglvnd-dev amd64 1.4.0-1 [3,162 B]\n",
"Get:10 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 libgl1-mesa-dev amd64 23.0.4-0ubuntu1~22.04.1 [6,510 B]\n",
"Get:11 http://archive.ubuntu.com/ubuntu jammy/main amd64 libglu1-mesa amd64 9.0.2-1 [145 kB]\n",
"Get:12 http://archive.ubuntu.com/ubuntu jammy/main amd64 libglu1-mesa-dev amd64 9.0.2-1 [231 kB]\n",
"Get:13 http://archive.ubuntu.com/ubuntu jammy/main amd64 libice-dev amd64 2:1.0.10-1build2 [51.4 kB]\n",
"Get:14 http://archive.ubuntu.com/ubuntu jammy/main amd64 libsm-dev amd64 2:1.2.3-1build2 [18.1 kB]\n",
"Get:15 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxt-dev amd64 1:1.2.1-1 [396 kB]\n",
"Get:16 http://archive.ubuntu.com/ubuntu jammy/universe amd64 freeglut3-dev amd64 2.8.1-6 [126 kB]\n",
"Get:17 http://archive.ubuntu.com/ubuntu jammy/main amd64 libfontenc1 amd64 1:1.1.4-1build3 [14.7 kB]\n",
"Get:18 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxfont2 amd64 1:2.0.5-1build1 [94.5 kB]\n",
"Get:19 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxkbfile1 amd64 1:1.1.0-1build3 [71.8 kB]\n",
"Get:20 http://archive.ubuntu.com/ubuntu jammy/universe amd64 swig4.0 amd64 4.0.2-1ubuntu1 [1,110 kB]\n",
"Get:21 http://archive.ubuntu.com/ubuntu jammy/universe amd64 swig all 4.0.2-1ubuntu1 [5,632 B]\n",
"Get:22 http://archive.ubuntu.com/ubuntu jammy/main amd64 x11-xkb-utils amd64 7.7+5build4 [172 kB]\n",
"Get:23 http://archive.ubuntu.com/ubuntu jammy/main amd64 xfonts-encodings all 1:1.0.5-0ubuntu2 [578 kB]\n",
"Get:24 http://archive.ubuntu.com/ubuntu jammy/main amd64 xfonts-utils amd64 1:7.7+6build2 [94.6 kB]\n",
"Get:25 http://archive.ubuntu.com/ubuntu jammy/main amd64 xfonts-base all 1:1.0.5 [5,896 kB]\n",
"Get:26 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 xserver-common all 2:21.1.4-2ubuntu1.7~22.04.7 [28.6 kB]\n",
"Get:27 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 xvfb amd64 2:21.1.4-2ubuntu1.7~22.04.7 [865 kB]\n",
"Fetched 10.2 MB in 1s (10.5 MB/s)\n",
"Selecting previously unselected package freeglut3:amd64.\n",
"(Reading database ... 121658 files and directories currently installed.)\n",
"Preparing to unpack .../00-freeglut3_2.8.1-6_amd64.deb ...\n",
"Unpacking freeglut3:amd64 (2.8.1-6) ...\n",
"Selecting previously unselected package libglx-dev:amd64.\n",
"Preparing to unpack .../01-libglx-dev_1.4.0-1_amd64.deb ...\n",
"Unpacking libglx-dev:amd64 (1.4.0-1) ...\n",
"Selecting previously unselected package libgl-dev:amd64.\n",
"Preparing to unpack .../02-libgl-dev_1.4.0-1_amd64.deb ...\n",
"Unpacking libgl-dev:amd64 (1.4.0-1) ...\n",
"Selecting previously unselected package libglvnd-core-dev:amd64.\n",
"Preparing to unpack .../03-libglvnd-core-dev_1.4.0-1_amd64.deb ...\n",
"Unpacking libglvnd-core-dev:amd64 (1.4.0-1) ...\n",
"Selecting previously unselected package libegl-dev:amd64.\n",
"Preparing to unpack .../04-libegl-dev_1.4.0-1_amd64.deb ...\n",
"Unpacking libegl-dev:amd64 (1.4.0-1) ...\n",
"Selecting previously unselected package libgles1:amd64.\n",
"Preparing to unpack .../05-libgles1_1.4.0-1_amd64.deb ...\n",
"Unpacking libgles1:amd64 (1.4.0-1) ...\n",
"Selecting previously unselected package libgles-dev:amd64.\n",
"Preparing to unpack .../06-libgles-dev_1.4.0-1_amd64.deb ...\n",
"Unpacking libgles-dev:amd64 (1.4.0-1) ...\n",
"Selecting previously unselected package libopengl-dev:amd64.\n",
"Preparing to unpack .../07-libopengl-dev_1.4.0-1_amd64.deb ...\n",
"Unpacking libopengl-dev:amd64 (1.4.0-1) ...\n",
"Selecting previously unselected package libglvnd-dev:amd64.\n",
"Preparing to unpack .../08-libglvnd-dev_1.4.0-1_amd64.deb ...\n",
"Unpacking libglvnd-dev:amd64 (1.4.0-1) ...\n",
"Selecting previously unselected package libgl1-mesa-dev:amd64.\n",
"Preparing to unpack .../09-libgl1-mesa-dev_23.0.4-0ubuntu1~22.04.1_amd64.deb ...\n",
"Unpacking libgl1-mesa-dev:amd64 (23.0.4-0ubuntu1~22.04.1) ...\n",
"Selecting previously unselected package libglu1-mesa:amd64.\n",
"Preparing to unpack .../10-libglu1-mesa_9.0.2-1_amd64.deb ...\n",
"Unpacking libglu1-mesa:amd64 (9.0.2-1) ...\n",
"Selecting previously unselected package libglu1-mesa-dev:amd64.\n",
"Preparing to unpack .../11-libglu1-mesa-dev_9.0.2-1_amd64.deb ...\n",
"Unpacking libglu1-mesa-dev:amd64 (9.0.2-1) ...\n",
"Selecting previously unselected package libice-dev:amd64.\n",
"Preparing to unpack .../12-libice-dev_2%3a1.0.10-1build2_amd64.deb ...\n",
"Unpacking libice-dev:amd64 (2:1.0.10-1build2) ...\n",
"Selecting previously unselected package libsm-dev:amd64.\n",
"Preparing to unpack .../13-libsm-dev_2%3a1.2.3-1build2_amd64.deb ...\n",
"Unpacking libsm-dev:amd64 (2:1.2.3-1build2) ...\n",
"Selecting previously unselected package libxt-dev:amd64.\n",
"Preparing to unpack .../14-libxt-dev_1%3a1.2.1-1_amd64.deb ...\n",
"Unpacking libxt-dev:amd64 (1:1.2.1-1) ...\n",
"Selecting previously unselected package freeglut3-dev:amd64.\n",
"Preparing to unpack .../15-freeglut3-dev_2.8.1-6_amd64.deb ...\n",
"Unpacking freeglut3-dev:amd64 (2.8.1-6) ...\n",
"Selecting previously unselected package libfontenc1:amd64.\n",
"Preparing to unpack .../16-libfontenc1_1%3a1.1.4-1build3_amd64.deb ...\n",
"Unpacking libfontenc1:amd64 (1:1.1.4-1build3) ...\n",
"Selecting previously unselected package libxfont2:amd64.\n",
"Preparing to unpack .../17-libxfont2_1%3a2.0.5-1build1_amd64.deb ...\n",
"Unpacking libxfont2:amd64 (1:2.0.5-1build1) ...\n",
"Selecting previously unselected package libxkbfile1:amd64.\n",
"Preparing to unpack .../18-libxkbfile1_1%3a1.1.0-1build3_amd64.deb ...\n",
"Unpacking libxkbfile1:amd64 (1:1.1.0-1build3) ...\n",
"Selecting previously unselected package swig4.0.\n",
"Preparing to unpack .../19-swig4.0_4.0.2-1ubuntu1_amd64.deb ...\n",
"Unpacking swig4.0 (4.0.2-1ubuntu1) ...\n",
"Selecting previously unselected package swig.\n",
"Preparing to unpack .../20-swig_4.0.2-1ubuntu1_all.deb ...\n",
"Unpacking swig (4.0.2-1ubuntu1) ...\n",
"Selecting previously unselected package x11-xkb-utils.\n",
"Preparing to unpack .../21-x11-xkb-utils_7.7+5build4_amd64.deb ...\n",
"Unpacking x11-xkb-utils (7.7+5build4) ...\n",
"Selecting previously unselected package xfonts-encodings.\n",
"Preparing to unpack .../22-xfonts-encodings_1%3a1.0.5-0ubuntu2_all.deb ...\n",
"Unpacking xfonts-encodings (1:1.0.5-0ubuntu2) ...\n",
"Selecting previously unselected package xfonts-utils.\n",
"Preparing to unpack .../23-xfonts-utils_1%3a7.7+6build2_amd64.deb ...\n",
"Unpacking xfonts-utils (1:7.7+6build2) ...\n",
"Selecting previously unselected package xfonts-base.\n",
"Preparing to unpack .../24-xfonts-base_1%3a1.0.5_all.deb ...\n",
"Unpacking xfonts-base (1:1.0.5) ...\n",
"Selecting previously unselected package xserver-common.\n",
"Preparing to unpack .../25-xserver-common_2%3a21.1.4-2ubuntu1.7~22.04.7_all.deb ...\n",
"Unpacking xserver-common (2:21.1.4-2ubuntu1.7~22.04.7) ...\n",
"Selecting previously unselected package xvfb.\n",
"Preparing to unpack .../26-xvfb_2%3a21.1.4-2ubuntu1.7~22.04.7_amd64.deb ...\n",
"Unpacking xvfb (2:21.1.4-2ubuntu1.7~22.04.7) ...\n",
"Setting up freeglut3:amd64 (2.8.1-6) ...\n",
"Setting up libglvnd-core-dev:amd64 (1.4.0-1) ...\n",
"Setting up libice-dev:amd64 (2:1.0.10-1build2) ...\n",
"Setting up libsm-dev:amd64 (2:1.2.3-1build2) ...\n",
"Setting up libfontenc1:amd64 (1:1.1.4-1build3) ...\n",
"Setting up libxt-dev:amd64 (1:1.2.1-1) ...\n",
"Setting up libgles1:amd64 (1.4.0-1) ...\n",
"Setting up xfonts-encodings (1:1.0.5-0ubuntu2) ...\n",
"Setting up libglx-dev:amd64 (1.4.0-1) ...\n",
"Setting up libglu1-mesa:amd64 (9.0.2-1) ...\n",
"Setting up libxkbfile1:amd64 (1:1.1.0-1build3) ...\n",
"Setting up libopengl-dev:amd64 (1.4.0-1) ...\n",
"Setting up swig4.0 (4.0.2-1ubuntu1) ...\n",
"Setting up libxfont2:amd64 (1:2.0.5-1build1) ...\n",
"Setting up libgl-dev:amd64 (1.4.0-1) ...\n",
"Setting up swig (4.0.2-1ubuntu1) ...\n",
"Setting up libegl-dev:amd64 (1.4.0-1) ...\n",
"Setting up x11-xkb-utils (7.7+5build4) ...\n",
"Setting up xfonts-utils (1:7.7+6build2) ...\n",
"Setting up xfonts-base (1:1.0.5) ...\n",
"Setting up libglu1-mesa-dev:amd64 (9.0.2-1) ...\n",
"Setting up xserver-common (2:21.1.4-2ubuntu1.7~22.04.7) ...\n",
"Setting up libgles-dev:amd64 (1.4.0-1) ...\n",
"Setting up xvfb (2:21.1.4-2ubuntu1.7~22.04.7) ...\n",
"Setting up libglvnd-dev:amd64 (1.4.0-1) ...\n",
"Setting up libgl1-mesa-dev:amd64 (23.0.4-0ubuntu1~22.04.1) ...\n",
"Setting up freeglut3-dev:amd64 (2.8.1-6) ...\n",
"Processing triggers for libc-bin (2.35-0ubuntu3.4) ...\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link\n",
"\n",
"/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link\n",
"\n",
"Processing triggers for man-db (2.10.2-1) ...\n",
"Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...\n"
]
}
],
"source": [
"!apt-get update && apt-get install swig cmake ffmpeg freeglut3-dev xvfb"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kDjF3qRg7oGH"
},
"source": [
"## Clone RL Baselines3 Zoo Repo"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SCjGikdT1DFy",
"outputId": "f6dff688-dd43-4d15-b46b-9ac07f6e845e"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'rl-baselines3-zoo'...\n",
"remote: Enumerating objects: 5435, done.\u001b[K\n",
"remote: Counting objects: 100% (281/281), done.\u001b[K\n",
"remote: Compressing objects: 100% (178/178), done.\u001b[K\n",
"remote: Total 5435 (delta 158), reused 164 (delta 88), pack-reused 5154\u001b[K\n",
"Receiving objects: 100% (5435/5435), 3.86 MiB | 8.99 MiB/s, done.\n",
"Resolving deltas: 100% (3576/3576), done.\n"
]
}
],
"source": [
"!git clone https://github.com/DLR-RM/rl-baselines3-zoo"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "REMQlh-ezyVt",
"outputId": "f5d031ff-1b5e-4ff6-86dd-0ff6f47e2d5e"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/rl-baselines3-zoo\n"
]
}
],
"source": [
"%cd /content/rl-baselines3-zoo/"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5tmD_QTBqTMb"
},
"source": [
"### Install pip dependencies"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OWIDzgJTqShY",
"outputId": "077f13ca-9744-47bc-ecd7-5ab0b73b7d65"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting gym==0.26.2 (from -r requirements.txt (line 1))\n",
" Downloading gym-0.26.2.tar.gz (721 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m721.7/721.7 kB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Collecting stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1 (from -r requirements.txt (line 2))\n",
" Downloading stable_baselines3-2.3.0a1-py3-none-any.whl (181 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m181.7/181.7 kB\u001b[0m \u001b[31m12.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting sb3-contrib<3.0,>=2.3.0a1 (from -r requirements.txt (line 3))\n",
" Downloading sb3_contrib-2.3.0a1-py3-none-any.whl (80 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m80.7/80.7 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting box2d-py==2.3.8 (from -r requirements.txt (line 4))\n",
" Downloading box2d-py-2.3.8.tar.gz (374 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m374.5/374.5 kB\u001b[0m \u001b[31m14.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"Collecting pybullet (from -r requirements.txt (line 5))\n",
" Downloading pybullet-3.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (103.2 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m103.2/103.2 MB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting pybullet_envs_gymnasium>=0.4.0 (from -r requirements.txt (line 6))\n",
" Downloading pybullet_envs_gymnasium-0.4.0-py3-none-any.whl (22 kB)\n",
"Collecting optuna~=3.0 (from -r requirements.txt (line 9))\n",
" Downloading optuna-3.5.0-py3-none-any.whl (413 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m413.4/413.4 kB\u001b[0m \u001b[31m49.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 10)) (6.0.1)\n",
"Requirement already satisfied: cloudpickle>=2.2.1 in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 11)) (2.2.1)\n",
"Requirement already satisfied: plotly in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 12)) (5.15.0)\n",
"Collecting rliable>=1.0.5 (from -r requirements.txt (line 15))\n",
" Downloading rliable-1.0.8-py3-none-any.whl (19 kB)\n",
"Collecting wandb (from -r requirements.txt (line 16))\n",
" Downloading wandb-0.16.2-py3-none-any.whl (2.2 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m90.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting huggingface_sb3<4.0,>=3.0 (from -r requirements.txt (line 17))\n",
" Downloading huggingface_sb3-3.0-py3-none-any.whl (9.7 kB)\n",
"Requirement already satisfied: seaborn in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 18)) (0.13.1)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 19)) (4.66.1)\n",
"Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 20)) (13.7.0)\n",
"Requirement already satisfied: moviepy in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 21)) (1.0.3)\n",
"Collecting ruff (from -r requirements.txt (line 22))\n",
" Downloading ruff-0.1.13-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.5 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m98.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.10/dist-packages (from gym==0.26.2->-r requirements.txt (line 1)) (1.23.5)\n",
"Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym==0.26.2->-r requirements.txt (line 1)) (0.0.8)\n",
"Collecting gymnasium<0.30,>=0.28.1 (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m953.9/953.9 kB\u001b[0m \u001b[31m77.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: torch>=1.13 in /usr/local/lib/python3.10/dist-packages (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.1.0+cu121)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.5.3)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (3.7.1)\n",
"Requirement already satisfied: opencv-python in /usr/local/lib/python3.10/dist-packages (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (4.8.0.76)\n",
"Requirement already satisfied: pygame in /usr/local/lib/python3.10/dist-packages (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.5.2)\n",
"Requirement already satisfied: tensorboard>=2.9.1 in /usr/local/lib/python3.10/dist-packages (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.15.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (5.9.5)\n",
"Collecting shimmy[atari]~=1.3.0 (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading Shimmy-1.3.0-py3-none-any.whl (37 kB)\n",
"Requirement already satisfied: pillow in /usr/local/lib/python3.10/dist-packages (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (9.4.0)\n",
"Requirement already satisfied: sphinx<8,>=5 in /usr/local/lib/python3.10/dist-packages (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (5.0.2)\n",
"Collecting sphinx-autobuild (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading sphinx_autobuild-2021.3.14-py3-none-any.whl (9.9 kB)\n",
"Collecting sphinx-rtd-theme>=1.3.0 (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading sphinx_rtd_theme-2.0.0-py2.py3-none-any.whl (2.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.8/2.8 MB\u001b[0m \u001b[31m101.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting sphinxcontrib.spelling (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading sphinxcontrib_spelling-8.0.0-py3-none-any.whl (16 kB)\n",
"Collecting sphinx-copybutton (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading sphinx_copybutton-0.5.2-py3-none-any.whl (13 kB)\n",
"Requirement already satisfied: pytest in /usr/local/lib/python3.10/dist-packages (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (7.4.4)\n",
"Collecting pytest-cov (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading pytest_cov-4.1.0-py3-none-any.whl (21 kB)\n",
"Collecting pytest-env (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading pytest_env-1.1.3-py3-none-any.whl (6.2 kB)\n",
"Collecting pytest-xdist (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading pytest_xdist-3.5.0-py3-none-any.whl (42 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.0/42.0 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting mypy (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading mypy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.5 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.5/12.5 MB\u001b[0m \u001b[31m96.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting black<24,>=23.9.1 (from stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m73.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting alembic>=1.5.0 (from optuna~=3.0->-r requirements.txt (line 9))\n",
" Downloading alembic-1.13.1-py3-none-any.whl (233 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m233.4/233.4 kB\u001b[0m \u001b[31m30.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting colorlog (from optuna~=3.0->-r requirements.txt (line 9))\n",
" Downloading colorlog-6.8.0-py3-none-any.whl (11 kB)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from optuna~=3.0->-r requirements.txt (line 9)) (23.2)\n",
"Requirement already satisfied: sqlalchemy>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from optuna~=3.0->-r requirements.txt (line 9)) (2.0.24)\n",
"Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly->-r requirements.txt (line 12)) (8.2.3)\n",
"Collecting arch==5.3.0 (from rliable>=1.0.5->-r requirements.txt (line 15))\n",
" Downloading arch-5.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (905 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m905.4/905.4 kB\u001b[0m \u001b[31m57.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: scipy>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from rliable>=1.0.5->-r requirements.txt (line 15)) (1.11.4)\n",
"Requirement already satisfied: absl-py>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from rliable>=1.0.5->-r requirements.txt (line 15)) (1.4.0)\n",
"Requirement already satisfied: statsmodels>=0.11 in /usr/local/lib/python3.10/dist-packages (from arch==5.3.0->rliable>=1.0.5->-r requirements.txt (line 15)) (0.14.1)\n",
"Collecting property-cached>=1.6.4 (from arch==5.3.0->rliable>=1.0.5->-r requirements.txt (line 15))\n",
" Downloading property_cached-1.6.4-py2.py3-none-any.whl (7.8 kB)\n",
"Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 16)) (8.1.7)\n",
"Collecting GitPython!=3.1.29,>=1.0.0 (from wandb->-r requirements.txt (line 16))\n",
" Downloading GitPython-3.1.41-py3-none-any.whl (196 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.4/196.4 kB\u001b[0m \u001b[31m25.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 16)) (2.31.0)\n",
"Collecting sentry-sdk>=1.0.0 (from wandb->-r requirements.txt (line 16))\n",
" Downloading sentry_sdk-1.39.2-py2.py3-none-any.whl (254 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m254.1/254.1 kB\u001b[0m \u001b[31m30.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting docker-pycreds>=0.4.0 (from wandb->-r requirements.txt (line 16))\n",
" Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n",
"Collecting setproctitle (from wandb->-r requirements.txt (line 16))\n",
" Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 16)) (67.7.2)\n",
"Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 16)) (1.4.4)\n",
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 16)) (3.20.3)\n",
"Requirement already satisfied: huggingface-hub~=0.8 in /usr/local/lib/python3.10/dist-packages (from huggingface_sb3<4.0,>=3.0->-r requirements.txt (line 17)) (0.20.2)\n",
"Requirement already satisfied: wasabi in /usr/local/lib/python3.10/dist-packages (from huggingface_sb3<4.0,>=3.0->-r requirements.txt (line 17)) (1.1.2)\n",
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->-r requirements.txt (line 20)) (3.0.0)\n",
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->-r requirements.txt (line 20)) (2.16.1)\n",
"Requirement already satisfied: decorator<5.0,>=4.0.2 in /usr/local/lib/python3.10/dist-packages (from moviepy->-r requirements.txt (line 21)) (4.4.2)\n",
"Requirement already satisfied: proglog<=1.0.0 in /usr/local/lib/python3.10/dist-packages (from moviepy->-r requirements.txt (line 21)) (0.1.10)\n",
"Requirement already satisfied: imageio<3.0,>=2.5 in /usr/local/lib/python3.10/dist-packages (from moviepy->-r requirements.txt (line 21)) (2.31.6)\n",
"Requirement already satisfied: imageio-ffmpeg>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from moviepy->-r requirements.txt (line 21)) (0.4.9)\n",
"Collecting Mako (from alembic>=1.5.0->optuna~=3.0->-r requirements.txt (line 9))\n",
" Downloading Mako-1.3.0-py3-none-any.whl (78 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.6/78.6 kB\u001b[0m \u001b[31m12.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: typing-extensions>=4 in /usr/local/lib/python3.10/dist-packages (from alembic>=1.5.0->optuna~=3.0->-r requirements.txt (line 9)) (4.5.0)\n",
"Collecting mypy-extensions>=0.4.3 (from black<24,>=23.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n",
"Collecting pathspec>=0.9.0 (from black<24,>=23.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading pathspec-0.12.1-py3-none-any.whl (31 kB)\n",
"Requirement already satisfied: platformdirs>=2 in /usr/local/lib/python3.10/dist-packages (from black<24,>=23.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (4.1.0)\n",
"Requirement already satisfied: tomli>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from black<24,>=23.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.0.1)\n",
"Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb->-r requirements.txt (line 16)) (1.16.0)\n",
"Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb->-r requirements.txt (line 16))\n",
" Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting farama-notifications>=0.0.1 (from gymnasium<0.30,>=0.28.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub~=0.8->huggingface_sb3<4.0,>=3.0->-r requirements.txt (line 17)) (3.13.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub~=0.8->huggingface_sb3<4.0,>=3.0->-r requirements.txt (line 17)) (2023.6.0)\n",
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->-r requirements.txt (line 20)) (0.1.2)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.2.0)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (4.47.2)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.4.5)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (3.1.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2023.3.post1)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 16)) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 16)) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 16)) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 16)) (2023.11.17)\n",
"Collecting ale-py~=0.8.1 (from shimmy[atari]~=1.3.0->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading ale_py-0.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m89.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: sphinxcontrib-applehelp in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.0.7)\n",
"Requirement already satisfied: sphinxcontrib-devhelp in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.0.5)\n",
"Requirement already satisfied: sphinxcontrib-jsmath in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.0.1)\n",
"Requirement already satisfied: sphinxcontrib-htmlhelp>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.0.4)\n",
"Requirement already satisfied: sphinxcontrib-serializinghtml>=1.1.5 in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.1.9)\n",
"Requirement already satisfied: sphinxcontrib-qthelp in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.0.6)\n",
"Requirement already satisfied: Jinja2>=2.3 in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (3.1.3)\n",
"Requirement already satisfied: docutils<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (0.18.1)\n",
"Requirement already satisfied: snowballstemmer>=1.1 in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.2.0)\n",
"Requirement already satisfied: babel>=1.3 in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.14.0)\n",
"Requirement already satisfied: alabaster<0.8,>=0.7 in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (0.7.16)\n",
"Requirement already satisfied: imagesize in /usr/local/lib/python3.10/dist-packages (from sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.4.1)\n",
"Collecting sphinxcontrib-jquery<5,>=4 (from sphinx-rtd-theme>=1.3.0->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading sphinxcontrib_jquery-4.1-py2.py3-none-any.whl (121 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.1/121.1 kB\u001b[0m \u001b[31m19.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from sqlalchemy>=1.3.0->optuna~=3.0->-r requirements.txt (line 9)) (3.0.3)\n",
"Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.60.0)\n",
"Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.17.3)\n",
"Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.2.0)\n",
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (3.5.2)\n",
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (0.7.2)\n",
"Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (3.0.1)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (3.2.1)\n",
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.1.0)\n",
"Requirement already satisfied: iniconfig in /usr/local/lib/python3.10/dist-packages (from pytest->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.0.0)\n",
"Requirement already satisfied: pluggy<2.0,>=0.12 in /usr/local/lib/python3.10/dist-packages (from pytest->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.3.0)\n",
"Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /usr/local/lib/python3.10/dist-packages (from pytest->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.2.0)\n",
"Collecting coverage[toml]>=5.2.1 (from pytest-cov->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading coverage-7.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (233 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m233.2/233.2 kB\u001b[0m \u001b[31m31.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting execnet>=1.1 (from pytest-xdist->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading execnet-2.0.2-py3-none-any.whl (37 kB)\n",
"Collecting livereload (from sphinx-autobuild->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading livereload-2.6.3-py2.py3-none-any.whl (24 kB)\n",
"Collecting colorama (from sphinx-autobuild->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n",
"Collecting PyEnchant>=3.1.1 (from sphinxcontrib.spelling->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2))\n",
" Downloading pyenchant-3.2.2-py3-none-any.whl (55 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.7/55.7 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: importlib-resources in /usr/local/lib/python3.10/dist-packages (from ale-py~=0.8.1->shimmy[atari]~=1.3.0->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (6.1.1)\n",
"Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb->-r requirements.txt (line 16))\n",
" Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n",
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (5.3.2)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (0.3.0)\n",
"Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (4.9)\n",
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<2,>=0.5->tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.3.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=2.3->sphinx<8,>=5->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (2.1.3)\n",
"Requirement already satisfied: patsy>=0.5.4 in /usr/local/lib/python3.10/dist-packages (from statsmodels>=0.11->arch==5.3.0->rliable>=1.0.5->-r requirements.txt (line 15)) (0.5.6)\n",
"Requirement already satisfied: tornado in /usr/local/lib/python3.10/dist-packages (from livereload->sphinx-autobuild->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (6.3.2)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (1.3.0)\n",
"Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (0.5.1)\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard>=2.9.1->stable-baselines3[docs,extra_no_roms,tests]<3.0,>=2.3.0a1->-r requirements.txt (line 2)) (3.2.2)\n",
"Building wheels for collected packages: gym, box2d-py\n",
" Building wheel for gym (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for gym: filename=gym-0.26.2-py3-none-any.whl size=827624 sha256=15c7f697bf6bbb079560cf2d4f2f03ae15d66d87c5927a5ca7eca503895a3fb6\n",
" Stored in directory: /root/.cache/pip/wheels/b9/22/6d/3e7b32d98451b4cd9d12417052affbeeeea012955d437da1da\n",
" Building wheel for box2d-py (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for box2d-py: filename=box2d_py-2.3.8-cp310-cp310-linux_x86_64.whl size=2349118 sha256=8f1ee752a01e494744eb9d6afcb21fad5b421ef3db3f340f03e2e8227085c6a8\n",
" Stored in directory: /root/.cache/pip/wheels/47/01/d2/6a780da77ccb98b1d2facdd520a8d10838a03b590f6f8d50c0\n",
"Successfully built gym box2d-py\n",
"Installing collected packages: pybullet, farama-notifications, box2d-py, smmap, setproctitle, sentry-sdk, ruff, PyEnchant, property-cached, pathspec, mypy-extensions, Mako, livereload, gymnasium, gym, execnet, docker-pycreds, coverage, colorlog, colorama, ale-py, shimmy, pytest-xdist, pytest-env, pybullet_envs_gymnasium, mypy, gitdb, black, alembic, stable-baselines3, pytest-cov, optuna, huggingface_sb3, GitPython, wandb, sb3-contrib, arch, rliable, sphinxcontrib-jquery, sphinxcontrib.spelling, sphinx-rtd-theme, sphinx-copybutton, sphinx-autobuild\n",
" Attempting uninstall: gym\n",
" Found existing installation: gym 0.25.2\n",
" Uninstalling gym-0.25.2:\n",
" Successfully uninstalled gym-0.25.2\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"dopamine-rl 4.0.6 requires gym<=0.25.2, but you have gym 0.26.2 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0mSuccessfully installed GitPython-3.1.41 Mako-1.3.0 PyEnchant-3.2.2 ale-py-0.8.1 alembic-1.13.1 arch-5.3.0 black-23.12.1 box2d-py-2.3.8 colorama-0.4.6 colorlog-6.8.0 coverage-7.4.0 docker-pycreds-0.4.0 execnet-2.0.2 farama-notifications-0.0.4 gitdb-4.0.11 gym-0.26.2 gymnasium-0.29.1 huggingface_sb3-3.0 livereload-2.6.3 mypy-1.8.0 mypy-extensions-1.0.0 optuna-3.5.0 pathspec-0.12.1 property-cached-1.6.4 pybullet-3.2.6 pybullet_envs_gymnasium-0.4.0 pytest-cov-4.1.0 pytest-env-1.1.3 pytest-xdist-3.5.0 rliable-1.0.8 ruff-0.1.13 sb3-contrib-2.3.0a1 sentry-sdk-1.39.2 setproctitle-1.3.3 shimmy-1.3.0 smmap-5.0.1 sphinx-autobuild-2021.3.14 sphinx-copybutton-0.5.2 sphinx-rtd-theme-2.0.0 sphinxcontrib-jquery-4.1 sphinxcontrib.spelling-8.0.0 stable-baselines3-2.3.0a1 wandb-0.16.2\n"
]
}
],
"source": [
"!pip install -r requirements.txt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6gJ-pAbF7zRZ"
},
"source": [
"## Train an RL Agent\n",
"\n",
"\n",
"The train agent can be found in the `logs/` folder.\n",
"\n",
"Here we will train A2C on CartPole-v1 environment for 100 000 steps.\n",
"\n",
"\n",
"To train it on Pong (Atari), you just have to pass `--env PongNoFrameskip-v4`\n",
"\n",
"Note: You need to update `hyperparams/algo.yml` to support new environments. You can access it in the side panel of Google Colab. (see https://stackoverflow.com/questions/46986398/import-data-into-google-colaboratory)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9bIR_N7R11XI",
"outputId": "1143e04d-43cb-4a34-8a77-26bf0e0db2f6"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2024-01-18 21:50:39.455808: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-01-18 21:50:39.455863: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-01-18 21:50:39.457836: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-01-18 21:50:39.468643: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2024-01-18 21:50:40.585959: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
"========== CartPole-v1 ==========\n",
"Seed: 1491342044\n",
"Loading hyperparameters from: /content/rl-baselines3-zoo/hyperparams/a2c.yml\n",
"Default hyperparameters for environment (ones being tuned will be overridden):\n",
"OrderedDict([('ent_coef', 0.0),\n",
" ('n_envs', 8),\n",
" ('n_timesteps', 500000.0),\n",
" ('policy', 'MlpPolicy')])\n",
"Using 8 environments\n",
"Overwriting n_timesteps with n=100000\n",
"Creating test environment\n",
"Using cuda device\n",
"Log path: logs/a2c/CartPole-v1_1\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 24.5 |\n",
"| ep_rew_mean | 24.5 |\n",
"| time/ | |\n",
"| fps | 1297 |\n",
"| iterations | 100 |\n",
"| time_elapsed | 3 |\n",
"| total_timesteps | 4000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.688 |\n",
"| explained_variance | 0.249 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 99 |\n",
"| policy_loss | 1.69 |\n",
"| value_loss | 7.45 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 37 |\n",
"| ep_rew_mean | 37 |\n",
"| time/ | |\n",
"| fps | 1685 |\n",
"| iterations | 200 |\n",
"| time_elapsed | 4 |\n",
"| total_timesteps | 8000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.616 |\n",
"| explained_variance | 0.0869 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 199 |\n",
"| policy_loss | 1.31 |\n",
"| value_loss | 6.24 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 66.8 |\n",
"| ep_rew_mean | 66.8 |\n",
"| time/ | |\n",
"| fps | 1910 |\n",
"| iterations | 300 |\n",
"| time_elapsed | 6 |\n",
"| total_timesteps | 12000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.604 |\n",
"| explained_variance | 0.0361 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 299 |\n",
"| policy_loss | 1.25 |\n",
"| value_loss | 6.01 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 97.8 |\n",
"| ep_rew_mean | 97.8 |\n",
"| time/ | |\n",
"| fps | 2079 |\n",
"| iterations | 400 |\n",
"| time_elapsed | 7 |\n",
"| total_timesteps | 16000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.606 |\n",
"| explained_variance | -0.00657 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 399 |\n",
"| policy_loss | 1.13 |\n",
"| value_loss | 5.34 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 137 |\n",
"| ep_rew_mean | 137 |\n",
"| time/ | |\n",
"| fps | 2252 |\n",
"| iterations | 500 |\n",
"| time_elapsed | 8 |\n",
"| total_timesteps | 20000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.544 |\n",
"| explained_variance | 0.00982 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 499 |\n",
"| policy_loss | 1.01 |\n",
"| value_loss | 4.7 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 170 |\n",
"| ep_rew_mean | 170 |\n",
"| time/ | |\n",
"| fps | 2388 |\n",
"| iterations | 600 |\n",
"| time_elapsed | 10 |\n",
"| total_timesteps | 24000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.532 |\n",
"| explained_variance | 0.00171 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 599 |\n",
"| policy_loss | 1.1 |\n",
"| value_loss | 4.18 |\n",
"------------------------------------\n",
"Eval num_timesteps=25000, episode_reward=162.80 +/- 9.20\n",
"Episode length: 162.80 +/- 9.20\n",
"------------------------------------\n",
"| eval/ | |\n",
"| mean_ep_length | 163 |\n",
"| mean_reward | 163 |\n",
"| time/ | |\n",
"| total_timesteps | 25000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.566 |\n",
"| explained_variance | 0.0506 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 624 |\n",
"| policy_loss | 0.929 |\n",
"| value_loss | 4.01 |\n",
"------------------------------------\n",
"New best mean reward!\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 199 |\n",
"| ep_rew_mean | 199 |\n",
"| time/ | |\n",
"| fps | 2356 |\n",
"| iterations | 700 |\n",
"| time_elapsed | 11 |\n",
"| total_timesteps | 28000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.546 |\n",
"| explained_variance | 0.00541 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 699 |\n",
"| policy_loss | 0.75 |\n",
"| value_loss | 3.67 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 208 |\n",
"| ep_rew_mean | 208 |\n",
"| time/ | |\n",
"| fps | 2444 |\n",
"| iterations | 800 |\n",
"| time_elapsed | 13 |\n",
"| total_timesteps | 32000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.513 |\n",
"| explained_variance | 0.00403 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 799 |\n",
"| policy_loss | 0.936 |\n",
"| value_loss | 3.18 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 222 |\n",
"| ep_rew_mean | 222 |\n",
"| time/ | |\n",
"| fps | 2511 |\n",
"| iterations | 900 |\n",
"| time_elapsed | 14 |\n",
"| total_timesteps | 36000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.548 |\n",
"| explained_variance | 0.258 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 899 |\n",
"| policy_loss | 0.677 |\n",
"| value_loss | 2.6 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 229 |\n",
"| ep_rew_mean | 229 |\n",
"| time/ | |\n",
"| fps | 2578 |\n",
"| iterations | 1000 |\n",
"| time_elapsed | 15 |\n",
"| total_timesteps | 40000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.59 |\n",
"| explained_variance | 0.0286 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 999 |\n",
"| policy_loss | 0.721 |\n",
"| value_loss | 2.28 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 246 |\n",
"| ep_rew_mean | 246 |\n",
"| time/ | |\n",
"| fps | 2636 |\n",
"| iterations | 1100 |\n",
"| time_elapsed | 16 |\n",
"| total_timesteps | 44000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.585 |\n",
"| explained_variance | 0.000741 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1099 |\n",
"| policy_loss | 0.815 |\n",
"| value_loss | 1.88 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 259 |\n",
"| ep_rew_mean | 259 |\n",
"| time/ | |\n",
"| fps | 2663 |\n",
"| iterations | 1200 |\n",
"| time_elapsed | 18 |\n",
"| total_timesteps | 48000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.52 |\n",
"| explained_variance | 0.000329 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1199 |\n",
"| policy_loss | 0.599 |\n",
"| value_loss | 1.52 |\n",
"------------------------------------\n",
"Eval num_timesteps=50000, episode_reward=500.00 +/- 0.00\n",
"Episode length: 500.00 +/- 0.00\n",
"-------------------------------------\n",
"| eval/ | |\n",
"| mean_ep_length | 500 |\n",
"| mean_reward | 500 |\n",
"| time/ | |\n",
"| total_timesteps | 50000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.519 |\n",
"| explained_variance | -0.000137 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1249 |\n",
"| policy_loss | 0.579 |\n",
"| value_loss | 1.35 |\n",
"-------------------------------------\n",
"New best mean reward!\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 285 |\n",
"| ep_rew_mean | 285 |\n",
"| time/ | |\n",
"| fps | 2410 |\n",
"| iterations | 1300 |\n",
"| time_elapsed | 21 |\n",
"| total_timesteps | 52000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.522 |\n",
"| explained_variance | 0.000197 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1299 |\n",
"| policy_loss | 0.539 |\n",
"| value_loss | 1.18 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 304 |\n",
"| ep_rew_mean | 304 |\n",
"| time/ | |\n",
"| fps | 2464 |\n",
"| iterations | 1400 |\n",
"| time_elapsed | 22 |\n",
"| total_timesteps | 56000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.505 |\n",
"| explained_variance | 0.000204 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1399 |\n",
"| policy_loss | 0.404 |\n",
"| value_loss | 0.892 |\n",
"------------------------------------\n",
"-------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 330 |\n",
"| ep_rew_mean | 330 |\n",
"| time/ | |\n",
"| fps | 2497 |\n",
"| iterations | 1500 |\n",
"| time_elapsed | 24 |\n",
"| total_timesteps | 60000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.496 |\n",
"| explained_variance | -0.000126 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1499 |\n",
"| policy_loss | 0.366 |\n",
"| value_loss | 0.642 |\n",
"-------------------------------------\n",
"-------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 356 |\n",
"| ep_rew_mean | 356 |\n",
"| time/ | |\n",
"| fps | 2534 |\n",
"| iterations | 1600 |\n",
"| time_elapsed | 25 |\n",
"| total_timesteps | 64000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.514 |\n",
"| explained_variance | -0.000166 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1599 |\n",
"| policy_loss | 0.293 |\n",
"| value_loss | 0.426 |\n",
"-------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 381 |\n",
"| ep_rew_mean | 381 |\n",
"| time/ | |\n",
"| fps | 2573 |\n",
"| iterations | 1700 |\n",
"| time_elapsed | 26 |\n",
"| total_timesteps | 68000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.47 |\n",
"| explained_variance | 0.000167 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1699 |\n",
"| policy_loss | 0.3 |\n",
"| value_loss | 0.26 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 399 |\n",
"| ep_rew_mean | 399 |\n",
"| time/ | |\n",
"| fps | 2611 |\n",
"| iterations | 1800 |\n",
"| time_elapsed | 27 |\n",
"| total_timesteps | 72000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.505 |\n",
"| explained_variance | 1.55e-05 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1799 |\n",
"| policy_loss | 0.16 |\n",
"| value_loss | 0.137 |\n",
"------------------------------------\n",
"Eval num_timesteps=75000, episode_reward=500.00 +/- 0.00\n",
"Episode length: 500.00 +/- 0.00\n",
"------------------------------------\n",
"| eval/ | |\n",
"| mean_ep_length | 500 |\n",
"| mean_reward | 500 |\n",
"| time/ | |\n",
"| total_timesteps | 75000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.519 |\n",
"| explained_variance | 6.34e-05 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1874 |\n",
"| policy_loss | 0.0992 |\n",
"| value_loss | 0.0699 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 414 |\n",
"| ep_rew_mean | 414 |\n",
"| time/ | |\n",
"| fps | 2458 |\n",
"| iterations | 1900 |\n",
"| time_elapsed | 30 |\n",
"| total_timesteps | 76000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.492 |\n",
"| explained_variance | 5.94e-05 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1899 |\n",
"| policy_loss | 0.117 |\n",
"| value_loss | 0.0544 |\n",
"------------------------------------\n",
"-------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 431 |\n",
"| ep_rew_mean | 431 |\n",
"| time/ | |\n",
"| fps | 2469 |\n",
"| iterations | 2000 |\n",
"| time_elapsed | 32 |\n",
"| total_timesteps | 80000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.502 |\n",
"| explained_variance | -0.000255 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 1999 |\n",
"| policy_loss | 0.0482 |\n",
"| value_loss | 0.00927 |\n",
"-------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 449 |\n",
"| ep_rew_mean | 449 |\n",
"| time/ | |\n",
"| fps | 2487 |\n",
"| iterations | 2100 |\n",
"| time_elapsed | 33 |\n",
"| total_timesteps | 84000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.512 |\n",
"| explained_variance | -0.00438 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 2099 |\n",
"| policy_loss | 0.00471 |\n",
"| value_loss | 8.33e-05 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 455 |\n",
"| ep_rew_mean | 455 |\n",
"| time/ | |\n",
"| fps | 2519 |\n",
"| iterations | 2200 |\n",
"| time_elapsed | 34 |\n",
"| total_timesteps | 88000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.454 |\n",
"| explained_variance | 0.0035 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 2199 |\n",
"| policy_loss | 0.00176 |\n",
"| value_loss | 9.59e-06 |\n",
"------------------------------------\n",
"------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 463 |\n",
"| ep_rew_mean | 463 |\n",
"| time/ | |\n",
"| fps | 2550 |\n",
"| iterations | 2300 |\n",
"| time_elapsed | 36 |\n",
"| total_timesteps | 92000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.487 |\n",
"| explained_variance | 0.706 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 2299 |\n",
"| policy_loss | 2.05e-07 |\n",
"| value_loss | 3.42e-10 |\n",
"------------------------------------\n",
"-------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 468 |\n",
"| ep_rew_mean | 468 |\n",
"| time/ | |\n",
"| fps | 2578 |\n",
"| iterations | 2400 |\n",
"| time_elapsed | 37 |\n",
"| total_timesteps | 96000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.564 |\n",
"| explained_variance | 0.876 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 2399 |\n",
"| policy_loss | -1.63e-06 |\n",
"| value_loss | 1.43e-10 |\n",
"-------------------------------------\n",
"Eval num_timesteps=100000, episode_reward=500.00 +/- 0.00\n",
"Episode length: 500.00 +/- 0.00\n",
"-------------------------------------\n",
"| eval/ | |\n",
"| mean_ep_length | 500 |\n",
"| mean_reward | 500 |\n",
"| time/ | |\n",
"| total_timesteps | 100000 |\n",
"| train/ | |\n",
"| entropy_loss | -0.49 |\n",
"| explained_variance | 0.574 |\n",
"| learning_rate | 0.0007 |\n",
"| n_updates | 2499 |\n",
"| policy_loss | -7.62e-06 |\n",
"| value_loss | 7e-10 |\n",
"-------------------------------------\n",
"---------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 470 |\n",
"| ep_rew_mean | 470 |\n",
"| time/ | |\n",
"| fps | 2478 |\n",
"| iterations | 2500 |\n",
"| time_elapsed | 40 |\n",
"| total_timesteps | 100000 |\n",
"---------------------------------\n",
"Saving to logs/a2c/CartPole-v1_1\n"
]
}
],
"source": [
"!python -m rl_zoo3.train --algo a2c --env CartPole-v1 --n-timesteps 100000"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-fHBq73665yD"
},
"source": [
"#### Evaluate trained agent\n",
"\n",
"\n",
"You can remove the `--folder logs/` to evaluate pretrained agent."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Bw8YuEgU6bT3",
"outputId": "41f6da92-e433-4f0e-cd59-c05f2a0620b7"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2024-01-18 21:52:40.172515: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-01-18 21:52:40.172559: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-01-18 21:52:40.173956: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-01-18 21:52:40.184323: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2024-01-18 21:52:41.505696: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
"Loading latest experiment, id=1\n",
"Loading logs/a2c/CartPole-v1_1/CartPole-v1.zip\n",
"Episode Reward: 500.00\n",
"Episode Length 500\n",
"Episode Reward: 500.00\n",
"Episode Length 500\n",
"Episode Reward: 500.00\n",
"Episode Length 500\n",
"Episode Reward: 500.00\n",
"Episode Length 500\n",
"Episode Reward: 500.00\n",
"Episode Length 500\n",
"Episode Reward: 500.00\n",
"Episode Length 500\n",
"Episode Reward: 500.00\n",
"Episode Length 500\n",
"Episode Reward: 500.00\n",
"Episode Length 500\n",
"Episode Reward: 500.00\n",
"Episode Length 500\n",
"Episode Reward: 500.00\n",
"Episode Length 500\n",
"10 Episodes\n",
"Mean reward: 500.00 +/- 0.00\n",
"Mean episode length: 500.00 +/- 0.00\n"
]
}
],
"source": [
"!python -m rl_zoo3.enjoy --algo a2c --env CartPole-v1 --no-render --n-timesteps 5000 --folder logs/"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w5Il2J0VHPLC"
},
"source": [
"#### Tune Hyperparameters\n",
"\n",
"We use [Optuna](https://optuna.org/) for optimizing the hyperparameters.\n",
"\n",
"Tune the hyperparameters for PPO, using a tpe sampler and median pruner, 2 parallels jobs,\n",
"with a budget of 1000 trials and a maximum of 50000 steps"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "w2sC22eGHTH-",
"outputId": "a789c71d-2442-4e5f-d572-8aec15e90f19"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2024-01-18 22:11:18.859530: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-01-18 22:11:18.859582: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-01-18 22:11:18.861096: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-01-18 22:11:18.868191: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2024-01-18 22:11:19.930524: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
"========== MountainCar-v0 ==========\n",
"Seed: 387089939\n",
"Loading hyperparameters from: /content/rl-baselines3-zoo/hyperparams/a2c.yml\n",
"Default hyperparameters for environment (ones being tuned will be overridden):\n",
"OrderedDict([('ent_coef', 0.0),\n",
" ('n_envs', 16),\n",
" ('n_timesteps', 1000000.0),\n",
" ('normalize', True),\n",
" ('policy', 'MlpPolicy')])\n",
"Using 16 environments\n",
"Overwriting n_timesteps with n=50000\n",
"Doing 1 intermediate evaluations for pruning based on the number of timesteps. (1 evaluation every 100k timesteps)\n",
"Normalizing input and reward\n",
"Optimizing hyperparameters\n",
"/usr/local/lib/python3.10/dist-packages/optuna/samplers/_tpe/sampler.py:319: ExperimentalWarning: ``multivariate`` option is an experimental feature. The interface can change in the future.\n",
" warnings.warn(\n",
"Sampler: tpe - Pruner: median\n",
"\u001b[32m[I 2024-01-18 22:11:23,107]\u001b[0m A new study created in memory with name: no-name-9fb10a58-d050-4f56-b10c-4e0f52dc9b75\u001b[0m\n",
"Normalizing input and reward\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:11:49,809]\u001b[0m Trial 0 finished with value: -200.0 and parameters: {'gamma': 0.9999, 'normalize_advantage': True, 'max_grad_norm': 0.9, 'use_rms_prop': False, 'gae_lambda': 1.0, 'n_steps': 512, 'lr_schedule': 'constant', 'learning_rate': 0.00935289341861223, 'ent_coef': 0.002061739820511231, 'vf_coef': 0.5463797780036018, 'ortho_init': True, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 0 with value: -200.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:11:49,898]\u001b[0m Trial 1 finished with value: -200.0 and parameters: {'gamma': 0.99, 'normalize_advantage': True, 'max_grad_norm': 2, 'use_rms_prop': True, 'gae_lambda': 0.9, 'n_steps': 512, 'lr_schedule': 'constant', 'learning_rate': 0.005703969808574036, 'ent_coef': 3.230266765780452e-08, 'vf_coef': 0.18010479542320434, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'relu'}. Best is trial 0 with value: -200.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:12:14,756]\u001b[0m Trial 2 finished with value: -200.0 and parameters: {'gamma': 0.95, 'normalize_advantage': True, 'max_grad_norm': 0.5, 'use_rms_prop': False, 'gae_lambda': 0.92, 'n_steps': 512, 'lr_schedule': 'linear', 'learning_rate': 0.001446274742824118, 'ent_coef': 8.58046330729597e-07, 'vf_coef': 0.36913144369368633, 'ortho_init': True, 'net_arch': 'medium', 'activation_fn': 'relu'}. Best is trial 0 with value: -200.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:12:17,884]\u001b[0m Trial 3 finished with value: -200.0 and parameters: {'gamma': 0.995, 'normalize_advantage': True, 'max_grad_norm': 0.6, 'use_rms_prop': True, 'gae_lambda': 0.92, 'n_steps': 1024, 'lr_schedule': 'linear', 'learning_rate': 2.7167619314259116e-05, 'ent_coef': 3.253840991426006e-08, 'vf_coef': 0.4365096747284146, 'ortho_init': True, 'net_arch': 'medium', 'activation_fn': 'relu'}. Best is trial 0 with value: -200.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:12:37,949]\u001b[0m Trial 4 finished with value: -200.0 and parameters: {'gamma': 0.98, 'normalize_advantage': False, 'max_grad_norm': 5, 'use_rms_prop': False, 'gae_lambda': 0.98, 'n_steps': 256, 'lr_schedule': 'constant', 'learning_rate': 0.18835769661688095, 'ent_coef': 3.2584634595782944e-08, 'vf_coef': 0.6263030066179339, 'ortho_init': True, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 0 with value: -200.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:12:40,122]\u001b[0m Trial 5 finished with value: -200.0 and parameters: {'gamma': 0.9999, 'normalize_advantage': False, 'max_grad_norm': 0.9, 'use_rms_prop': False, 'gae_lambda': 1.0, 'n_steps': 64, 'lr_schedule': 'linear', 'learning_rate': 0.00013662700466793371, 'ent_coef': 0.005633586673358821, 'vf_coef': 0.8802316330558665, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'relu'}. Best is trial 0 with value: -200.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:13:01,464]\u001b[0m Trial 6 finished with value: -200.0 and parameters: {'gamma': 0.995, 'normalize_advantage': True, 'max_grad_norm': 0.9, 'use_rms_prop': True, 'gae_lambda': 0.98, 'n_steps': 32, 'lr_schedule': 'linear', 'learning_rate': 8.82822377843917e-05, 'ent_coef': 0.00033196362779212054, 'vf_coef': 0.8907524001236121, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'relu'}. Best is trial 0 with value: -200.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:13:02,499]\u001b[0m Trial 7 finished with value: -200.0 and parameters: {'gamma': 0.9, 'normalize_advantage': False, 'max_grad_norm': 5, 'use_rms_prop': False, 'gae_lambda': 0.95, 'n_steps': 64, 'lr_schedule': 'linear', 'learning_rate': 0.010714633464532266, 'ent_coef': 3.5170399241378147e-06, 'vf_coef': 0.49940342039827024, 'ortho_init': False, 'net_arch': 'small', 'activation_fn': 'tanh'}. Best is trial 0 with value: -200.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:13:24,236]\u001b[0m Trial 8 finished with value: -200.0 and parameters: {'gamma': 0.98, 'normalize_advantage': True, 'max_grad_norm': 0.6, 'use_rms_prop': False, 'gae_lambda': 0.8, 'n_steps': 128, 'lr_schedule': 'constant', 'learning_rate': 0.025689253951425062, 'ent_coef': 4.262682882392802e-06, 'vf_coef': 0.3118328726812273, 'ortho_init': True, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 0 with value: -200.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:13:25,798]\u001b[0m Trial 9 finished with value: -200.0 and parameters: {'gamma': 0.9999, 'normalize_advantage': True, 'max_grad_norm': 0.8, 'use_rms_prop': False, 'gae_lambda': 0.92, 'n_steps': 256, 'lr_schedule': 'constant', 'learning_rate': 8.519680783801616e-05, 'ent_coef': 1.8672034732482123e-06, 'vf_coef': 0.8267584940793417, 'ortho_init': True, 'net_arch': 'small', 'activation_fn': 'tanh'}. Best is trial 0 with value: -200.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:13:46,991]\u001b[0m Trial 10 finished with value: -200.0 and parameters: {'gamma': 0.9999, 'normalize_advantage': True, 'max_grad_norm': 0.6, 'use_rms_prop': True, 'gae_lambda': 0.95, 'n_steps': 64, 'lr_schedule': 'linear', 'learning_rate': 0.0004543312020615806, 'ent_coef': 5.796532003587722e-06, 'vf_coef': 0.22977436347109925, 'ortho_init': False, 'net_arch': 'small', 'activation_fn': 'relu'}. Best is trial 0 with value: -200.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:13:50,074]\u001b[0m Trial 11 finished with value: -189.2 and parameters: {'gamma': 0.95, 'normalize_advantage': True, 'max_grad_norm': 0.9, 'use_rms_prop': False, 'gae_lambda': 0.9, 'n_steps': 512, 'lr_schedule': 'constant', 'learning_rate': 0.026123831780410785, 'ent_coef': 0.04408015592323474, 'vf_coef': 0.6885515775919384, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 11 with value: -189.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:14:11,751]\u001b[0m Trial 12 finished with value: -200.0 and parameters: {'gamma': 0.995, 'normalize_advantage': True, 'max_grad_norm': 0.9, 'use_rms_prop': False, 'gae_lambda': 1.0, 'n_steps': 512, 'lr_schedule': 'constant', 'learning_rate': 0.00598379916062674, 'ent_coef': 1.0774691218759834e-05, 'vf_coef': 0.7636368348968973, 'ortho_init': True, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 11 with value: -189.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:14:14,498]\u001b[0m Trial 13 finished with value: -200.0 and parameters: {'gamma': 0.95, 'normalize_advantage': True, 'max_grad_norm': 0.9, 'use_rms_prop': True, 'gae_lambda': 1.0, 'n_steps': 512, 'lr_schedule': 'constant', 'learning_rate': 0.021876361022239534, 'ent_coef': 0.0015905082982847468, 'vf_coef': 0.6776675221223998, 'ortho_init': True, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 11 with value: -189.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:14:35,167]\u001b[0m Trial 14 finished with value: -200.0 and parameters: {'gamma': 0.95, 'normalize_advantage': True, 'max_grad_norm': 0.9, 'use_rms_prop': False, 'gae_lambda': 0.98, 'n_steps': 32, 'lr_schedule': 'constant', 'learning_rate': 0.01672856325584544, 'ent_coef': 0.00341237116937009, 'vf_coef': 0.9236934282542169, 'ortho_init': False, 'net_arch': 'small', 'activation_fn': 'tanh'}. Best is trial 11 with value: -189.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:14:38,936]\u001b[0m Trial 15 finished with value: -200.0 and parameters: {'gamma': 0.95, 'normalize_advantage': True, 'max_grad_norm': 0.9, 'use_rms_prop': False, 'gae_lambda': 0.9, 'n_steps': 512, 'lr_schedule': 'linear', 'learning_rate': 0.007419909043994535, 'ent_coef': 0.0009414566867462527, 'vf_coef': 0.8439767427764007, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 11 with value: -189.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:15:03,802]\u001b[0m Trial 16 finished with value: -200.0 and parameters: {'gamma': 0.95, 'normalize_advantage': False, 'max_grad_norm': 0.9, 'use_rms_prop': False, 'gae_lambda': 0.98, 'n_steps': 1024, 'lr_schedule': 'constant', 'learning_rate': 0.03870526271119802, 'ent_coef': 0.001497881990578423, 'vf_coef': 0.18064999855866803, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 11 with value: -189.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:15:03,970]\u001b[0m Trial 17 finished with value: -188.4 and parameters: {'gamma': 0.9999, 'normalize_advantage': True, 'max_grad_norm': 5, 'use_rms_prop': False, 'gae_lambda': 1.0, 'n_steps': 512, 'lr_schedule': 'constant', 'learning_rate': 0.0009578157197135375, 'ent_coef': 3.9038616977557485e-05, 'vf_coef': 0.14833949631793475, 'ortho_init': True, 'net_arch': 'small', 'activation_fn': 'tanh'}. Best is trial 17 with value: -188.4.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:15:26,219]\u001b[0m Trial 18 finished with value: -147.2 and parameters: {'gamma': 0.9999, 'normalize_advantage': True, 'max_grad_norm': 0.6, 'use_rms_prop': False, 'gae_lambda': 1.0, 'n_steps': 128, 'lr_schedule': 'constant', 'learning_rate': 0.0024782704751925423, 'ent_coef': 0.002471734383653773, 'vf_coef': 0.5971160227193472, 'ortho_init': True, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 18 with value: -147.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:15:27,304]\u001b[0m Trial 19 finished with value: -196.0 and parameters: {'gamma': 0.9999, 'normalize_advantage': True, 'max_grad_norm': 1, 'use_rms_prop': False, 'gae_lambda': 0.95, 'n_steps': 256, 'lr_schedule': 'constant', 'learning_rate': 0.0030806710199832543, 'ent_coef': 2.265399710111172e-05, 'vf_coef': 0.02860553419047976, 'ortho_init': True, 'net_arch': 'small', 'activation_fn': 'tanh'}. Best is trial 18 with value: -147.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:15:48,977]\u001b[0m Trial 20 finished with value: -200.0 and parameters: {'gamma': 0.9999, 'normalize_advantage': True, 'max_grad_norm': 0.6, 'use_rms_prop': False, 'gae_lambda': 0.98, 'n_steps': 128, 'lr_schedule': 'constant', 'learning_rate': 0.007664547346527178, 'ent_coef': 0.0965623546069676, 'vf_coef': 0.386869196453981, 'ortho_init': True, 'net_arch': 'small', 'activation_fn': 'tanh'}. Best is trial 18 with value: -147.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:15:54,727]\u001b[0m Trial 21 finished with value: -200.0 and parameters: {'gamma': 0.9999, 'normalize_advantage': False, 'max_grad_norm': 5, 'use_rms_prop': False, 'gae_lambda': 1.0, 'n_steps': 2048, 'lr_schedule': 'constant', 'learning_rate': 0.0009693831301339943, 'ent_coef': 0.00878801007764484, 'vf_coef': 0.2386299344399865, 'ortho_init': True, 'net_arch': 'small', 'activation_fn': 'tanh'}. Best is trial 18 with value: -147.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:16:13,360]\u001b[0m Trial 22 finished with value: -200.0 and parameters: {'gamma': 0.999, 'normalize_advantage': True, 'max_grad_norm': 5, 'use_rms_prop': False, 'gae_lambda': 0.98, 'n_steps': 512, 'lr_schedule': 'linear', 'learning_rate': 0.00010215860049699425, 'ent_coef': 4.137910155394835e-06, 'vf_coef': 0.35592198620052723, 'ortho_init': True, 'net_arch': 'small', 'activation_fn': 'tanh'}. Best is trial 18 with value: -147.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:16:18,687]\u001b[0m Trial 23 finished with value: -200.0 and parameters: {'gamma': 0.95, 'normalize_advantage': False, 'max_grad_norm': 1, 'use_rms_prop': True, 'gae_lambda': 0.9, 'n_steps': 16, 'lr_schedule': 'constant', 'learning_rate': 0.01397783766726145, 'ent_coef': 0.01440753943248514, 'vf_coef': 0.9489702721700115, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 18 with value: -147.2.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:16:36,897]\u001b[0m Trial 24 finished with value: -143.4 and parameters: {'gamma': 0.9999, 'normalize_advantage': True, 'max_grad_norm': 0.6, 'use_rms_prop': False, 'gae_lambda': 1.0, 'n_steps': 16, 'lr_schedule': 'constant', 'learning_rate': 0.00030961294978402277, 'ent_coef': 0.022504720662278196, 'vf_coef': 0.9775763830430406, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 24 with value: -143.4.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:16:42,827]\u001b[0m Trial 25 finished with value: -200.0 and parameters: {'gamma': 0.95, 'normalize_advantage': True, 'max_grad_norm': 2, 'use_rms_prop': False, 'gae_lambda': 1.0, 'n_steps': 512, 'lr_schedule': 'constant', 'learning_rate': 0.0012849357990049664, 'ent_coef': 1.6935427753702717e-08, 'vf_coef': 0.31379552255146614, 'ortho_init': True, 'net_arch': 'small', 'activation_fn': 'tanh'}. Best is trial 24 with value: -143.4.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:17:01,366]\u001b[0m Trial 26 finished with value: -200.0 and parameters: {'gamma': 0.99, 'normalize_advantage': True, 'max_grad_norm': 0.6, 'use_rms_prop': False, 'gae_lambda': 1.0, 'n_steps': 16, 'lr_schedule': 'constant', 'learning_rate': 0.0024890801362722455, 'ent_coef': 0.0019898572502017216, 'vf_coef': 0.9669089532815947, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'relu'}. Best is trial 24 with value: -143.4.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:17:05,708]\u001b[0m Trial 27 finished with value: -118.0 and parameters: {'gamma': 0.9999, 'normalize_advantage': True, 'max_grad_norm': 0.6, 'use_rms_prop': False, 'gae_lambda': 1.0, 'n_steps': 16, 'lr_schedule': 'constant', 'learning_rate': 0.0013596027226751505, 'ent_coef': 0.0010470298270232505, 'vf_coef': 0.8090181644454333, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 27 with value: -118.0.\u001b[0m\n",
"Normalizing input and reward\n",
"Normalization activated: {'norm_reward': False, 'training': False}\n",
"\u001b[32m[I 2024-01-18 22:17:23,924]\u001b[0m Trial 28 finished with value: -200.0 and parameters: {'gamma': 0.9999, 'normalize_advantage': True, 'max_grad_norm': 1, 'use_rms_prop': False, 'gae_lambda': 0.99, 'n_steps': 128, 'lr_schedule': 'constant', 'learning_rate': 7.63961189757746e-05, 'ent_coef': 0.019839451715683458, 'vf_coef': 0.984137558925232, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 27 with value: -118.0.\u001b[0m\n",
"\u001b[32m[I 2024-01-18 22:17:25,778]\u001b[0m Trial 29 finished with value: -118.6 and parameters: {'gamma': 0.9999, 'normalize_advantage': True, 'max_grad_norm': 0.6, 'use_rms_prop': False, 'gae_lambda': 0.92, 'n_steps': 16, 'lr_schedule': 'linear', 'learning_rate': 0.00013272671445433383, 'ent_coef': 0.0008801199346944073, 'vf_coef': 0.616132941436031, 'ortho_init': False, 'net_arch': 'medium', 'activation_fn': 'tanh'}. Best is trial 27 with value: -118.0.\u001b[0m\n",
"Number of finished trials: 30\n",
"Best trial:\n",
"Value: -118.0\n",
"Params: \n",
" gamma: 0.9999\n",
" normalize_advantage: True\n",
" max_grad_norm: 0.6\n",
" use_rms_prop: False\n",
" gae_lambda: 1.0\n",
" n_steps: 16\n",
" lr_schedule: constant\n",
" learning_rate: 0.0013596027226751505\n",
" ent_coef: 0.0010470298270232505\n",
" vf_coef: 0.8090181644454333\n",
" ortho_init: False\n",
" net_arch: medium\n",
" activation_fn: tanh\n",
"Writing report to logs/a2c/report_MountainCar-v0_30-trials-50000-tpe-median_1705616245\n",
"/usr/bin/xdg-open: 882: x-www-browser: not found\n",
"/usr/bin/xdg-open: 882: firefox: not found\n",
"/usr/bin/xdg-open: 882: iceweasel: not found\n",
"/usr/bin/xdg-open: 882: seamonkey: not found\n",
"/usr/bin/xdg-open: 882: mozilla: not found\n",
"/usr/bin/xdg-open: 882: epiphany: not found\n",
"/usr/bin/xdg-open: 882: konqueror: not found\n",
"/usr/bin/xdg-open: 882: chromium: not found\n",
"/usr/bin/xdg-open: 882: chromium-browser: not found\n",
"/usr/bin/xdg-open: 882: google-chrome: not found\n",
"/usr/bin/xdg-open: 882: www-browser: not found\n",
"/usr/bin/xdg-open: 882: links2: not found\n",
"/usr/bin/xdg-open: 882: elinks: not found\n",
"/usr/bin/xdg-open: 882: links: not found\n",
"/usr/bin/xdg-open: 882: lynx: not found\n",
"/usr/bin/xdg-open: 882: w3m: not found\n",
"xdg-open: no method available for opening 'http://127.0.0.1:45125'\n",
"Traceback (most recent call last):\n",
" File \"/usr/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n",
" return _run_code(code, main_globals, None,\n",
" File \"/usr/lib/python3.10/runpy.py\", line 86, in _run_code\n",
" exec(code, run_globals)\n",
" File \"/content/rl-baselines3-zoo/rl_zoo3/train.py\", line 279, in <module>\n",
" train()\n",
" File \"/content/rl-baselines3-zoo/rl_zoo3/train.py\", line 275, in train\n",
" exp_manager.hyperparameters_optimization()\n",
" File \"/content/rl-baselines3-zoo/rl_zoo3/exp_manager.py\", line 915, in hyperparameters_optimization\n",
" fig1.show()\n",
" File \"/usr/local/lib/python3.10/dist-packages/plotly/basedatatypes.py\", line 3409, in show\n",
" File \"/usr/local/lib/python3.10/dist-packages/plotly/io/_renderers.py\", line 403, in show\n",
" renderers._perform_external_rendering(fig_dict, renderers_string=renderer, **kwargs)\n",
" File \"/usr/local/lib/python3.10/dist-packages/plotly/io/_renderers.py\", line 340, in _perform_external_rendering\n",
" renderer.render(fig_dict)\n",
" File \"/usr/local/lib/python3.10/dist-packages/plotly/io/_base_renderers.py\", line 761, in render\n",
" open_html_in_browser(html, self.using, self.new, self.autoraise)\n",
" File \"/usr/local/lib/python3.10/dist-packages/plotly/io/_base_renderers.py\", line 710, in open_html_in_browser\n",
" server.handle_request()\n",
" File \"/usr/lib/python3.10/socketserver.py\", line 294, in handle_request\n",
" ready = selector.select(timeout)\n",
" File \"/usr/lib/python3.10/selectors.py\", line 416, in select\n",
" fd_event_list = self._selector.poll(timeout)\n",
"KeyboardInterrupt\n"
]
}
],
"source": [
"!python -m rl_zoo3.train --algo a2c --env MountainCar-v0 -n 50000 -optimize --n-trials 30 --n-jobs 2 --sampler tpe --pruner median"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xVm9QPNVwKXN"
},
"source": [
"### Record a Video"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "MPyfQxD5z26J"
},
"outputs": [],
"source": [
"# Set up display; otherwise rendering will fail\n",
"import os\n",
"os.system(\"Xvfb :1 -screen 0 1024x768x24 &\")\n",
"os.environ['DISPLAY'] = ':1'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qBuUfnzI8DN6"
},
"source": [
"### Display the video"
]
},
{
"cell_type": "code",
"source": [
"%ls logs/a2c/MountainCar-v0_1/"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VSxH3Q7fq68F",
"outputId": "2b6b438b-1601-4868-d64f-c9a9dba89f20"
},
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0.monitor.csv \u001b[0m\u001b[01;34mMountainCar-v0\u001b[0m/\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "ZC3OTfpf8CXu"
},
"outputs": [],
"source": [
"import base64\n",
"from pathlib import Path\n",
"\n",
"from IPython import display as ipythondisplay\n",
"\n",
"\n",
"def show_videos(video_path=\"\", prefix=\"\"):\n",
" \"\"\"\n",
" Taken from https://github.com/eleurent/highway-env\n",
"\n",
" :param video_path: (str) Path to the folder containing videos\n",
" :param prefix: (str) Filter the video, showing only the only starting with this prefix\n",
" \"\"\"\n",
" html = []\n",
" for mp4 in Path(video_path).glob(\"{}*.mp4\".format(prefix)):\n",
" video_b64 = base64.b64encode(mp4.read_bytes())\n",
" html.append(\n",
" \"\"\"<video alt=\"{}\" autoplay\n",
" loop controls style=\"height: 400px;\">\n",
" <source src=\"data:video/mp4;base64,{}\" type=\"video/mp4\" />\n",
" </video>\"\"\".format(\n",
" mp4, video_b64.decode(\"ascii\")\n",
" )\n",
" )\n",
" ipythondisplay.display(ipythondisplay.HTML(data=\"<br>\".join(html)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oKOjFuwK9HI0"
},
"outputs": [],
"source": [
"show_videos(video_path='logs/a2c/CartPole-v1_1/videos/', prefix='')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RjdpP0HE8D2p"
},
"source": [
"### Continue Training\n",
"\n",
"Here, we will continue training of the previous model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zgMZQJJF6u1C"
},
"outputs": [],
"source": [
"!python -m rl_zoo3.train --algo a2c --env CartPole-v1 --n-timesteps 50000 -i logs/a2c/CartPole-v1_1/CartPole-v1.zip"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GSaoyiAE8cVj"
},
"outputs": [],
"source": [
"!python -m rl_zoo3.enjoy --algo a2c --env CartPole-v1 --no-render --n-timesteps 1000 --folder logs/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jL9u4I1H-48O"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "rl-baselines-zoo.ipynb",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"vscode": {
"interpreter": {
"hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment