Skip to content

Instantly share code, notes, and snippets.

@chetandhembre
Created October 22, 2016 20:34
Show Gist options
  • Save chetandhembre/9d83dc8ea1bc2daf034e4d2edc5d71de to your computer and use it in GitHub Desktop.
Save chetandhembre/9d83dc8ea1bc2daf034e4d2edc5d71de to your computer and use it in GitHub Desktop.
Clif World Q- Learning Vs SARSA learning
# https://dl.dropboxusercontent.com/u/47591917/cliff_world_episode_length.png
# https://dl.dropboxusercontent.com/u/47591917/cliff_world_episode_reward.png
# https://dl.dropboxusercontent.com/u/47591917/cliff_world_episode_timestamp.png
# https://dl.dropboxusercontent.com/u/47591917/cliff_world_path.png
import numpy as np
import matplotlib
import math
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.patches import Circle, Wedge, Polygon
from matplotlib.collections import PatchCollection
import matplotlib.patches as patches
import matplotlib.patches as mpatches
UP = 0
DOWN = 1
RIGHT = 2
LEFT = 3
ALLOWED_ACTIONS = [UP, DOWN, RIGHT, LEFT]
NORMAL_REWARD = -1
CLIFF_REWARD = -100
class State(object):
def __init__(self, x, y):
self.x = x
self.y = y
def __hash__(self):
return hash((self.x, self.y))
def __eq__(self, other):
return (self.x == other.x) and (self.y == other.y)
def __str__(self):
return str((self.x, self.y))
class StateAction(object):
def __init__(self, state, action):
self.state = state
self.action = action
def __hash__(self):
return hash((self.state, self.action))
def __eq__(self, other):
return (self.state == other.state and self.action == other.action)
def __str__(self):
return str((str(self.state), self.y))
class ValueMap(object):
def __init__(self, width, height, epsilon, discount_factor, alpha):
self.width = width
self.height = height
self.epsilon = epsilon
self.discount_factor = discount_factor
self.alpha = alpha
self.value_map = {}
self.cliff = {}
def initialize(self):
for i in range(self.height):
for j in range(self.width):
state = State(i, j)
for action in ALLOWED_ACTIONS:
action_state = StateAction(state, action)
if i == 0 and self.width - 1 > j > 0:
self.cliff[state] = 1
self.value_map[action_state] = -9999
else:
self.cliff[state] = 0
self.value_map[action_state] = 0
def get_greddy_action(self, state):
max_value = -np.inf
max_action = None
for action in ALLOWED_ACTIONS:
action_state = StateAction(state, action)
value = self.value_map[action_state]
if max_value <= value:
max_value = value
max_action = action
return max_action
def select_next_action(self, current_state, is_greddy=False):
greedy_action = self.get_greddy_action(current_state)
if is_greddy:
return greedy_action
actions_probabilities =np.ones(len(ALLOWED_ACTIONS)) * self.epsilon / len(ALLOWED_ACTIONS)
actions_probabilities[greedy_action] = actions_probabilities[greedy_action] + (1 - self.epsilon)
return np.random.choice(np.arange(len(actions_probabilities)), p=actions_probabilities)
def _get_next_state(self, current_state, current_action):
x, y = current_state.x, current_state.y
x_new, y_new = x, y
if current_action == UP:
x_new = x + 1
elif current_action == RIGHT:
y_new = y + 1
elif current_action == DOWN:
x_new = x - 1
elif current_action == LEFT:
y_new = y - 1
x_new = max(0, x_new)
x_new = min(self.height - 1, x_new)
y_new = min(self.width - 1, y_new)
y_new = max(0, y_new)
return State(x_new, y_new)
def _get_reward(self, state):
if self.cliff[state]:
return CLIFF_REWARD
return NORMAL_REWARD
def update_value_map(self, current_state, current_action, is_sarsa=False, is_greddy=False):
current_action_state = StateAction(current_state, current_action)
next_state = self._get_next_state(current_state, current_action)
reward = NORMAL_REWARD
if self.cliff[next_state]:
reward = CLIFF_REWARD
next_action = self.select_next_action(next_state, is_greddy=is_greddy)
if not is_sarsa:
next_action = self.get_greddy_action(next_state)
next_action_state = StateAction(next_state, next_action)
target_value = self.value_map[next_action_state]
old_value = self.value_map[current_action_state]
self.value_map[current_action_state] = old_value + self.alpha * float((reward + self.discount_factor * target_value - old_value))
return next_state if not self.cliff[next_state] else None, reward
class Game(object):
def __init__(self, width, height, no_episodes, alpha=0.1, discount_factor=1, epsilon=0.1):
self.width = width
self.height = height
self.value_map = ValueMap(width, height, epsilon, discount_factor, alpha)
self.no_action_per_episodes = {}
self.rewards_per_episodes = {}
self.no_episodes = no_episodes
self.start = State(0, 0)
self.end = State(0, width - 1)
self.actions_episodes = []
self.rewards_episodes = []
self.last_episode_actions = []
self.sarsa_actions_episodes = []
self.sarsa_rewards_episodes = []
self.sasra_last_episode_actions = []
def plot(self):
noshow = True
labels = []
labels.append(r'Q learning')
labels.append(r'SARSA learning')
# Plot the episode length over time
fig1 = plt.figure(figsize=(10,6))
plt.plot(self.actions_episodes)
plt.plot(self.sarsa_actions_episodes)
plt.xlabel("Epsiode")
plt.ylabel("Epsiode Length")
plt.title("Episode Length over Time")
plt.legend(labels, ncol=4, loc='center left',
bbox_to_anchor=[0.5, 1.1],
columnspacing=1.0, labelspacing=0.0,
handletextpad=0.0, handlelength=1.5,
fancybox=True, shadow=True)
plt.savefig('cliff_world_episode_length.png')
if noshow:
plt.close(fig1)
else:
plt.show(fig1)
# # Plot the episode reward over time
fig2 = plt.figure(figsize=(10,6))
smoothing_window = 10
rewards_smoothed = pd.Series(self.rewards_episodes).rolling(smoothing_window, min_periods=smoothing_window).mean()
plt.plot(rewards_smoothed)
rewards_smoothed = pd.Series(self.sarsa_rewards_episodes).rolling(smoothing_window, min_periods=smoothing_window).mean()
plt.plot(rewards_smoothed)
plt.xlabel("Epsiode")
plt.ylabel("Epsiode Reward (Smoothed)")
plt.title("Episode Reward over Time (Smoothed over window size {})".format(smoothing_window))
plt.legend(labels, ncol=4, loc='center left',
bbox_to_anchor=[0.5, 1.1],
columnspacing=1.0, labelspacing=0.0,
handletextpad=0.0, handlelength=1.5,
fancybox=True, shadow=True)
plt.savefig('cliff_world_episode_reward.png')
if noshow:
plt.close(fig2)
else:
plt.show(fig2)
# Plot time steps and episode number
fig3 = plt.figure(figsize=(10,6))
plt.plot(np.cumsum(self.actions_episodes), np.arange(len(self.actions_episodes)))
plt.plot(np.cumsum(self.sarsa_actions_episodes), np.arange(len(self.sarsa_actions_episodes)))
plt.xlabel("Time Steps")
plt.ylabel("Episode")
plt.title("Episode per time step")
plt.legend(labels, ncol=4, loc='center left',
bbox_to_anchor=[0.5, 1.1],
columnspacing=1.0, labelspacing=0.0,
handletextpad=0.0, handlelength=1.5,
fancybox=True, shadow=True)
plt.savefig('cliff_world_episode_timestamp.png')
if noshow:
plt.close(fig3)
else:
plt.show(fig3)
plt.plot(self.start.y, self.start.x, 'x', markersize=20)
previous_x = self.start.y
previous_y = self.start.x
for position in self.last_episode_actions[1:]:
x, y = position.y, position.x
plt.arrow(previous_x, previous_y, x - previous_x, y - previous_y, head_width=0.3, head_length=0.3, overhang=0, color='blue', label="Q learning")
plt.plot(x, y, 'o', markersize=5)
previous_x = x
previous_y = y
previous_x = self.start.y
previous_y = self.start.x
for position in self.sasra_last_episode_actions[1:]:
x, y = position.y, position.x
plt.arrow(previous_x, previous_y, x - previous_x, y - previous_y, head_width=0.3, head_length=0.3, overhang=0, color='red', label="SARSA learning")
plt.plot(x, y, 'o', markersize=5)
previous_x = x
previous_y = y
plt.plot(self.end.y, self.end.x, 'x', markersize=20)
axes = plt.gca()
axes.set_xticks(range(-1, self.width + 1))
axes.set_yticks(range(-1, self.height + 1))
axes.set_title('path to reach destination')
axes.add_patch(
patches.Rectangle(
(1, 0), # (x,y)
self.width - 3, # width
1, # height
alpha=0.1
)
)
red_patch = mpatches.Patch(color='red', label='SARSA Learning')
blue_patch = mpatches.Patch(color='blue', label='Q Learning')
plt.legend(handles=[red_patch, blue_patch])
plt.grid()
# plt.show()
plt.savefig('cliff_world_path.png')
def play(self, is_sarsa=False):
self.value_map.initialize()
for i in range(self.no_episodes):
current_state = self.start
actions = 0
rewards = 0
is_greddy = False
if i == self.no_episodes - 1:
is_greddy = True
if is_sarsa:
self.sasra_last_episode_actions.append(current_state)
else:
self.last_episode_actions.append(current_state)
while not(current_state == self.end):
if actions > 2000:
break
current_action = self.value_map.select_next_action(current_state, is_greddy=is_greddy)
next_state, reward = self.value_map.update_value_map(current_state, current_action, is_sarsa=is_sarsa, is_greddy=is_greddy)
if next_state is None:
next_state = self.start
current_state = next_state
actions = actions + 1
rewards = rewards + reward
if i == self.no_episodes - 1:
print current_action, current_state
if is_sarsa:
self.sasra_last_episode_actions.append(current_state)
else:
self.last_episode_actions.append(current_state)
if is_sarsa:
self.sarsa_actions_episodes.append(actions)
self.sarsa_rewards_episodes.append(rewards)
else:
self.actions_episodes.append(actions)
self.rewards_episodes.append(rewards)
game = Game(12, 4, 1000)
game.play()
game.play(is_sarsa=True)
game.plot()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment