Skip to content

Instantly share code, notes, and snippets.

@tomykaira
Created February 4, 2019 06:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tomykaira/9e14557685dc4b8c73aa3c5decefb122 to your computer and use it in GitHub Desktop.
Save tomykaira/9e14557685dc4b8c73aa3c5decefb122 to your computer and use it in GitHub Desktop.
# Q-learning with designed reward
# import gym
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import gym
import random
init_notebook_mode(connected=True)
env = gym.make('CartPole-v1')
data = []
# tweak start with > 0 params
theta = np.random.uniform(low=0, high=10, size=(4))
alpha = 0.001
gamma = 0.50
decay = 0.000
theta_log = []
def q(state, action):
return (theta @ state) * action
def update_qtable(i, state, action, reward, next_state):
lr = alpha * (1. / (1. + decay * i))
q_max = max(q(next_state, -1), q(next_state, 1))
delta_q = lr * (reward + gamma * q_max - q(state, action))
for i in range(4):
theta[i] += delta_q * state[i]
def make(state):
return 1 if q(state, 1) >= q(state, -1) else -1
theta_log.append([0, theta.copy()])
turns = []
for i in range(20000):
obs = env.reset()
turn = 0
action = make(obs)
while True:
next_obs, reward, done, _ = env.step(1 if action == 1 else 0)
turn += 1
# tweak on reward
if turn < 500 and done:
reward = -500
else:
reward += turn / 10
next_action = make(next_obs)
update_qtable(i, obs, action, reward, next_obs)
obs = next_obs
action = next_action
if done:
theta_log.append([i+1, theta.copy()])
turns.append(turn)
break
plt = []
for i in range(4):
plt.append(go.Scatter(x=[x[0] for x in theta_log], y=[x[1][i] for x in theta_log]))
iplot(plt)
print(theta, turn, sum(turns[-50:-1]) / 50)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment