Failed first attempt to adapt simple Q-Learning to CartPole example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import numpy as np | |
import matplotlib.pyplot as plt | |
class hello_cartpole: | |
def __init__(self): | |
self.env = gym.make('CartPole-v0') | |
self.env.reset() | |
def close(self): | |
self.env.close() | |
# Given a new observed state, see if it exceeded our existing set of | |
# observed min/max values and update accordingly. | |
def update_observed_min_max(self, observed): | |
for item in range(0,len(observed)): # What's a more Python-y way to do this? | |
if self.observed_min[item] > observed[item]: | |
self.observed_min[item] = observed[item] | |
if self.observed_max[item] < observed[item]: | |
self.observed_max[item] = observed[item] | |
# The full range of an environment's observable space is available, explore | |
# how much of that we actually see in a randomized sampling. | |
def sample_observation_space(self): | |
# We start by setting our "observed" min/max to the opposite extreme. | |
self.observed_min = self.env.observation_space.high | |
self.observed_max = self.env.observation_space.low | |
print('Sampling starting from {} to {}'.format(self.observed_min, self.observed_max)) | |
# Sample over 1000 episodes | |
for episode in range(1,1001): | |
state = self.env.reset() | |
self.update_observed_min_max(state) | |
done = False | |
while done != True: | |
state, reward, done, info = self.env.step(self.env.action_space.sample()) | |
self.update_observed_min_max(state) | |
print('Observed range from {} to {}'.format(self.observed_min, self.observed_max)) | |
# Given a new observed state, generate a discrete index into an array | |
# for the state. Each element generates a digit into the index based | |
# on where it is between observed min and max. So for a four-element | |
# array, the index is between 0 and 9999. | |
def discrete_state(self, observed): | |
digit = 1 | |
sum = 0 | |
value = 0 | |
for item in range(0,len(observed)): | |
if self.observed_min[item] > observed[item]: | |
# A new value might be less than sampled min, assign to 0. | |
value = 0 | |
elif self.observed_max[item] < observed[item]: | |
# A new value might be greater than sampled max, assign to 9. | |
value = 9 | |
else: | |
# Everything else is divided into the observed range. | |
value = int(10.0*(observed[item] - self.observed_min[item])/(self.observed_max[item] - self.observed_min[item])) | |
# Add result into digit, move on to next digit. | |
sum += (value * digit) | |
digit *= 10 | |
return sum | |
# An attempt to adapt discrete Q-learning algorithm from this page to CartPole | |
# https://www.oreilly.com/learning/introduction-to-reinforcement-learning-and-openai-gym | |
def q_learning(self): | |
self.history = list() | |
done = False | |
Q = np.zeros([10000, self.env.action_space.n]) | |
G = 0 | |
alpha = 0.618 # Don't understand where this value came from yet | |
for episode in range (1,10001): | |
done = False | |
G, reward = 0,0 | |
state = self.env.reset() | |
state_d = self.discrete_state(state) | |
while done != True: | |
action = np.argmax(Q[state_d]) | |
state2, reward, done, info = self.env.step(action) | |
state2_d = self.discrete_state(state2) | |
Q[state_d,action] += alpha * (reward + np.max(Q[state2_d]) - Q[state_d,action]) | |
G += reward | |
state_d = state2_d | |
if episode % 50 == 0: | |
print('Episode {} Total Reward: {}'.format(episode,G)) | |
self.history.append(G) | |
print(np.count_nonzero(Q)) | |
def plot_g(self): | |
plt.plot(self.history) | |
plt.show() | |
if __name__ == "__main__": | |
hc = hello_cartpole() | |
hc.sample_observation_space() | |
hc.q_learning() | |
hc.plot_g() | |
hc.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment