Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save harshilpatel312/a89a15a6c2e4f58f19cfee7fd6fe617a to your computer and use it in GitHub Desktop.
Save harshilpatel312/a89a15a6c2e4f58f19cfee7fd6fe617a to your computer and use it in GitHub Desktop.
Q-Table learning in OpenAI grid world.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Q-Table Learning"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the environment"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"env = gym.make('FrozenLake-v0')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Implement Q-Table learning algorithm"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Initialize table with all zeros\n",
"Q = np.zeros([env.observation_space.n,env.action_space.n])\n",
"\n",
"# Set learning parameters\n",
"lr = .8\n",
"gamma = .95\n",
"num_episodes = 2000\n",
"max_steps = 99\n",
"\n",
"# Create lists to contain total rewards and steps per episode\n",
"rewards_list = []\n",
"for i in range(num_episodes):\n",
" # Reset environment and get first new observation\n",
" state = env.reset()\n",
" total_reward = 0\n",
" done = False\n",
" step = 0\n",
" \n",
" # The Q-Table learning algorithm\n",
" while step < max_steps:\n",
" step += 1\n",
" \n",
" # Choose an action by greedily (with noise) picking from Q table\n",
" action = np.argmax(Q[state,:] + np.random.randn(1, env.action_space.n) * (1./(i+1)))\n",
" \n",
" # Get new state and reward from environment\n",
" new_state, reward, done, _ = env.step(action)\n",
" \n",
" # Update Q-Table with new knowledge\n",
" Q[state, action] = Q[state, action] + lr * (reward + gamma * np.max(Q[new_state,:]) - Q[state, action])\n",
" total_reward += reward\n",
" state = new_state\n",
" \n",
" if done == True:\n",
" break\n",
"\n",
" rewards_list.append(total_reward)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Score over time: 0.0\n",
"Final Q-Table Values\n",
"[[ 1.73606976e-01 8.09353670e-03 5.37655073e-03 7.57670043e-03]\n",
" [ 8.56387621e-04 2.40744772e-05 8.14224075e-04 1.60722899e-01]\n",
" [ 8.83784512e-04 7.39989148e-04 4.10694044e-03 1.91931139e-01]\n",
" [ 9.15500710e-04 1.27182348e-04 9.00477296e-04 7.90341515e-02]\n",
" [ 1.56972694e-01 1.41182122e-03 5.59359958e-03 7.14299174e-04]\n",
" [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n",
" [ 8.90677852e-02 8.89296303e-05 6.39010325e-04 2.34099326e-08]\n",
" [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n",
" [ 7.00541641e-04 1.83801419e-04 5.54609322e-04 1.63727386e-01]\n",
" [ 5.12418126e-04 3.02610162e-01 2.87072750e-04 8.52808810e-04]\n",
" [ 4.75554318e-01 3.22479170e-04 1.51289856e-04 1.97135557e-04]\n",
" [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n",
" [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n",
" [ 1.10336815e-03 1.45180020e-03 6.05380068e-01 1.93977187e-04]\n",
" [ 0.00000000e+00 0.00000000e+00 9.55785422e-01 0.00000000e+00]\n",
" [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n"
]
}
],
"source": [
"print(\"Score over time: \" + str(sum(rewards_list)/num_episodes))\n",
"print(\"Final Q-Table Values\")\n",
"print(Q)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment