Skip to content

Instantly share code, notes, and snippets.

Created June 2, 2019 00:31
Show Gist options
  • Save Roger-random/72cb438d52f9899c44eb5a525ea1b249 to your computer and use it in GitHub Desktop.
Save Roger-random/72cb438d52f9899c44eb5a525ea1b249 to your computer and use it in GitHub Desktop.
Failed first attempt to adapt simple Q-Learning to CartPole example
import gym
import numpy as np
import matplotlib.pyplot as plt
class hello_cartpole:
def __init__(self):
self.env = gym.make('CartPole-v0')
def close(self):
# Given a new observed state, see if it exceeded our existing set of
# observed min/max values and update accordingly.
def update_observed_min_max(self, observed):
for item in range(0,len(observed)): # What's a more Python-y way to do this?
if self.observed_min[item] > observed[item]:
self.observed_min[item] = observed[item]
if self.observed_max[item] < observed[item]:
self.observed_max[item] = observed[item]
# The full range of an environment's observable space is available, explore
# how much of that we actually see in a randomized sampling.
def sample_observation_space(self):
# We start by setting our "observed" min/max to the opposite extreme.
self.observed_min = self.env.observation_space.high
self.observed_max = self.env.observation_space.low
print('Sampling starting from {} to {}'.format(self.observed_min, self.observed_max))
# Sample over 1000 episodes
for episode in range(1,1001):
state = self.env.reset()
done = False
while done != True:
state, reward, done, info = self.env.step(self.env.action_space.sample())
print('Observed range from {} to {}'.format(self.observed_min, self.observed_max))
# Given a new observed state, generate a discrete index into an array
# for the state. Each element generates a digit into the index based
# on where it is between observed min and max. So for a four-element
# array, the index is between 0 and 9999.
def discrete_state(self, observed):
digit = 1
sum = 0
value = 0
for item in range(0,len(observed)):
if self.observed_min[item] > observed[item]:
# A new value might be less than sampled min, assign to 0.
value = 0
elif self.observed_max[item] < observed[item]:
# A new value might be greater than sampled max, assign to 9.
value = 9
# Everything else is divided into the observed range.
value = int(10.0*(observed[item] - self.observed_min[item])/(self.observed_max[item] - self.observed_min[item]))
# Add result into digit, move on to next digit.
sum += (value * digit)
digit *= 10
return sum
# An attempt to adapt discrete Q-learning algorithm from this page to CartPole
def q_learning(self):
self.history = list()
done = False
Q = np.zeros([10000, self.env.action_space.n])
G = 0
alpha = 0.618 # Don't understand where this value came from yet
for episode in range (1,10001):
done = False
G, reward = 0,0
state = self.env.reset()
state_d = self.discrete_state(state)
while done != True:
action = np.argmax(Q[state_d])
state2, reward, done, info = self.env.step(action)
state2_d = self.discrete_state(state2)
Q[state_d,action] += alpha * (reward + np.max(Q[state2_d]) - Q[state_d,action])
G += reward
state_d = state2_d
if episode % 50 == 0:
print('Episode {} Total Reward: {}'.format(episode,G))
def plot_g(self):
if __name__ == "__main__":
hc = hello_cartpole()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment