Skip to content

Instantly share code, notes, and snippets.

@cadurosar
Last active August 4, 2023 22:19
Show Gist options
  • Star 10 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save cadurosar/bd54c723c1d6335a43c8 to your computer and use it in GitHub Desktop.
Save cadurosar/bd54c723c1d6335a43c8 to your computer and use it in GitHub Desktop.
A interactive ipython notebook for: Keras plays catch - https://gist.github.com/EderSantana/c7222daa328f0e885093
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Action End Test, Points: 0\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAFSCAYAAAB2cI2KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADV5JREFUeJzt3E2IlXX7wPHr/BVMx0QHdUQTQTeCYzIbQSaZjaPQxkR8\nC1820YNFoCtFSciNaIsgSkISXBqSDi4USxAxFAvEGISCFAoNmUxH0REJu59F0PP08B/PXPN2zpn7\n84GzCM99vK6yr7/7zMypFEVRBAAD8n+1HgCgkYgmQIJoAiSIJkCCaAIkiCZARjHMImLQj+7u7iFd\n36iPMu5t5/I8GnXv/lSG+/s0K5XKoK8timJI1zeqMu5t5/Jo1L37S6Pbc4AE0QRIEE2ABNEESBBN\ngATRBEgQTYAE0QRIEE2AhPEDedKBAwfi+++/j0qlEnv27InFixeP9FwA9anaz5J/++23xb/+9a+i\nKIrip59+KjZs2DBiP3s+1Osb9VHGve1cnkej7t2fqrfnV65ciRUrVkRExIIFC+LRo0fx5MmTapcB\njElVo3nv3r1obm7++5+nTZsW9+7dG9GhAOrVgN7T/G/VPhSpu7s7WltbBz3QMH/oUsMo4952Lo+x\ntHfVaM6cOfMfJ8uenp6YMWNGv88fyheJGvUjpIaqjHvbuTwade9BfzRce3t7nDt3LiIibty4ES0t\nLTFp0qThnQ6gQVQ9aba1tcWiRYti48aNMW7cuNi3b99ozAVQl3xyex0o4952Lo9G3dsntwMMA9EE\nSBBNgATRBEgQTYAE0QRIEE2ABNEESBBNgATRBEgQTYAE0QRIEE2ABNEESBBNgATRBEgQTYAE0QRI\nEE2ABNEESBBNgATRBEgQTYAE0QRIEE2ABNEESBBNgATRBEgQTYAE0QRIEE2ABNEESBBNgATRBEgQ\nTYAE0QRIEE2ABNEESBBNgATRBEgQTYAE0QRIEE2ABNEESBBNgATRBEgQTYAE0QRIEE2ABNEESBBN\ngATRBEgYP5AnHTp0KK5duxbPnz+Pt99+Ozo7O0d6LoC6VDWaV69ejZs3b8bx48ejt7c31qxZI5pA\naVWN5tKlS2PJkiURETFlypR4+vRpFEURlUplxIcDqDdV39OsVCrx0ksvRUTEiRMnoqOjQzCB0hrQ\ne5oREefPn4+TJ0/G0aNHR3IegLo2oGheunQpjhw5EkePHo3Jkye/8Lnd3d3R2to66IGKohj0tY2s\njHvbuTzG0t6Voso2jx8/jjfffDOOHTsWzc3N1V9wCLfuZX2vtIx727k8GnXv/tJY9aR55syZ6O3t\njR07dvy9/KFDh2LWrFnDPiRAvat60ky/oJNmWhn3tnN5NOre/aXRTwQBJIgmQIJoAiSIJkCCaAIk\niCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSI\nJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgm\nQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkDCgKL57Nmz\n6OzsjK6urpGeB6CuDSiahw8fjqlTp470LAB1r2o0b926Fbdu3YqOjo7RmAegrlWN5sGDB2P37t2j\nMQtA3XthNLu6uqKtrS3mzJkTERFFUYzKUAD1avyLfvHixYtx+/btuHDhQty9ezcmTJgQs2bNimXL\nlvV7TXd3d7S2tg56oLKGuYx727k8xtLelWKA23zyySfxyiuvxBtvvPHiF6xUBj1MURRDur5RlXFv\nO5dHo+7dXxp9nyZAwoBPmgN+QSfNtDLubefyaNS9nTQBhoFoAiSIJkCCaAIkiCZAgmgCJIgmQIJo\nAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgC\nJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIk\niCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAwoCiefr06Vi9enWs\nXbs2Ll68ONIzAdSvoooHDx4UK1euLPr6+orffvuteP/991/4/IgY9GOo1zfqo4x727k8j0bduz/j\no4rLly9He3t7TJw4MSZOnBj79++vdgnAmFX19vzOnTvx9OnT2L59e2zevDmuXLkyGnMB1KWqJ82i\nKKK3tzcOHz4ct2/fjq1bt8aFCxf6fX53d3e0trYOeqC/TvLlU8a97VweY2nvqtGcPn16tLW1RaVS\niblz50ZTU1Pcv38/mpub/9/nL168eNDDFEURlUpl0Nc3qjLubefyaNS9+wt91dvz9vb2uHr1ahRF\nEQ8ePIi+vr5+gwkw1lU9aba0tMSqVati/fr1UalUYt++faMxF0BdqhTD/GbDUI7hjXqMH6oy7m3n\n8mjUvQd9ew7Af4gmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJo\nAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgC\nJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIk\niCZAgmgCJIgmQIJoAiSIJkDC+GpP6Ovri127dsXDhw/jjz/+iHfffTdee+210ZgNoO5UjeapU6di\n/vz5sXPnzujp6Ylt27bF2bNnR2M2gLpT9fZ82rRp8eDBg4iIePjwYTQ3N4/4UAD1qlIURVHtSW+9\n9Vb88ssv8ejRozhy5Ei8+uqr/b9gpTLoYYqiGNL1jaqMe9u5PBp17/7SWPWkefr06Zg9e3Z89dVX\ncezYsfjggw+GfTiARlH1Pc1r167F8uXLIyJi4cKF0dPT88K/Obq7u6O1tXXQAw3g4DsmlXFvO5fH\nWNq7ajTnzZsX169fj87Ozrhz5040NTW98Ki9ePHiQQ/TqMf4oSrj3nYuj0bdu7/QV31Ps6+vL/bs\n2RO///57PH/+PHbs2BFLly7t9/ne08wr4952Lo9G3XvQ0cwSzbwy7m3n8mjUvQf9hSAA/kM0ARJE\nEyBBNAESRBMgQTQBEkQTIEE0ARJEEyBBNAESRBMgoeqnHFHdn3/+WRevMdqG+rEFz58/H9R148aN\nG9LvC0PhpAmQIJoACaIJkCCaAAmiCZAgmgAJogmQIJoACaIJkCCaAAmiCZAgmgAJogmQIJoACaIJ\nkCCaAAmiCZAgmgAJogmQIJoACaIJkCCaAAmiCZAgmgAJogmQIJoACaIJkCCaAAmiCZAwvtYDjAVF\nUQzp+kqlMujXqFQqQ/q9gZxKMdT/4wFKxO05QIJoAiSIJkCCaAIkiCZAgmgCJNRNNA8cOBAbN26M\nTZs2RXd3d63HGRWHDh2KjRs3xrp16+Lrr7+u9Tij5tmzZ9HZ2RldXV21HmXUnD59OlavXh1r166N\nixcv1nqcEdfX1xfvvfdebN26NTZt2hTffPNNrUcaNnXxze3fffdd/Pzzz3H8+PG4efNm7N27N44f\nP17rsUbU1atX4+bNm3H8+PHo7e2NNWvWRGdnZ63HGhWHDx+OqVOn1nqMUdPb2xuffvppdHV1xZMn\nT+Ljjz+Ojo6OWo81ok6dOhXz58+PnTt3Rk9PT2zbti3Onj1b67GGRV1E88qVK7FixYqIiFiwYEE8\nevQonjx5Ek1NTTWebOQsXbo0lixZEhERU6ZMiadPn0ZRFGP+J3xu3boVt27dGvPR+G+XL1+O9vb2\nmDhxYkycODH2799f65FG3LRp0+LHH3+MiIiHDx9Gc3NzjScaPnVxe37v3r1//EudNm1a3Lt3r4YT\njbxKpRIvvfRSREScOHEiOjo6xnwwIyIOHjwYu3fvrvUYo+rOnTvx9OnT2L59e2zevDmuXLlS65FG\n3Ouvvx6//vprrFy5MrZs2RK7du2q9UjDpi5Omv+rTD/Zef78+Th58mQcPXq01qOMuK6urmhra4s5\nc+ZERHn+OxdFEb29vXH48OG4fft2bN26NS5cuFDrsUbU6dOnY/bs2fH555/HDz/8EHv37o0vv/yy\n1mMNi7qI5syZM/9xsuzp6YkZM2bUcKLRcenSpThy5EgcPXo0Jk+eXOtxRtzFixfj9u3bceHChbh7\n925MmDAhZs2aFcuWLav1aCNq+vTp0dbWFpVKJebOnRtNTU1x//79MXXL+r+uXbsWy5cvj4iIhQsX\nRk9Pz5h5+6kubs/b29vj3LlzERFx48aNaGlpiUmTJtV4qpH1+PHj+PDDD+Ozzz6Ll19+udbjjIqP\nPvooTpw4EV988UWsW7cu3nnnnTEfzIi//nxfvXo1iqKIBw8eRF9f35gOZkTEvHnz4vr16xHx19sT\nTU1NYyKYEXVy0mxra4tFixbFxo0bY9y4cbFv375ajzTizpw5E729vbFjx46//wY+dOhQzJo1q9aj\nMcxaWlpi1apVsX79+qhUKqX4871hw4bYs2dPbNmyJZ4/fz6mvvjlo+EAEuri9hygUYgmQIJoAiSI\nJkCCaAIkiCZAgmgCJIgmQMK/ARRcnZLdxlbtAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f3431d0d610>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"train(model)\n",
"test(model)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using Theano backend.\n",
"ERROR (theano.sandbox.cuda): nvcc compiler not found on $PATH. Check your nvcc installation and try again.\n",
"ERROR:theano.sandbox.cuda:nvcc compiler not found on $PATH. Check your nvcc installation and try again.\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import seaborn\n",
"seaborn.set()\n",
"import json\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import time\n",
"from keras.models import model_from_json\n",
"from qlearn import Catch\n",
"from PIL import Image\n",
"from IPython import display\n",
"last_frame_time = 0\n",
"translate_action = [\"Left\",\"Stay\",\"Right\",\"Create Ball\",\"End Test\"]\n",
"grid_size = 10\n",
"\n",
"def display_screen(action,points,input_t):\n",
" global last_frame_time\n",
" display.clear_output(wait=True)\n",
" print \"Action %s, Points: %d\" % (translate_action[action],points)\n",
" if(\"End\" not in translate_action[action]):\n",
" plt.imshow(input_t.reshape((grid_size,)*2),\n",
" interpolation='none', cmap='gray')\n",
" display.display(plt.gcf())\n",
" last_frame_time = set_max_fps(last_frame_time)\n",
"def set_max_fps(last_frame_time,FPS = 1):\n",
" current_milli_time = lambda: int(round(time.time() * 1000))\n",
" sleep_time = 1./FPS - (current_milli_time() - last_frame_time)\n",
" if sleep_time > 0:\n",
" time.sleep(sleep_time)\n",
" return current_milli_time()\n",
"def test(model):\n",
" global last_frame_time\n",
" plt.ion()\n",
" # Define environment, game\n",
" env = Catch(grid_size)\n",
" c = 0\n",
" last_frame_time = 0\n",
" points = 0\n",
" for e in range(10):\n",
" loss = 0.\n",
" env.reset()\n",
" game_over = False\n",
" # get initial input\n",
" input_t = env.observe()\n",
" display_screen(3,points,input_t)\n",
" c += 1\n",
" while not game_over:\n",
" input_tm1 = input_t\n",
" # get next action\n",
" q = model.predict(input_tm1)\n",
" action = np.argmax(q[0])\n",
" # apply action, get rewards and new state\n",
" input_t, reward, game_over = env.act(action)\n",
" points += reward\n",
" display_screen(action,points,input_t)\n",
" c += 1\n",
" display_screen(4,points,input_t)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import json\n",
"import numpy as np\n",
"from keras.models import Sequential\n",
"from keras.layers.core import Dense\n",
"from keras.optimizers import sgd\n",
"\n",
"\n",
"class Catch(object):\n",
" def __init__(self, grid_size=10):\n",
" self.grid_size = grid_size\n",
" self.reset()\n",
"\n",
" def _update_state(self, action):\n",
" \"\"\"\n",
" Input: action and states\n",
" Ouput: new states and reward\n",
" \"\"\"\n",
" state = self.state\n",
" if action == 0: # left\n",
" action = -1\n",
" elif action == 1: # stay\n",
" action = 0\n",
" else:\n",
" action = 1 # right\n",
" f0, f1, basket = state[0]\n",
" new_basket = min(max(1, basket + action), self.grid_size-1)\n",
" f0 += 1\n",
" out = np.asarray([f0, f1, new_basket])\n",
" out = out[np.newaxis]\n",
"\n",
" assert len(out.shape) == 2\n",
" self.state = out\n",
"\n",
" def _draw_state(self):\n",
" im_size = (self.grid_size,)*2\n",
" state = self.state[0]\n",
" canvas = np.zeros(im_size)\n",
" canvas[state[0], state[1]] = 1 # draw fruit\n",
" canvas[-1, state[2]-1:state[2] + 2] = 1 # draw basket\n",
" return canvas\n",
" \n",
" def _get_reward(self):\n",
" fruit_row, fruit_col, basket = self.state[0]\n",
" if fruit_row == self.grid_size-1:\n",
" if abs(fruit_col - basket) <= 1:\n",
" return 1\n",
" else:\n",
" return -1\n",
" else:\n",
" return 0\n",
"\n",
" def _is_over(self):\n",
" if self.state[0, 0] == self.grid_size-1:\n",
" return True\n",
" else:\n",
" return False\n",
"\n",
" def observe(self):\n",
" canvas = self._draw_state()\n",
" return canvas.reshape((1, -1))\n",
"\n",
" def act(self, action):\n",
" self._update_state(action)\n",
" reward = self._get_reward()\n",
" game_over = self._is_over()\n",
" return self.observe(), reward, game_over\n",
"\n",
" def reset(self):\n",
" n = np.random.randint(0, self.grid_size-1, size=1)\n",
" m = np.random.randint(1, self.grid_size-2, size=1)\n",
" self.state = np.asarray([0, n, m])[np.newaxis]\n",
"\n",
"\n",
"class ExperienceReplay(object):\n",
" def __init__(self, max_memory=100, discount=.9):\n",
" self.max_memory = max_memory\n",
" self.memory = list()\n",
" self.discount = discount\n",
"\n",
" def remember(self, states, game_over):\n",
" # memory[i] = [[state_t, action_t, reward_t, state_t+1], game_over?]\n",
" self.memory.append([states, game_over])\n",
" if len(self.memory) > self.max_memory:\n",
" del self.memory[0]\n",
"\n",
" def get_batch(self, model, batch_size=10):\n",
" len_memory = len(self.memory)\n",
" num_actions = model.output_shape[-1]\n",
" env_dim = self.memory[0][0][0].shape[1]\n",
" inputs = np.zeros((min(len_memory, batch_size), env_dim))\n",
" targets = np.zeros((inputs.shape[0], num_actions))\n",
" for i, idx in enumerate(np.random.randint(0, len_memory,\n",
" size=inputs.shape[0])):\n",
" state_t, action_t, reward_t, state_tp1 = self.memory[idx][0]\n",
" game_over = self.memory[idx][1]\n",
"\n",
" inputs[i:i+1] = state_t\n",
" # There should be no target values for actions not taken.\n",
" # Thou shalt not correct actions not taken #deep\n",
" targets[i] = model.predict(state_t)[0]\n",
" Q_sa = np.max(model.predict(state_tp1)[0])\n",
" if game_over: # if game_over is True\n",
" targets[i, action_t] = reward_t\n",
" else:\n",
" # reward_t + gamma * max_a' Q(s', a')\n",
" targets[i, action_t] = reward_t + self.discount * Q_sa\n",
" return inputs, targets\n",
"\n",
" \n",
"# parameters\n",
"epsilon = .1 # exploration\n",
"num_actions = 3 # [move_left, stay, move_right]\n",
"epoch = 1000\n",
"max_memory = 500\n",
"hidden_size = 100\n",
"batch_size = 1\n",
"grid_size = 10\n",
"\n",
"model = Sequential()\n",
"model.add(Dense(hidden_size, input_shape=(grid_size**2,), activation='relu'))\n",
"model.add(Dense(hidden_size, activation='relu'))\n",
"model.add(Dense(num_actions))\n",
"model.compile(sgd(lr=.2), \"mse\")\n",
" \n",
"# If you want to continue training from a previous model, just uncomment the line bellow\n",
"# model.load_weights(\"model.h5\")\n",
"\n",
"# Define environment/game\n",
"env = Catch(grid_size)\n",
"\n",
"# Initialize experience replay object\n",
"exp_replay = ExperienceReplay(max_memory=max_memory)\n",
"\n",
"def train(model):\n",
" # Train\n",
" win_cnt = 0\n",
" for e in range(1):\n",
" loss = 0.\n",
" env.reset()\n",
" game_over = False\n",
" # get initial input\n",
" input_t = env.observe()\n",
"\n",
" while not game_over:\n",
" input_tm1 = input_t\n",
" # get next action\n",
" if np.random.rand() <= epsilon:\n",
" action = np.random.randint(0, num_actions, size=1)\n",
" else:\n",
" q = model.predict(input_tm1)\n",
" action = np.argmax(q[0])\n",
"\n",
" # apply action, get rewards and new state\n",
" input_t, reward, game_over = env.act(action)\n",
" if reward == 1:\n",
" win_cnt += 1\n",
"\n",
" # store experience\n",
" exp_replay.remember([input_tm1, action, reward, input_t], game_over) \n",
" \n",
" # adapt model\n",
" inputs, targets = exp_replay.get_batch(model, batch_size=batch_size)\n",
" \n",
" display_screen(action,3000,inputs[0]) \n",
" \n",
" loss += model.train_on_batch(inputs, targets)[0]\n",
" print(\"Epoch {:03d}/999 | Loss {:.4f} | Win count {}\".format(e, loss, win_cnt))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@Guidosalimbeni
Copy link

Guidosalimbeni commented Nov 20, 2018

I got this error message in line
-> 166 loss += model.train_on_batch(inputs, targets)[0]
invalid index to scalar variable.

I am using Python 3 on window anaconda .. TensorFlow backend

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment