Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save programming-datascience/d8b96346e347b0b6942e16a33e64039c to your computer and use it in GitHub Desktop.
Save programming-datascience/d8b96346e347b0b6942e16a33e64039c to your computer and use it in GitHub Desktop.
Cartpole Video notebook
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Importing libraries\n",
"import gym\n",
"import numpy as np\n",
"from itertools import count\n",
"from collections import namedtuple"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.distributions import Categorical\n",
"# Importing PyTorch here"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"env = gym.make('CartPole-v0') # We make the Cartpole environment here"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are 2 actions\n"
]
}
],
"source": [
"print(\"There are {} actions\".format(env.action_space.n))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# You can move either left or right to balance the pole\n",
"# Lets implement the Actor critic network\n",
"class ActorCritic(nn.Module):\n",
" def __init__(self):\n",
" super(ActorCritic, self).__init__()\n",
" self.fc1 = nn.Linear(4, 128) # 4 because there are 4 parameters as the observation space\n",
" self.actor = nn.Linear(128, 2) # 2 for the number of actions\n",
" self.critic = nn.Linear(128, 1) # Critic is always 1\n",
" self.saved_actions = []\n",
" self.rewards = []\n",
" def forward(self, x):\n",
" x = F.relu(self.fc1(x))\n",
" action_prob = F.softmax(self.actor(x), dim=-1)\n",
" state_values = self.critic(x)\n",
" return action_prob, state_values"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def select_action(state):\n",
" state = torch.from_numpy(state).float()\n",
" probs, state_value = model(state)\n",
" m = Categorical(probs)\n",
" action = m.sample()\n",
" model.saved_actions.append(SavedAction(m.log_prob(action), state_value))\n",
" return action.item()\n",
"# In this function, we decide whehter we want the block to move left or right,based on what the model decided"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def finish_episode():\n",
" # We calculate the losses and perform backprop in this function\n",
" R = 0\n",
" saved_actions = model.saved_actions\n",
" policy_losses = []\n",
" value_losses =[]\n",
" returns = []\n",
" \n",
" for r in model.rewards[::-1]:\n",
" R = r + 0.99 * R # 0.99 is our gamma number\n",
" returns.insert(0, R)\n",
" returns = torch.tensor(returns)\n",
" returns = (returns - returns.mean()) / (returns.std() + eps)\n",
" \n",
" for (log_prob, value), R in zip(saved_actions, returns):\n",
" advantage = R - value.item()\n",
" \n",
" policy_losses.append(-log_prob * advantage)\n",
" value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))\n",
" \n",
" optimizer.zero_grad()\n",
" loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()\n",
" \n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" del model.rewards[:]\n",
" del model.saved_actions[:]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"model = ActorCritic()\n",
"optimizer = optim.Adam(model.parameters(), lr=3e-2)\n",
"eps = np.finfo(np.float32).eps.item()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def train():\n",
" running_reward = 10\n",
" for i_episode in count(): # We need around this much episodes\n",
" state = env.reset()\n",
" ep_reward = 0\n",
" for t in range(1, 10000):\n",
" action = select_action(state)\n",
" state, reward, done, _ = env.step(action)\n",
" model.rewards.append(reward)\n",
" ep_reward += reward\n",
" if done:\n",
" break\n",
" running_reward = 0.05 * ep_reward + (1-0.05) * running_reward\n",
" finish_episode()\n",
" if i_episode % 10 == 0: # We will print some things out\n",
" print(\"Episode {}\\tLast Reward: {:.2f}\\tAverage reward: {:.2f}\".format(\n",
" i_episode, ep_reward, running_reward\n",
" ))\n",
" if running_reward > env.spec.reward_threshold:\n",
" print(\"Solved, running reward is now {} and the last episode runs to {} time steps\".format(\n",
" running_reward, t\n",
" ))\n",
" break\n",
" # This means that we solved cartpole and training is complete\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode 0\tLast Reward: 20.00\tAverage reward: 10.50\n",
"Episode 10\tLast Reward: 10.00\tAverage reward: 13.26\n",
"Episode 20\tLast Reward: 12.00\tAverage reward: 11.96\n",
"Episode 30\tLast Reward: 10.00\tAverage reward: 11.27\n",
"Episode 40\tLast Reward: 10.00\tAverage reward: 10.59\n",
"Episode 50\tLast Reward: 9.00\tAverage reward: 10.35\n",
"Episode 60\tLast Reward: 25.00\tAverage reward: 12.55\n",
"Episode 70\tLast Reward: 17.00\tAverage reward: 20.20\n",
"Episode 80\tLast Reward: 60.00\tAverage reward: 31.55\n",
"Episode 90\tLast Reward: 83.00\tAverage reward: 46.08\n",
"Episode 100\tLast Reward: 76.00\tAverage reward: 51.85\n",
"Episode 110\tLast Reward: 50.00\tAverage reward: 53.48\n",
"Episode 120\tLast Reward: 145.00\tAverage reward: 52.82\n",
"Episode 130\tLast Reward: 200.00\tAverage reward: 104.65\n",
"Episode 140\tLast Reward: 200.00\tAverage reward: 142.91\n",
"Episode 150\tLast Reward: 130.00\tAverage reward: 139.55\n",
"Episode 160\tLast Reward: 151.00\tAverage reward: 143.53\n",
"Episode 170\tLast Reward: 114.00\tAverage reward: 125.28\n",
"Episode 180\tLast Reward: 125.00\tAverage reward: 123.16\n",
"Episode 190\tLast Reward: 147.00\tAverage reward: 128.21\n",
"Episode 200\tLast Reward: 200.00\tAverage reward: 146.84\n",
"Episode 210\tLast Reward: 200.00\tAverage reward: 168.17\n",
"Episode 220\tLast Reward: 200.00\tAverage reward: 180.94\n",
"Episode 230\tLast Reward: 200.00\tAverage reward: 188.59\n",
"Episode 240\tLast Reward: 91.00\tAverage reward: 162.25\n",
"Episode 250\tLast Reward: 158.00\tAverage reward: 153.06\n",
"Episode 260\tLast Reward: 200.00\tAverage reward: 164.78\n",
"Episode 270\tLast Reward: 200.00\tAverage reward: 176.65\n",
"Episode 280\tLast Reward: 200.00\tAverage reward: 185.83\n",
"Episode 290\tLast Reward: 82.00\tAverage reward: 175.56\n",
"Episode 300\tLast Reward: 200.00\tAverage reward: 185.37\n",
"Episode 310\tLast Reward: 97.00\tAverage reward: 169.63\n",
"Episode 320\tLast Reward: 77.00\tAverage reward: 153.32\n",
"Episode 330\tLast Reward: 200.00\tAverage reward: 156.71\n",
"Episode 340\tLast Reward: 200.00\tAverage reward: 172.85\n",
"Episode 350\tLast Reward: 200.00\tAverage reward: 183.75\n",
"Episode 360\tLast Reward: 200.00\tAverage reward: 190.27\n",
"Episode 370\tLast Reward: 200.00\tAverage reward: 194.17\n",
"Episode 380\tLast Reward: 58.00\tAverage reward: 169.13\n",
"Episode 390\tLast Reward: 200.00\tAverage reward: 171.02\n",
"Episode 400\tLast Reward: 200.00\tAverage reward: 173.28\n",
"Episode 410\tLast Reward: 200.00\tAverage reward: 179.81\n",
"Episode 420\tLast Reward: 200.00\tAverage reward: 187.91\n",
"Episode 430\tLast Reward: 200.00\tAverage reward: 192.76\n",
"Solved, running reward is now 195.19851545387158 and the last episode runs to 200 time steps\n"
]
}
],
"source": [
"train()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"# There. we finished\n",
"# Lets see it in action\n",
"done = False\n",
"cnt = 0"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/site-packages/gym/logger.py:30: UserWarning: \u001b[33mWARN: You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.\u001b[0m\n",
" warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-25-0fffeb1f509a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mcnt\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mselect_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobservation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mobservation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/site-packages/gym/core.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode, **kwargs)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'human'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 233\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 234\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/site-packages/gym/envs/classic_control/cartpole.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpoletrans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_rotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 213\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreturn_rgb_array\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'rgb_array'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 214\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/site-packages/gym/envs/classic_control/rendering.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, return_rgb_array)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0marr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwidth\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0marr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monetime_geoms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0marr\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_rgb_array\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misopen\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/site-packages/pyglet/window/cocoa/__init__.py\u001b[0m in \u001b[0;36mflip\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 287\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw_mouse_cursor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 289\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 290\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 291\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdispatch_events\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/site-packages/pyglet/gl/cocoa.py\u001b[0m in \u001b[0;36mflip\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mflip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 327\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nscontext\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflushBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.7/site-packages/pyglet/libs/darwin/cocoapy/runtime.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 783\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 784\u001b[0m \u001b[0;34m\"\"\"Call the method with the given arguments.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 785\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobjc_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 786\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 787\u001b[0m \u001b[0;31m######################################################################\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/site-packages/pyglet/libs/darwin/cocoapy/runtime.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, objc_id, *args)\u001b[0m\n\u001b[1;32m 753\u001b[0m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_callable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 754\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 755\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobjc_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mselector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 756\u001b[0m \u001b[0;31m# Convert result to python type if it is a instance or class pointer.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 757\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestype\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mObjCInstance\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"observation = env.reset()\n",
"while True:\n",
" cnt += 1\n",
" env.render()\n",
" action = select_action(observation)\n",
" observation, reward, done, _ = env.step(action)\n",
" # Lets see how long it lasts until failing\n",
"print(f\"Game lasted {cnt} moves\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# There We are done"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment