This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import gym | |
env = gym.make("CartPole-v1") | |
# there are two actions | |
print("number of actions: " + np.str(env.action_space.n)) | |
# check the observations bounds for each (position, velocity, angle, angular velocity) | |
for _ in range(4): | |
print (env.observation_space.high[_]) | |
print (env.observation_space.low[_]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def discretize_state(self, obs): | |
discretized = list() | |
for i in range(len(obs)): | |
scaling = (obs[i] + abs(self.lower_bounds[i])) / (self.upper_bounds[i] - self.lower_bounds[i]) | |
new_obs = int(round((self.buckets[i] - 1) * scaling)) | |
new_obs = min(self.buckets[i] - 1, max(0, new_obs)) | |
discretized.append(new_obs) | |
return tuple(discretized) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def QLupdate(self, state, action, reward, new_state): | |
# updating the Q-value of the visited state-action pair | |
self.Q_table[state][action] += self.learning_rate * (reward + self.discount * np.max(self.Q_table[new_state]) - self.Q_table[state][action]) | |
def SARSAupdate(self, state, action, reward, new_state, next_action): | |
# updating the Q-value of the visited state-action pair | |
self.Q_table[state][action] += self.learning_rate * (reward + self.discount * self.Q_table[new_state][next_action] - self.Q_table[state][action]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def QLtrain(self): | |
cum_reward = np.zeros((self.num_episodes)) | |
for ep in range(self.num_episodes): | |
current_state = self.discretize_state(self.env.reset()) | |
done = False | |
while not done: | |
#choosing action according to our exploration-exploitation policy | |
action = self.choose_action(current_state) | |
obs, reward, done, _ = self.env.step(action) | |
cum_reward[ep]+=reward |