Skip to content

Instantly share code, notes, and snippets.

View SolClover's full-sized avatar

SolClover SolClover

View GitHub Profile
@SolClover
SolClover / Art057_Python_012.py
Created October 16, 2022 07:31
Evaluate agent by visualizing its actions
# Reset environment to initial state
state, info = env.reset()
# Cycle through 50 steps redering and displaying environment state each time
for _ in range(50):
# Render and display current state of the environment
plt.imshow(env.render()) # render current state and pass to pyplot
plt.axis('off')
display.display(plt.gcf()) # get current figure and display
@SolClover
SolClover / Art057_Python_011.py
Created October 16, 2022 06:50
Call evaluation function and plot the results
# Evaluate
n_eval_episodes=10000
mean_reward, std_reward, episode_rewards = evaluate_agent(n_max_steps, n_eval_episodes, Qtable)
# Print evaluation results
print(f"Mean Reward = {mean_reward:.2f} +/- {std_reward:.2f}")
print(f"Min = {min(episode_rewards):.1f} and Max {max(episode_rewards):.1f}")
# Show the distribution of rewards obtained from evaluation
plt.figure(figsize=(9,6), dpi=200)
@SolClover
SolClover / Art057_Python_010.py
Created October 16, 2022 06:49
Function to evaluate agent's performance
def evaluate_agent(n_max_steps, n_eval_episodes, Qtable):
# Initialize an empty list to store rewards for each episode
episode_rewards=[]
# Evaluate for each episode
for episode in range(n_eval_episodes):
# Reset the environment at the start of each episode
state, info = env.reset()
t = 0
@SolClover
SolClover / Art057_Python_009.py
Created October 16, 2022 06:38
Train the model using SARSA algorithm
# Train
Qtable = train(n_episodes, n_max_steps, start_epsilon, min_epsilon, decay_rate, Qtable)
# Show Q-table
Qtable
@SolClover
SolClover / Art057_Python_008.py
Created October 16, 2022 06:34
SARSA training function
def train(n_episodes, n_max_steps, start_epsilon, min_epsilon, decay_rate, Qtable):
for episode in range(n_episodes):
# Reset the environment at the start of each episode
state, info = env.reset()
t = 0
done = False
# Calculate epsilon value based on decay rate
epsilon = max(min_epsilon, (start_epsilon - min_epsilon)*np.exp(-decay_rate*episode))
@SolClover
SolClover / Art057_Python_007.py
Created October 16, 2022 06:30
Define functions to use in training and evaluation
# This is our acting policy (epsilon-greedy), which selects an action for exploration/exploitation during training
def epsilon_greedy(Qtable, state, epsilon):
# Generate a random number and compare to epsilon, if lower then explore, otherwise exploit
randnum = np.random.uniform(0, 1)
if randnum < epsilon:
action = env.action_space.sample() # explore
else:
action = np.argmax(Qtable[state, :]) # exploit
return action
@SolClover
SolClover / Art057_Python_006.py
Created October 16, 2022 06:21
Initialize Q-table
# Initial Q-table
# Our Q-table is a matrix of state(observation) space x action space, i.e., 500 x 6
Qtable = np.zeros((env.observation_space.n, env.action_space.n))
# Show
Qtable
# SARSA parameters
alpha = 0.1 # learning rate
gamma = 0.95 # discount factor
# Training parameters
n_episodes = 100000 # number of episodes to use for training
n_max_steps = 100 # maximum number of steps per episode
# Exploration / Exploitation parameters
start_epsilon = 1.0 # start training by selecting purely random actions
@SolClover
SolClover / Art057_Python_004.py
Created October 15, 2022 10:44
Agent performing random actions around the environment
# Reset environment to initial state
state, info = env.reset()
# Cycle through 30 random steps redering and displaying the agent inside the environment each time
for _ in range(30):
# Render and display current state of the environment
plt.imshow(env.render()) # render current state and pass to pyplot
plt.axis('off')
display.display(plt.gcf()) # get current figure and display
display.clear_output(wait=True) # clear output before showing the next frame
@SolClover
SolClover / Art057_Python_003.py
Created October 15, 2022 10:35
Show key information about the environment
# Show environment description (map) as an array
print("Environment Array: ")
print(env.desc)
# Observation and action space
state_obs_space = env.observation_space # Returns sate(observation) space of the environment.
action_space = env.action_space # Returns action space of the environment.
print("State(Observation) space:", state_obs_space)
print("Action space:", action_space)