Skip to content

Instantly share code, notes, and snippets.

@dharma6872
Created January 29, 2021 09:55
Show Gist options
  • Save dharma6872/cbfe5315b12e224c67c4402d48cc30ed to your computer and use it in GitHub Desktop.
Save dharma6872/cbfe5315b12e224c67c4402d48cc30ed to your computer and use it in GitHub Desktop.
[Lecture 3 Dummy Q-learning] #강화학습
# -*- coding: utf-8 -*-
"""Lecture 3 Dummy Q-learning (table)"""
import numpy as np
import gym
from gym.envs.registration import register
import random as pr
import matplotlib.pyplot as plt
# argmax that chooses randomly among eligible(적격의) maxium indices
def rargmax(vector):
m = np.amax(vector)
indices = np.nonzero(vector == m)[0]
return pr.choice(indices)
register(
id = "FrozenLake-v3",
entry_point = "gym.envs.toy_text:FrozenLakeEnv",
kwargs = {"map_name": "4x4", "is_slippery": False}
)
env = gym.make("FrozenLake-v3")
#Initialize table with all zeros
Q = np.zeros([env.observation_space.n, env.action_space.n])
#Set learning parameters
num_episodes = 2000
# Q
# num_episodes
# create lists to contain total rewards and steps per episode
rList = []
for i in range(num_episodes):
#Reset environment and get first new observation
state = env.reset()
rAll = 0
done = False
#The Q-Table learning algorithm
while not done:
action = rargmax(Q[state, :])
#Get new state and reward from environment
new_state, reward, done, _ = env.step(action)
#Update Q-Table with new knowledge using learning rate
Q[state, action] = reward + np.max(Q[new_state, :])
rAll += reward
state = new_state
rList.append(rAll)
print("Success rate: " + str(sum(rList)/num_episodes))
print("Final Q-Table Values")
print("LEFT DOWN RIGHT UP")
print(Q)
plt.bar(range(len(rList)), rList, color="blue")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment