Skip to content

Instantly share code, notes, and snippets.

@tomykaira
Created January 31, 2019 09:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tomykaira/da1d0a67c390255386995615ea75023f to your computer and use it in GitHub Desktop.
Save tomykaira/da1d0a67c390255386995615ea75023f to your computer and use it in GitHub Desktop.
import gym
import numpy as np
# import plotly.plotly as py
# import plotly.graph_objs as go
# from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
# import random
# init_notebook_mode(connected=True)
env = gym.make('CartPole-v1')
data = []
theta = [
np.random.uniform(low=-1, high=1, size=(2)),
np.random.uniform(low=-1, high=1, size=(2))
]
alpha = 0.1
gamma = 0.98
def q(state, action):
return theta[action] @ state[2:4]
def update_qtable(state, action, reward, next_state, next_action):
delta_q = alpha * (reward + gamma * q(next_state, next_action) - q(state, action))
for i in range(2):
theta[action][i] += delta_q * state[i+2] * 0.01
def make(state):
if q(state, 0) > q(state, 1):
return 0
return 1
import itertools
turns = []
vals = []
for i in range(10000):
state = env.reset()
action = make(state)
for turn in itertools.count():
next_state, reward, done, _ = env.step(action)
if turn < 500 and done:
reward = -500
next_action = make(next_state)
update_qtable(state, action, reward, next_state, next_action)
state = next_state
action = next_action
if done:
turns.append(turn)
vals.append(list(theta[0]) + list(theta[1]))
break
plt = []
#plt.append(go.Scatter(x=[x for x in range(len(turns))], y=turns))
for i in range(4):
plt.append(go.Scatter(x=[x for x in range(len(vals))], y=[v[i] for v in vals]))
iplot(plt)
print(theta, turn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment