Skip to content

Instantly share code, notes, and snippets.

@MattChanTK
Last active March 24, 2018 07:39
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save MattChanTK/90701191b50c2e7bafca167858fcb234 to your computer and use it in GitHub Desktop.
Save MattChanTK/90701191b50c2e7bafca167858fcb234 to your computer and use it in GitHub Desktop.
import gym
import numpy as np
import random
import math
from time import sleep
## Initialize the "Cart-Pole" environment
env = gym.make('CartPole-v0')
## Defining the environment related constants
# Number of discrete states (bucket) per state dimension
NUM_BUCKETS = (1, 1, 6, 3) # (x, x', theta, theta')
# Number of discrete actions
NUM_ACTIONS = env.action_space.n # (left, right)
# Bounds for each discrete state
STATE_BOUNDS = list(zip(env.observation_space.low, env.observation_space.high))
STATE_BOUNDS[1] = [-0.5, 0.5]
STATE_BOUNDS[3] = [-math.radians(50), math.radians(50)]
# Index of the action
ACTION_INDEX = len(NUM_BUCKETS)
## Creating a Q-Table for each state-action pair
q_table = np.zeros(NUM_BUCKETS + (NUM_ACTIONS,))
## Learning related constants
MIN_EXPLORE_RATE = 0.01
MIN_LEARNING_RATE = 0.1
## Defining the simulation related constants
NUM_EPISODES = 1000
MAX_T = 250
STREAK_TO_END = 120
SOLVED_T = 199
DEBUG_MODE = True
def simulate():
## Instantiating the learning related parameters
learning_rate = get_learning_rate(0)
explore_rate = get_explore_rate(0)
discount_factor = 0.99 # since the world is unchanging
num_streaks = 0
for episode in range(NUM_EPISODES):
# Reset the environment
obv = env.reset()
# the initial state
state_0 = state_to_bucket(obv)
for t in range(MAX_T):
env.render()
# Select an action
action = select_action(state_0, explore_rate)
# Execute the action
obv, reward, done, _ = env.step(action)
# Observe the result
state = state_to_bucket(obv)
# Update the Q based on the result
best_q = np.amax(q_table[state])
q_table[state_0 + (action,)] += learning_rate*(reward + discount_factor*(best_q) - q_table[state_0 + (action,)])
# Setting up for the next iteration
state_0 = state
# Print data
if (DEBUG_MODE):
print("\nEpisode = %d" % episode)
print("t = %d" % t)
print("Action: %d" % action)
print("State: %s" % str(state))
print("Reward: %f" % reward)
print("Best Q: %f" % best_q)
print("Explore rate: %f" % explore_rate)
print("Learning rate: %f" % learning_rate)
print("Streaks: %d" % num_streaks)
print("")
if done:
print("Episode %d finished after %f time steps" % (episode, t))
if (t >= SOLVED_T):
num_streaks += 1
else:
num_streaks = 0
break
#sleep(0.25)
# It's considered done when it's solved over 120 times consecutively
if num_streaks > STREAK_TO_END:
break
# Update parameters
explore_rate = get_explore_rate(episode)
learning_rate = get_learning_rate(episode)
def select_action(state, explore_rate):
# Select a random action
if random.random() < explore_rate:
action = env.action_space.sample()
# Select the action with the highest q
else:
action = np.argmax(q_table[state])
return action
def get_explore_rate(t):
return max(MIN_EXPLORE_RATE, min(1, 1.0 - math.log10((t+1)/25)))
def get_learning_rate(t):
return max(MIN_LEARNING_RATE, min(0.5, 1.0 - math.log10((t+1)/25)))
def state_to_bucket(state):
bucket_indice = []
for i in range(len(state)):
if state[i] <= STATE_BOUNDS[i][0]:
bucket_index = 0
elif state[i] >= STATE_BOUNDS[i][1]:
bucket_index = NUM_BUCKETS[i] - 1
else:
# Mapping the state bounds to the bucket array
bound_width = STATE_BOUNDS[i][1] - STATE_BOUNDS[i][0]
offset = (NUM_BUCKETS[i]-1)*STATE_BOUNDS[i][0]/bound_width
scaling = (NUM_BUCKETS[i]-1)/bound_width
bucket_index = int(round(scaling*state[i] - offset))
bucket_indice.append(bucket_index)
return tuple(bucket_indice)
if __name__ == "__main__":
simulate()
@cosme12
Copy link

cosme12 commented Sep 14, 2017

Hello, I read your blog post and its great! But I have a question. I ran the script again with the last best matrix found as the default q_table No improvement is shown. Is it because the environment changes when reloading the script? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment