Skip to content

Instantly share code, notes, and snippets.

@tomtung
Last active November 22, 2020 22:31
Show Gist options
  • Save tomtung/e90e9e046fd158d42bd225abd45e8890 to your computer and use it in GitHub Desktop.
Save tomtung/e90e9e046fd158d42bd225abd45e8890 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deep Q-Learning for Solving Unity's \"Banana Collectors\" Problem\n",
"\n",
"In this notebook we report how we experimented using deep Q-learning to solve a modified version of Unitfy's \"Banana Collectors\" environment, where the agent needs to navigate a 3D space to collect as many yellow bananas as possible while trying to avoid blue bananas.\n",
"\n",
"This notebook contains all the code for training and running the agent.\n",
"\n",
"A demo of a trained agent is shown in the gif below:\n",
"\n",
"![demo](https://user-images.githubusercontent.com/513210/99918973-558ba680-2ccf-11eb-8503-a78689a3a029.gif)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Environment Setup\n",
"\n",
"The dependencies can be set up by following the instructions from the [DRLND repo](https://github.com/udacity/deep-reinforcement-learning#dependencies). Once it's done, the following imports should work:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from unityagents import UnityEnvironment\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that to make GPU training work on our machine, the following version of PyTorch is used:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.7.0+cu110'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.__version__"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda', index=0)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"device"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Additionally, we also need to download the pre-built Unity environment, which is available for different platforms:\n",
"\n",
"- [Linux](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Linux.zip)\n",
"- [Mac OSX](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana.app.zip)\n",
"- [Windows (32-bit)](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Windows_x86.zip)\n",
"- [Windows (64-bit)](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P1/Banana/Banana_Windows_x86_64.zip)\n",
"\n",
"Once downloaded and extracted, please set the path beolow according, e.g.\n",
"\n",
"- **Mac**: `\"path/to/Banana.app\"`\n",
"- **Windows** (x86): `\"path/to/Banana_Windows_x86/Banana.exe\"`\n",
"- **Windows** (x86_64): `\"path/to/Banana_Windows_x86_64/Banana.exe\"`\n",
"- **Linux** (x86): `\"path/to/Banana_Linux/Banana.x86\"`\n",
"- **Linux** (x86_64): `\"path/to/Banana_Linux/Banana.x86_64\"`\n",
"- **Linux** (x86, headless): `\"path/to/Banana_Linux_NoVis/Banana.x86\"`\n",
"- **Linux** (x86_64, headless): `\"path/to/Banana_Linux_NoVis/Banana.x86_64\"`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"ENV_PATH = \"../Banana_Linux/Banana.x86_64\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If set up correctly, we should be able to initialize the environment:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:unityagents:\n",
"'Academy' started successfully!\n",
"Unity Academy name: Academy\n",
" Number of Brains: 1\n",
" Number of External Brains : 1\n",
" Lesson number : 0\n",
" Reset Parameters :\n",
"\t\t\n",
"Unity brain name: BananaBrain\n",
" Number of Visual Observations (per agent): 0\n",
" Vector Observation space type: continuous\n",
" Vector Observation space size (per agent): 37\n",
" Number of stacked Vector Observation: 1\n",
" Vector Action space type: discrete\n",
" Vector Action space size (per agent): 4\n",
" Vector Action descriptions: , , , \n"
]
}
],
"source": [
"env = UnityEnvironment(file_name=ENV_PATH)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# get the default brain\n",
"brain_name = env.brain_names[0]\n",
"brain = env.brains[brain_name]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Environment description\n",
"\n",
"The simulation contains a single agent that navigates a large environment. At each time step, it has four actions at its disposal:\n",
"- `0` - walk forward \n",
"- `1` - walk backward\n",
"- `2` - turn left\n",
"- `3` - turn right\n",
"\n",
"The state space has `37` dimensions and contains the agent's velocity, along with ray-based perception of objects around agent's forward direction. A reward of `+1` is provided for collecting a yellow banana, and a reward of `-1` is provided for collecting a blue banana."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of actions: 4\n",
"States have length: 37\n"
]
}
],
"source": [
"# reset the environment\n",
"env_info = env.reset(train_mode=True)[brain_name]\n",
"\n",
"# number of actions\n",
"action_size = brain.vector_action_space_size\n",
"print('Number of actions:', action_size)\n",
"\n",
"# examine the state space \n",
"state = env_info.vector_observations[0]\n",
"state_size = len(state)\n",
"print('States have length:', state_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Methodology & Implementation\n",
"\n",
"To solve this toy problem, we experimented with deep Q-learning as described in Mnih et al. (2015). Additionally, we also added the following extensions:\n",
"\n",
"- Prioritized experience replay\n",
"- Noisy network\n",
"- Double Q-learning\n",
"- Dueling network\n",
"\n",
"The details and code are presented in the following sub-sections."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Hyper-parameters\n",
"\n",
"The hyper-parameters we used is shown as the default values in the following data class:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AgentConfig(batch_size=256, learning_rate=0.0003, replay_buffer_size=100000, target_params_update_every=4, target_params_update_ratio=0.005, reward_discount_factor=0.99, reward_unroll_steps=5)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"!pip install -q dataclasses\n",
"\n",
"from dataclasses import dataclass\n",
"\n",
"\n",
"@dataclass\n",
"class AgentConfig:\n",
" batch_size: int = 256\n",
" learning_rate: float = 3e-4\n",
" replay_buffer_size: int = 100_000\n",
" target_params_update_every: int = 4\n",
" target_params_update_ratio: float = 0.005\n",
" reward_discount_factor: float = 0.99\n",
" reward_unroll_steps: int = 5\n",
"\n",
"\n",
"AgentConfig()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prioritized experience replay\n",
"\n",
"Here we use prioritized experience replay as described in Schaul et al. (2015). Specifically, we implemented the proportional priorization variant with the sum-tree data structure.\n",
"\n",
"The code for the replay buffer is as follows:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Generator, List, Tuple, Optional\n",
"\n",
"\n",
"class ProportionallyPrioritizedReplayBuffer:\n",
" \"\"\"A proportionally prioritized replay buffer implemented with sum-tree.\"\"\"\n",
"\n",
" _curr_index: int\n",
" _size: int\n",
" _max_priority: float\n",
" _sum_tree: List[float]\n",
" _priorities: List[float]\n",
" _samples: List[Any]\n",
"\n",
" def __init__(self, buffer_size: int):\n",
" assert buffer_size > 1\n",
" self._curr_index = 0\n",
" self._size = 0\n",
" self._max_priority = 1.0\n",
" self._sum_tree = [0] * (2 ** (math.floor(math.log2(buffer_size - 1)) + 1) - 1)\n",
" self._priorities = [0] * buffer_size\n",
" self._samples = [None] * buffer_size\n",
"\n",
" def _ancestor_indices(self, sample_index: int) -> Generator[int, None, None]:\n",
" assert 0 <= sample_index <= len(self._samples)\n",
" index = sample_index + len(self._sum_tree)\n",
" while index > 0:\n",
" index = (index - 1) // 2\n",
" yield index\n",
"\n",
" @staticmethod\n",
" def _children_indices(index: int) -> Tuple[int, int]:\n",
" # Note that it could go out-of-bounds for the sum tree array\n",
" left_index = index * 2 + 1\n",
" right_index = left_index + 1\n",
" return left_index, right_index\n",
"\n",
" def _set_priority(self, sample_index: int, priority: float):\n",
" assert priority > 0, \"Weights must be non-negative\"\n",
" delta = priority - self._priorities[sample_index]\n",
" self._priorities[sample_index] = priority\n",
" for index in self._ancestor_indices(sample_index):\n",
" self._sum_tree[index] += delta\n",
"\n",
" self._max_priority = max(self._max_priority, priority)\n",
"\n",
" def _set_sample(self, sample_index: int, sample: Any, priority: float):\n",
" self._set_priority(sample_index, priority)\n",
" self._samples[sample_index] = sample\n",
"\n",
" class _SampleHandle:\n",
" _parent: \"ProportionallyPrioritizedReplayBuffer\"\n",
" _index: int\n",
"\n",
" def __init__(self, parent: \"ProportionallyPrioritizedReplayBuffer\", index: int):\n",
" assert 0 <= index <= len(parent._samples)\n",
" self._parent = parent\n",
" self._index = index\n",
"\n",
" @property\n",
" def value(self) -> Any:\n",
" return self._parent._samples[self._index]\n",
"\n",
" @property\n",
" def priority(self) -> float:\n",
" return self._parent._priorities[self._index]\n",
"\n",
" @priority.setter\n",
" def priority(self, priority: float):\n",
" self._parent._set_priority(self._index, priority)\n",
"\n",
" def reset(self, value: Any, priority: float):\n",
" self._parent._set_sample(self._index, value, priority)\n",
"\n",
" def add(self, value: Any, priority: Optional[float] = None):\n",
" \"\"\"Add a new sample.\"\"\"\n",
" if priority is None:\n",
" priority = self._max_priority\n",
"\n",
" self._SampleHandle(self, self._curr_index).reset(value, priority)\n",
"\n",
" buffer_size = len(self._samples)\n",
" self._curr_index = (self._curr_index + 1) % buffer_size\n",
" self._size = min(self._size + 1, buffer_size)\n",
"\n",
" @property\n",
" def priority_sum(self):\n",
" return self._sum_tree[0]\n",
"\n",
" @property\n",
" def max_priority(self):\n",
" return self._max_priority\n",
"\n",
" def sample_single(self, query: Optional[float] = None) -> _SampleHandle:\n",
" \"\"\"Draw a sample.\"\"\"\n",
" assert self.priority_sum > 0.0, \"Nothing has been added\"\n",
"\n",
" if query is None:\n",
" query = random.random()\n",
"\n",
" assert 0.0 <= query <= 1.0\n",
" target = self.priority_sum * query\n",
" index = 0\n",
" while True:\n",
" assert 0.0 <= target <= self._sum_tree[index]\n",
" index_l, index_r = self._children_indices(index)\n",
" assert (index_l < len(self._sum_tree)) == (index_r < len(self._sum_tree))\n",
" if index_l >= len(self._sum_tree):\n",
" index_l -= len(self._sum_tree)\n",
" index_r -= len(self._sum_tree)\n",
" break\n",
"\n",
" sum_l = self._sum_tree[index_l]\n",
" if target <= sum_l:\n",
" index = index_l\n",
" else:\n",
" target -= sum_l\n",
" index = index_r\n",
"\n",
" assert index_l < len(self._priorities)\n",
" if target <= self._priorities[index_l]:\n",
" index = index_l\n",
" else:\n",
" assert index_r < len(self._priorities)\n",
" index = index_r\n",
"\n",
" return self._SampleHandle(self, index)\n",
"\n",
" def sample_batch(self, batch_size: int) -> List[_SampleHandle]:\n",
" \"\"\"Draw a stratified batch of samples with the given size.\"\"\"\n",
" end_points = np.linspace(0.0, 1.0, batch_size + 1).tolist()\n",
" return [\n",
" self.sample_single(query=random.uniform(l, r))\n",
" for l, r in zip(end_points[:-1], end_points[1:])\n",
" ]\n",
"\n",
" def __len__(self):\n",
" \"\"\"Return the current size of internal memory.\"\"\"\n",
" return self._size"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Noisy Network\n",
"\n",
"Following Fortunato et al. (2017), here we add factorized Gaussian noise to all layer parameters for more effective exploration.\n",
"\n",
"In addition to better training time exploration, we also found that without the noisy layers, the agent might get stuck at certain states and keep oscillating without making any more progress. The small amount of noise from the trained noisy layers also effectively prevent this from happening.\n",
"\n",
"The custom noisy linear layer is implemented below:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"\n",
"\n",
"class FactorizedNoisyLinear(torch.nn.Module):\n",
" __constants__ = [\"in_features\", \"out_features\"]\n",
" in_features: int\n",
" out_features: int\n",
"\n",
" weight_mean: torch.Tensor\n",
" weight_var: torch.Tensor\n",
" bias_mean: torch.Tensor\n",
" bias_var: torch.Tensor\n",
"\n",
" def __init__(self, in_features: int, out_features: int) -> None:\n",
" super().__init__()\n",
" self.in_features = in_features\n",
" self.out_features = out_features\n",
" self.weight_mean = torch.nn.Parameter(torch.Tensor(out_features, in_features))\n",
" self.weight_var = torch.nn.Parameter(torch.Tensor(out_features, in_features))\n",
" self.bias_mean = torch.nn.Parameter(torch.Tensor(out_features))\n",
" self.bias_var = torch.nn.Parameter(torch.Tensor(out_features))\n",
" self.reset_parameters()\n",
"\n",
" def reset_parameters(self) -> None:\n",
" x = 1 / math.sqrt(self.in_features)\n",
" for mean in [self.weight_mean, self.bias_mean]:\n",
" torch.nn.init.uniform_(mean, -x, x)\n",
"\n",
" for var in [self.weight_var, self.bias_var]:\n",
" torch.nn.init.constant_(var, 0.4 * x)\n",
"\n",
" def forward(self, input: torch.Tensor) -> torch.Tensor:\n",
" def f(x: torch.Tensor):\n",
" return torch.sign(x) * torch.sqrt(torch.abs(x))\n",
"\n",
" epsilon_in = f(torch.randn(1, self.in_features, device=input.device))\n",
" epsilon_out = f(torch.randn(self.out_features, 1, device=input.device))\n",
"\n",
" weight = self.weight_mean + (epsilon_out @ epsilon_in) * self.weight_var\n",
" bias = self.bias_mean + epsilon_out.squeeze() * self.bias_var\n",
" return torch.nn.functional.linear(input, weight, bias)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dueling network\n",
"\n",
"We also use the dueling network architecture described in Wang et al. (2015). Specifically, the network has one fully-connected layer before forking into value and advantage streams, and each stream has one hidden layer and one output layer. We use 128 hidden units for all hidden layers, and use ReLU as the activation function."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class DuelingNetwork(torch.nn.Module):\n",
" \"\"\"Actor (Policy) Model.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" state_size,\n",
" action_size,\n",
" hidden_size_1=128,\n",
" hidden_size_2=128,\n",
" activation_fn=torch.nn.functional.relu,\n",
" ):\n",
" \"\"\"Initialize parameters and build model.\n",
" Params\n",
" ======\n",
" state_size (int): Dimension of each state\n",
" action_size (int): Dimension of each action\n",
" \"\"\"\n",
" super().__init__()\n",
" self.activation_fn = activation_fn\n",
" self.fc1 = FactorizedNoisyLinear(state_size, hidden_size_1)\n",
" self.fc2_v = FactorizedNoisyLinear(hidden_size_1, hidden_size_2)\n",
" self.fc2_a = FactorizedNoisyLinear(hidden_size_1, hidden_size_2)\n",
" self.fc3_v = FactorizedNoisyLinear(hidden_size_2, 1)\n",
" self.fc3_a = FactorizedNoisyLinear(hidden_size_2, action_size)\n",
"\n",
" def forward(self, state):\n",
" \"\"\"Build a network that maps state -> action values.\"\"\"\n",
" x = state\n",
" x = self.activation_fn(self.fc1(x))\n",
"\n",
" v = self.activation_fn(self.fc2_v(x))\n",
" v = self.fc3_v(v)\n",
"\n",
" a = self.activation_fn(self.fc2_a(x))\n",
" a = self.fc3_a(a)\n",
" a = a - a.mean(dim=1, keepdim=True)\n",
"\n",
" x = v + a\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Double Q-learning\n",
"\n",
"Finally to put it together, we used double Q-learning as described by van Hasselt et al. (2015). The agent is implemented as follows:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"from collections import namedtuple, deque\n",
"import copy\n",
"from typing import List, Any, Generator, Tuple, Optional\n",
"\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"\n",
"class Agent:\n",
" \"\"\"Interacts with and learns from the environment.\"\"\"\n",
"\n",
" def __init__(self, state_size, action_size, config: Optional[AgentConfig] = None):\n",
" \"\"\"Initialize an Agent object.\n",
"\n",
" Params\n",
" ======\n",
" state_size (int): dimension of each state\n",
" action_size (int): dimension of each action\n",
" \"\"\"\n",
" self.state_size = state_size\n",
" self.action_size = action_size\n",
" self.config = config or AgentConfig()\n",
"\n",
" # Q-Network\n",
" self.qnetwork_local = DuelingNetwork(state_size, action_size).to(device)\n",
" self.qnetwork_target = copy.deepcopy(self.qnetwork_local)\n",
" self.qnetwork_target.eval()\n",
"\n",
" self.optimizer = optim.Adam(\n",
" self.qnetwork_local.parameters(), lr=self.config.learning_rate\n",
" )\n",
"\n",
" # Replay memory\n",
" self.replay_buffer = ProportionallyPrioritizedReplayBuffer(\n",
" self.config.replay_buffer_size\n",
" )\n",
" self.experience_unroll_queue = deque(\n",
" maxlen=self.config.reward_unroll_steps\n",
" )\n",
"\n",
"\n",
" # Initialize time step\n",
" self.t_step = 0\n",
"\n",
" def step(self, state, action, reward, next_state, done):\n",
" self.t_step += 1\n",
"\n",
" # Save experience in replay memory\n",
" self.experience_unroll_queue.append(\n",
" (state, action, next_state, reward, done)\n",
" )\n",
" if len(self.experience_unroll_queue) >= self.config.reward_unroll_steps:\n",
" state, action, _, _, _ = self.experience_unroll_queue[0]\n",
" _, _, next_state, _, done = self.experience_unroll_queue[-1]\n",
" reward = sum(\n",
" (self.config.reward_discount_factor ** i) * r\n",
" for i, (_, _, _, r, _) in enumerate(self.experience_unroll_queue)\n",
" )\n",
" self.replay_buffer.add((state, action, reward, next_state, done))\n",
"\n",
" self.learn()\n",
"\n",
" def act(self, state):\n",
" \"\"\"Returns actions for given state as per current policy.\n",
"\n",
" Params\n",
" ======\n",
" state (array_like): current state\n",
" \"\"\"\n",
" state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
" self.qnetwork_local.eval()\n",
" with torch.no_grad():\n",
" action_values = self.qnetwork_local(state)\n",
" self.qnetwork_local.train()\n",
" return np.argmax(action_values.cpu().data.numpy())\n",
"\n",
" def learn(self):\n",
" \"\"\"Update value parameters using given batch of experience tuples.\"\"\"\n",
" if len(self.replay_buffer) < self.config.batch_size:\n",
" return\n",
" \n",
" experiences = self.replay_buffer.sample_batch(self.config.batch_size)\n",
" \n",
" with torch.no_grad():\n",
" states, actions, rewards, next_states, dones = [\n",
" torch.from_numpy(np.vstack(v)).to(dtype=dtype, device=device)\n",
" for v, dtype in zip(\n",
" zip(*[e.value for e in experiences]),\n",
" [torch.float, torch.long, torch.float, torch.float, torch.float],\n",
" )\n",
" ]\n",
" sample_probs = (\n",
" torch.from_numpy(np.vstack([e.priority for e in experiences])).to(\n",
" dtype=torch.float, device=device\n",
" )\n",
" / self.replay_buffer.priority_sum\n",
" )\n",
" weights = 1.0 / (sample_probs * len(self.replay_buffer))\n",
"\n",
" q_curr = torch.gather(self.qnetwork_local(states), dim=-1, index=actions)\n",
"\n",
" with torch.no_grad():\n",
" _, next_actions = torch.max(\n",
" self.qnetwork_local(states), dim=-1, keepdim=True\n",
" )\n",
" q_target = rewards + (\n",
" (1.0 - dones)\n",
" * (self.config.reward_discount_factor ** self.config.reward_unroll_steps)\n",
" * torch.gather(\n",
" self.qnetwork_target(next_states), dim=-1, index=next_actions\n",
" )\n",
" )\n",
"\n",
" losses = F.mse_loss(q_curr, q_target, reduction=\"none\")\n",
"\n",
" # Update sample priorities\n",
" with torch.no_grad():\n",
" new_priorities = (\n",
" (losses.sqrt() + 1e-6)\n",
" .squeeze()\n",
" .cpu()\n",
" .numpy()\n",
" )\n",
" assert len(experiences) == len(new_priorities)\n",
" for experience, priority in zip(experiences, new_priorities):\n",
" experience.priority = priority\n",
"\n",
" loss = torch.mean(losses * weights)\n",
" self.optimizer.zero_grad()\n",
" loss.backward()\n",
" self.optimizer.step()\n",
"\n",
" # Update target network\n",
" if self.t_step % self.config.target_params_update_every == 0:\n",
" self.soft_update(self.qnetwork_local, self.qnetwork_target)\n",
"\n",
" def soft_update(self, local_model, target_model):\n",
" \"\"\"Soft update model parameters.\n",
" θ_target = τ*θ_local + (1 - τ)*θ_target\n",
"\n",
" Params\n",
" ======\n",
" local_model (PyTorch model): weights will be copied from\n",
" target_model (PyTorch model): weights will be copied to\n",
" \"\"\"\n",
" for target_param, local_param in zip(\n",
" target_model.parameters(), local_model.parameters()\n",
" ):\n",
" tau = self.config.target_params_update_ratio\n",
" target_param.data.copy_(\n",
" tau * local_param.data + (1.0 - tau) * target_param.data\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train & Evaluate\n",
"\n",
"With everything set up, we're now ready to train the agent. Note that since we use noisy networks for exploration, the epsilon-greedy strategy is not needed here.\n",
"\n",
"We train the model until the agent gets an average score of +13 over 100 consecutive episodes."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"agent = Agent(state_size=state_size, action_size=action_size)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from collections import deque\n",
"\n",
"def dqn(n_episodes=3000, max_t=1000):\n",
" \"\"\"Deep Q-Learning.\n",
"\n",
" Params\n",
" ======\n",
" n_episodes (int): maximum number of training episodes\n",
" max_t (int): maximum number of timesteps per episode\n",
" \"\"\"\n",
" scores = [] # list containing scores from each episode\n",
" scores_window = deque(maxlen=100) # last 100 scores\n",
" for i_episode in range(1, n_episodes + 1):\n",
" env_info = env.reset(train_mode=True)[brain_name]\n",
" state = env_info.vector_observations[0]\n",
" score = 0\n",
" for _ in range(max_t):\n",
" action = agent.act(state)\n",
" env_info = env.step(action)[brain_name]\n",
" next_state = env_info.vector_observations[0]\n",
" reward = env_info.rewards[0]\n",
" done = env_info.local_done[0]\n",
" agent.step(state, action, reward, next_state, done)\n",
" \n",
" score += reward\n",
" state = next_state\n",
" if done:\n",
" break\n",
"\n",
" scores_window.append(score) # save most recent score\n",
" scores.append(score) # save most recent score\n",
" average_score = np.mean(scores_window)\n",
" print(\n",
" f\"\\rEpisode {i_episode}\\tScore: {score:.2f}\\tWindowed average Score: {average_score:.2f}\",\n",
" end=\"\\n\" if i_episode % 100 == 0 else \"\",\n",
" )\n",
" if len(scores) > 100 and average_score > 13.:\n",
" print(\n",
" f\"\\nEnvironment solved between episode {i_episode - 100} and episode {i_episode}!\"\n",
" f\"\\tAverage Score: {average_score:.2f}\"\n",
" )\n",
" torch.save(agent.qnetwork_local.state_dict(), \"model.pt\")\n",
" break\n",
"\n",
" return scores"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2 µs, sys: 0 ns, total: 2 µs\n",
"Wall time: 4.53 µs\n",
"Episode 100\tScore: 8.00\tWindowed average Score: 10.17\n",
"Episode 135\tScore: 16.00\tWindowed average Score: 13.04\n",
"Environment solved between episode 35 and episode 135!\tAverage Score: 13.04\n"
]
}
],
"source": [
"%time\n",
"scores = dqn()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following graph shows how the score changes as we train for more episodes:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(10, 6))\n",
"ax = fig.add_subplot(111)\n",
"plt.plot(np.arange(1, len(scores)+1), scores)\n",
"plt.plot(np.arange(1, len(scores)+1), [13.] * len(scores))\n",
"plt.ylabel('Score')\n",
"plt.xlabel('Episode #')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, runing the following cell would simulate one run to demonstrate how the trained agent performs:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Score: 11.0\n"
]
}
],
"source": [
"env_info = env.reset(train_mode=False)[brain_name]\n",
"state = env_info.vector_observations[0]\n",
"score = 0\n",
"while True:\n",
" action = np.random.randint(action_size)\n",
" action = agent.act(state)\n",
" env_info = env.step(action)[brain_name]\n",
" next_state = env_info.vector_observations[0]\n",
" reward = env_info.rewards[0]\n",
" done = env_info.local_done[0]\n",
" score += reward\n",
" state = next_state\n",
" if done:\n",
" break\n",
"\n",
"print(\"Score: {}\".format(score))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Next steps\n",
"\n",
"During our experiment, we haven't got around to thoroughly tune the hyper-parameters, or analyze how each extension to DQN separately affects the training process and agent performance.\n",
"\n",
"It would also be interesting to experiment with all the other extensions to DQN that are described in Rainbow (Hessel et al. 2017), which effectively serves as the current baseline.\n",
"\n",
"We also observed that the agent sometimes ignores a cluster of yellow bananas near it that it saw a few seconds ago, and instead chases after another banana that is currently visible but far away. This inefficiency is likely caused by the decision process not being fully Markovian, since the the agent can't even observe all directions around it at the same time. One solution would be introduce some kind of recurrent mechanism as described in Hausknecht and Stone (2015), so that the agent can remember what it has seen recently and act more efficiently."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References\n",
"\n",
"- Mnih, V. et al. (2015) ‘Human-level control through deep reinforcement learning’, Nature, 518(7540), pp. 529–533. doi: 10.1038/nature14236.\n",
"- Schaul, T. et al. (2015) ‘Prioritized Experience Replay’. Available at: http://arxiv.org/abs/1511.05952\n",
"- Wang, Z. et al. (2015) ‘Dueling Network Architectures for Deep Reinforcement Learning’. Available at: http://arxiv.org/abs/1511.06581\n",
"- Fortunato, M. et al. (2017) ‘Noisy Networks for Exploration’. Available at: https://arxiv.org/abs/1706.10295\n",
"- van Hasselt, H., Guez, A. and Silver, D. (2015) ‘Deep Reinforcement Learning with Double Q-learning’. Available at: http://arxiv.org/abs/1509.06461\n",
"- Hessel, M. et al. (2017) ‘Rainbow: Combining Improvements in Deep Reinforcement Learning’. Available at: http://arxiv.org/abs/1710.02298\n",
"- Hausknecht, M. and Stone, P. (2015) ‘Deep Recurrent Q-Learning for Partially Observable MDPs’. Available at: http://arxiv.org/abs/1507.06527"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "drlnd",
"language": "python",
"name": "drlnd"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@tomtung
Copy link
Author

tomtung commented Nov 22, 2020

demo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment