Skip to content

Instantly share code, notes, and snippets.

@dharma6872
Created January 29, 2021 09:57
Show Gist options
  • Save dharma6872/d290523052db536dae5b7c2cdefed14f to your computer and use it in GitHub Desktop.
Save dharma6872/d290523052db536dae5b7c2cdefed14f to your computer and use it in GitHub Desktop.
[Lecture 4 Q learning] #강화학습
# -*- coding: utf-8 -*-
"""Lecture 4 Q learning"""
import gym
import numpy as np
import random
import matplotlib.pyplot as plt
from gym.envs.registration import register
register(
id = "FrozenLake-v3",
entry_point = "gym.envs.toy_text:FrozenLakeEnv",
kwargs = {"map_name": "4x4", "is_slippery": False}
)
env = gym.make("FrozenLake-v3")
# Q 테이블 0 으로 초기화
Q = np.zeros([env.observation_space.n, env.action_space.n])
# 디스카운터 팩터
dis = .99
num_episodes = 2000
# 에피소드 별 보상과 스텝 정보를 저장
rList = []
for i in range(num_episodes):
# 환경 초기화
state = env.reset()
rAll = 0
done = False
# Q 테이블 학습
while not done:
# 노이즈 포함 방식
action = np.argmax((Q[state, :]) + np.random.randn(1, env.action_space.n) / (1 + i))
# reward 받기
new_state, reward, done, _ = env.step(action)
# Q 테이블 업데이트
Q[state, action] = reward + dis * np.max(Q[new_state, :])
rAll += reward
state = new_state
rList.append(rAll)
print("Success rate: " + str(sum(rList) / num_episodes))
print("Final Q-Table Values")
print(Q)
plt.bar(range(len(rList)), rList, color="blue")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment