Skip to content

Instantly share code, notes, and snippets.

@tomykaira
Created January 31, 2019 15:54
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save tomykaira/29fdf19bd17e616464277fab6a18e2c9 to your computer and use it in GitHub Desktop.
# import gym
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import gym
import random
init_notebook_mode(connected=True)
env = gym.make('CartPole-v1')
data = []
theta = np.random.uniform(low=0, high=10, size=(4))
alpha = 0.001
gamma = 0.98
decay = 0.001
theta_log = []
def q(state):
return theta @ state
def encode_action(action):
return 1 if action == 1 else -1
def update_qtable(i, state, action, reward, next_state, next_action):
lr = alpha * (1. / (1. + decay * i))
q_next = q(next_state)
delta_q = lr * (reward + gamma * q_next * encode_action(next_action) - q(state) * encode_action(action))
for i in range(4):
theta[i] += delta_q * state[i]
def make(state):
return 1 if q(state) >= 0 else 0
theta_log.append([0, theta.copy()])
turns = []
for i in range(20000):
obs = env.reset()
turn = 0
action = make(obs)
while True:
next_obs, reward, done, _ = env.step(action)
turn += 1
if turn < 500 and done:
reward = -500
else:
reward += turn / 10
next_action = make(next_obs)
update_qtable(i, obs, action, reward, next_obs, next_action)
obs = next_obs
action = next_action
if done:
theta_log.append([i+1, theta.copy()])
turns.append(turn)
break
plt = []
for i in range(4):
plt.append(go.Scatter(x=[x[0] for x in theta_log], y=[x[1][i] for x in theta_log]))
iplot(plt)
print(theta, turn, sum(turns[-50:-1]) / 50)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment