Skip to content
{{ message }}

Instantly share code, notes, and snippets.

# Roger-random/hello_cartpole.py

Created Jun 2, 2019
Failed first attempt to adapt simple Q-Learning to CartPole example
 import gym import numpy as np import matplotlib.pyplot as plt class hello_cartpole: def __init__(self): self.env = gym.make('CartPole-v0') self.env.reset() def close(self): self.env.close() # Given a new observed state, see if it exceeded our existing set of # observed min/max values and update accordingly. def update_observed_min_max(self, observed): for item in range(0,len(observed)): # What's a more Python-y way to do this? if self.observed_min[item] > observed[item]: self.observed_min[item] = observed[item] if self.observed_max[item] < observed[item]: self.observed_max[item] = observed[item] # The full range of an environment's observable space is available, explore # how much of that we actually see in a randomized sampling. def sample_observation_space(self): # We start by setting our "observed" min/max to the opposite extreme. self.observed_min = self.env.observation_space.high self.observed_max = self.env.observation_space.low print('Sampling starting from {} to {}'.format(self.observed_min, self.observed_max)) # Sample over 1000 episodes for episode in range(1,1001): state = self.env.reset() self.update_observed_min_max(state) done = False while done != True: state, reward, done, info = self.env.step(self.env.action_space.sample()) self.update_observed_min_max(state) print('Observed range from {} to {}'.format(self.observed_min, self.observed_max)) # Given a new observed state, generate a discrete index into an array # for the state. Each element generates a digit into the index based # on where it is between observed min and max. So for a four-element # array, the index is between 0 and 9999. def discrete_state(self, observed): digit = 1 sum = 0 value = 0 for item in range(0,len(observed)): if self.observed_min[item] > observed[item]: # A new value might be less than sampled min, assign to 0. value = 0 elif self.observed_max[item] < observed[item]: # A new value might be greater than sampled max, assign to 9. value = 9 else: # Everything else is divided into the observed range. value = int(10.0*(observed[item] - self.observed_min[item])/(self.observed_max[item] - self.observed_min[item])) # Add result into digit, move on to next digit. sum += (value * digit) digit *= 10 return sum # An attempt to adapt discrete Q-learning algorithm from this page to CartPole # https://www.oreilly.com/learning/introduction-to-reinforcement-learning-and-openai-gym def q_learning(self): self.history = list() done = False Q = np.zeros([10000, self.env.action_space.n]) G = 0 alpha = 0.618 # Don't understand where this value came from yet for episode in range (1,10001): done = False G, reward = 0,0 state = self.env.reset() state_d = self.discrete_state(state) while done != True: action = np.argmax(Q[state_d]) state2, reward, done, info = self.env.step(action) state2_d = self.discrete_state(state2) Q[state_d,action] += alpha * (reward + np.max(Q[state2_d]) - Q[state_d,action]) G += reward state_d = state2_d if episode % 50 == 0: print('Episode {} Total Reward: {}'.format(episode,G)) self.history.append(G) print(np.count_nonzero(Q)) def plot_g(self): plt.plot(self.history) plt.show() if __name__ == "__main__": hc = hello_cartpole() hc.sample_observation_space() hc.q_learning() hc.plot_g() hc.close()
to join this conversation on GitHub. Already have an account? Sign in to comment