Last active
September 17, 2019 09:57
-
-
Save PierreExeter/b9d731dade09f5257d0548601e0da3b6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# HYPERPARAMETERS | |
n_episodes = 1000 # Total train episodes | |
n_steps = 200 # Max steps per episode | |
min_alpha = 0.1 # learning rate | |
min_epsilon = 0.1 # exploration rate | |
gamma = 1 # discount factor | |
ada_divisor = 25 # decay rate parameter for alpha and epsilon | |
# INITIALISE Q MATRIX | |
Q = np.zeros(buckets + (n_actions,)) | |
print(np.shape(Q)) | |
def discretize(obs): | |
''' discretise the continuous state into buckets ''' | |
ratios = [(obs[i] + abs(lower_bounds[i])) / (upper_bounds[i] - lower_bounds[i]) for i in range(len(obs))] | |
new_obs = [int(round((buckets[i] - 1) * ratios[i])) for i in range(len(obs))] | |
new_obs = [min(buckets[i] - 1, max(0, new_obs[i])) for i in range(len(obs))] | |
return tuple(new_obs) | |
def epsilon_policy(state, epsilon): | |
''' choose an action using the epsilon policy ''' | |
exploration_exploitation_tradeoff = np.random.random() | |
if exploration_exploitation_tradeoff <= epsilon: | |
action = env.action_space.sample() # exploration | |
else: | |
action = np.argmax(Q[state]) # exploitation | |
return action | |
def greedy_policy(state): | |
''' choose an action using the greedy policy ''' | |
return np.argmax(Q[state]) | |
def update_q(current_state, action, reward, new_state, alpha): | |
''' update the Q matrix with the Bellman equation ''' | |
Q[current_state][action] += alpha * (reward + gamma * np.max(Q[new_state]) - Q[current_state][action]) | |
def get_epsilon(t): | |
''' decrease the exploration rate at each episode ''' | |
return max(min_epsilon, min(1, 1.0 - math.log10((t + 1) / ada_divisor))) | |
def get_alpha(t): | |
''' decrease the learning rate at each episode ''' | |
return max(min_alpha, min(1.0, 1.0 - math.log10((t + 1) / ada_divisor))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment