"# Getting our feet wet with Stable Baselines 3"
"import gymnasium as gym\n",
"import numpy as np\n",
"from stable_baselines3 import PPO\n",
"from stable_baselines3.ppo.policies import MlpPolicy\n",
"import imageio"
"env = gym.make(\"CartPole-v1\")\n",
"model = PPO(MlpPolicy, env, verbose=0)"
" (features_extractor): FlattenExtractor(\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" )\n",
" (pi_features_extractor): FlattenExtractor(\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" )\n",
" (vf_features_extractor): FlattenExtractor(\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" )\n",
" (mlp_extractor): MlpExtractor(\n",
" (policy_net): Sequential(\n",
" (0): Linear(in_features=4, out_features=64, bias=True)\n",
" (1): Tanh()\n",
" (2): Linear(in_features=64, out_features=64, bias=True)\n",
" (3): Tanh()\n",
" )\n",
" (value_net): Sequential(\n",
" (0): Linear(in_features=4, out_features=64, bias=True)\n",
" (1): Tanh()\n",
" (2): Linear(in_features=64, out_features=64, bias=True)\n",
" (3): Tanh()\n",
" )\n",
" )\n",
" (action_net): Linear(in_features=64, out_features=2, bias=True)\n",
" (value_net): Linear(in_features=64, out_features=1, bias=True)\n",
"mean_reward: 9.81 +/- 0.52\n"
"# evaluating the model before training\n",
"from stable_baselines3.common.evaluation import evaluate_policy\n",
"mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, warn=False)\n",
"print(f\"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}\")"
"CPU times: user 11.2 s, sys: 162 ms, total: 11.4 s\n",
"Wall time: 11.4 s\n"
"<stable_baselines3.ppo.ppo.PPO at 0x79850070dd20>"
"mean_reward: 466.10 +/- 63.32\n",
"CPU times: user 19.8 s, sys: 2.89 ms, total: 19.8 s\n",
"Wall time: 19.7 s\n"
"mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, warn=False)\n",
"print(f\"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}\")"
"mean_reward: 500.00 +/- 0.00\n"
"mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, warn=False)\n",
"print(f\"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}\")"
"With just 20k steps we maxed out the score :)\n",
"That makes for a really boring recording, but this is how to make one."
"def record_video(env_name='CartPole-v1', ext='mp4'):\n",
" env = gym.make(env_name, render_mode=\"rgb_array\")\n",
" obs, info = env.reset()\n",
" images = [\n",
" env.render()\n",
" ]\n",
" ep_len = 0\n",
" while True:\n",
" a, _states = model.predict(obs, deterministic=True)\n",
" obs, _, terminated, truncated, _ = env.step(a)\n",
" images.append(env.render())\n",
" ep_len += 1\n",
" if terminated or truncated:\n",
" obs, info = env.reset()\n",
" break\n",
" print(f'Episode length: {ep_len}')\n",
" fn = f'{env_name}.{ext}'\n",
" imageio.mimsave(fn, images)\n",
" print(f'Video saved to: {fn}')"
"/home/radek/miniforge3/envs/cleanrl/lib/python3.10/site-packages/pkg_resources/ DeprecationWarning: pkg_resources is deprecated as an API\n",
" warnings.warn(\"pkg_resources is deprecated as an API\", DeprecationWarning)\n",
"/home/radek/miniforge3/envs/cleanrl/lib/python3.10/site-packages/pkg_resources/ DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.\n",
"Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See\n",
" declare_namespace(pkg)\n",
"/home/radek/miniforge3/envs/cleanrl/lib/python3.10/site-packages/pkg_resources/ DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('mpl_toolkits')`.\n",
"Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See\n",
" declare_namespace(pkg)\n",
"IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (600, 400) to (608, 400) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n"
"name": "stdout",
"output_type": "stream",
"text": [
"Episode length: 500\n"
"name": "stderr",
"output_type": "stream",
"text": [
"[swscaler @ 0x704ff00] Warning: data is not aligned! This can lead to a speed loss\n"
"name": "stdout",
"output_type": "stream",
"text": [
"Video saved to: CartPole-v1.mp4\n"
"# Training on a more complex problem"
"Using cuda device\n",
"Creating environment from the given name 'Acrobot-v1'\n",
"Wrapping the env with a `Monitor` wrapper\n",
"Wrapping the env in a DummyVecEnv.\n"
"env_name = \"Acrobot-v1\"\n",
"model = PPO('MlpPolicy', env_name, verbose=1)"
"IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (500, 500) to (512, 512) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n"
"Episode length: 500\n"
"[swscaler @ 0x67d5d80] Warning: data is not aligned! This can lead to a speed loss\n"
"Video saved to: Acrobot-v1.mp4\n"
"(-500.0, 0.0)"
"evaluate_policy(model, gym.make(env_name), n_eval_episodes=100, warn=False)"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080\"> 100%</span> <span style=\"color: #729c1f; text-decoration-color: #729c1f\">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span style=\"color: #008000; text-decoration-color: #008000\">20,478/20,000 </span> [ <span style=\"color: #808000; text-decoration-color: #808000\">0:00:24</span> &lt; <span style=\"color: #008080; text-decoration-color: #008080\">0:00:00</span> , <span style=\"color: #800000; text-decoration-color: #800000\">817 it/s</span> ]\n",
"text/plain": [
"\u001b[35m 100%\u001b[0m \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20,478/20,000 \u001b[0m [ \u001b[33m0:00:24\u001b[0m < \u001b[36m0:00:00\u001b[0m , \u001b[31m817 it/s\u001b[0m ]\n"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
"text/plain": []
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"text/plain": [
"CPU times: user 26 s, sys: 126 ms, total: 26.1 s\n",
"Wall time: 26 s\n"
"<stable_baselines3.ppo.ppo.PPO at 0x7984d1913a30>"
"model.learn(20_000, progress_bar=True)"
"(-80.96, 12.161348609426506)"
"evaluate_policy(model, gym.make(env_name), n_eval_episodes=100, warn=False)"
"# are all these wrappers changing the behavior of the env? Let's try with a function I wrote myself\n",
"def evaluate(n_runs=100):\n",
" env = gym.make(env_name, render_mode=\"rgb_array\")\n",
" obs, info = env.reset()\n",
" ep_lens = []\n",
" rewardss = []\n",
" for num_runs in range(n_runs):\n",
" ep_len = 0\n",
" rewards = 0\n",
" while True:\n",
" a, _states = model.predict(obs, deterministic=True)\n",
" obs, reward, terminated, truncated, _ = env.step(a)\n",
" ep_len += 1\n",
" rewards += reward\n",
" if terminated or truncated:\n",
" obs, info = env.reset()\n",
" break\n",
" ep_lens.append(ep_len)\n",
" rewardss.append(rewards)\n",
" return print(f'Mean episode length: {np.mean(ep_lens)}\\tMean reward: {np.mean(rewardss)}')"
"> The goal is to have the free end reach a designated target height in as few steps as possible, and as such all steps that do not reach the goal incur a reward of -1. Achieving the target height results in termination with a reward of 0. The reward threshold is -100."
"Mean episode length: 82.91\tMean reward: -81.91\n"
"# seems there are no fundamental differences introduced by using these wrappers\n",
"Episode length: 87\n",
"Video saved to: Acrobot-v1.gif\n"
"record_video(env_name, 'gif')"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080\"> 100%</span> <span style=\"color: #729c1f; text-decoration-color: #729c1f\">━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━</span> <span style=\"color: #008000; text-decoration-color: #008000\">20,374/20,000 </span> [ <span style=\"color: #808000; text-decoration-color: #808000\">0:00:26</span> &lt; <span style=\"color: #008080; text-decoration-color: #008080\">0:00:00</span> , <span style=\"color: #800000; text-decoration-color: #800000\">765 it/s</span> ]\n",
"text/plain": [
"\u001b[35m 100%\u001b[0m \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20,374/20,000 \u001b[0m [ \u001b[33m0:00:26\u001b[0m < \u001b[36m0:00:00\u001b[0m , \u001b[31m765 it/s\u001b[0m ]\n"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
"text/plain": []
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"text/plain": [
"CPU times: user 27.9 s, sys: 218 ms, total: 28.1 s\n",
"Wall time: 27.9 s\n"
"<stable_baselines3.ppo.ppo.PPO at 0x7984d1913a30>"
"model.learn(20_000, progress_bar=True)"
"Mean episode length: 86.34\tMean reward: -85.34\n"
"Episode length: 78\n",
"Video saved to: Acrobot-v1.gif\n"
"record_video(env_name, 'gif')"
