Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save nunofernandes-plight/dc2f9669fd21c24a86d94e1ae0542b60 to your computer and use it in GitHub Desktop.
Save nunofernandes-plight/dc2f9669fd21c24a86d94e1ae0542b60 to your computer and use it in GitHub Desktop.
Q-Table learning in OpenAI grid world.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Q-Table Learning"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import gym\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [],
"source": [
"env = gym.make('FrozenLake-v0')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Implement Q-Table learning algorithm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"#Initialize table with all zeros\n",
"Q = np.zeros([env.observation_space.n,env.action_space.n])\n",
"# Set learning parameters\n",
"lr = .8\n",
"y = .95\n",
"num_episodes = 2000\n",
"#create lists to contain total rewards and steps per episode\n",
"#jList = []\n",
"rList = []\n",
"for i in range(num_episodes):\n",
" #Reset environment and get first new observation\n",
" s = env.reset()\n",
" rAll = 0\n",
" d = False\n",
" j = 0\n",
" #The Q-Table learning algorithm\n",
" while j < 99:\n",
" j+=1\n",
" #Choose an action by greedily (with noise) picking from Q table\n",
" a = np.argmax(Q[s,:] + np.random.randn(1,env.action_space.n)*(1./(i+1)))\n",
" #Get new state and reward from environment\n",
" s1,r,d,_ = env.step(a)\n",
" #Update Q-Table with new knowledge\n",
" Q[s,a] = Q[s,a] + lr*(r + y*np.max(Q[s1,:]) - Q[s,a])\n",
" rAll += r\n",
" s = s1\n",
" if d == True:\n",
" break\n",
" #jList.append(j)\n",
" rList.append(rAll)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"print \"Score over time: \" + str(sum(rList)/num_episodes)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"print \"Final Q-Table Values\"\n",
"print Q"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment