Last active
May 23, 2021 08:08
-
-
Save joooyzee/25672a72ea4c3e1675460b934a7884e4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import gym | |
import random | |
def main(): | |
# create Taxi environment | |
env = gym.make('Taxi-v3') | |
# initialize q-table | |
state_size = env.observation_space.n | |
action_size = env.action_space.n | |
qtable = np.zeros((state_size, action_size)) | |
# hyperparameters | |
learning_rate = 0.9 | |
discount_rate = 0.8 | |
epsilon = 1.0 | |
decay_rate= 0.005 | |
# training variables | |
num_episodes = 1000 | |
max_steps = 99 # per episode | |
# training | |
for episode in range(num_episodes): | |
# reset the environment | |
state = env.reset() | |
done = False | |
for s in range(max_steps): | |
# exploration-exploitation tradeoff | |
if random.uniform(0,1) < epsilon: | |
# explore | |
action = env.action_space.sample() | |
else: | |
# exploit | |
action = np.argmax(qtable[state,:]) | |
# take action and observe reward | |
new_state, reward, done, info = env.step(action) | |
# Q-learning algorithm | |
qtable[state,action] = qtable[state,action] + learning_rate * (reward + discount_rate * np.max(qtable[new_state,:])-qtable[state,action]) | |
# Update to our new state | |
state = new_state | |
# if done, finish episode | |
if done == True: | |
break | |
# Decrease epsilon | |
epsilon = np.exp(-decay_rate*episode) | |
print(f"Training completed over {num_episodes} episodes") | |
input("Press Enter to watch trained agent...") | |
# watch trained agent | |
state = env.reset() | |
done = False | |
rewards = 0 | |
for s in range(max_steps): | |
print(f"TRAINED AGENT") | |
print("Step {}".format(s+1)) | |
action = np.argmax(qtable[state,:]) | |
new_state, reward, done, info = env.step(action) | |
rewards += reward | |
env.render() | |
print(f"score: {rewards}") | |
state = new_state | |
if done == True: | |
break | |
env.close() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment