Skip to content

Instantly share code, notes, and snippets.

@SolClover
Created October 16, 2022 06:34
Show Gist options
  • Save SolClover/76169dbb3aa595de273de60f42b251a8 to your computer and use it in GitHub Desktop.
Save SolClover/76169dbb3aa595de273de60f42b251a8 to your computer and use it in GitHub Desktop.
SARSA training function
def train(n_episodes, n_max_steps, start_epsilon, min_epsilon, decay_rate, Qtable):
for episode in range(n_episodes):
# Reset the environment at the start of each episode
state, info = env.reset()
t = 0
done = False
# Calculate epsilon value based on decay rate
epsilon = max(min_epsilon, (start_epsilon - min_epsilon)*np.exp(-decay_rate*episode))
# Choose an action using previously defined epsilon-greedy policy
action = epsilon_greedy(Qtable, state, epsilon)
for t in range(n_max_steps):
# Perform the action in the environment, get reward and next state
next_state, reward, done, _, info = env.step(action)
# Choose next action
next_action=epsilon_greedy(Qtable, next_state, epsilon)
# Update Q-table
Qtable = update_Q(Qtable, state, action, reward, next_state, next_action)
# Update current state
state = next_state
action = next_action
# Finish the episode when done=True, i.e., reached the goal or fallen into a hole
if done:
break
# Return final Q-table
return Qtable
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment