Skip to content

Instantly share code, notes, and snippets.

@hdlim15
Created June 23, 2016 09:04
Show Gist options
  • Save hdlim15/f2cc80369fe54051823da66a9d56f513 to your computer and use it in GitHub Desktop.
Save hdlim15/f2cc80369fe54051823da66a9d56f513 to your computer and use it in GitHub Desktop.
import gym
import pybrain
import random
from pybrain.datasets import SupervisedDataSet
from pybrain.tools.shortcuts import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer
name = "CartPole-v0"
env = gym.make(name)
#env.monitor.start("/tmp/CartPole-v0-6")
count = -1
episode = -1
avg_score = 0
streak = 0
EPSILON = 0.6
GAMMA = 0.95
net = buildNetwork(4, 4, 2)
miniBatch = SupervisedDataSet(4, 2)
observation = env.reset()
observation[0] = 0
while streak < 200:
episode += 1
for t in range(201):
# env.render()
count += 1
bestGuess = net.activate(observation)
#perform e-greedy action
if random.random() < EPSILON:
action = env.action_space.sample()
else:
action = 0 if bestGuess[0] > bestGuess[1] else 1
observationNew, reward, done, info = env.step(action)
observationNew[0] = 0
if done:
nextGuess = -100
else:
ff = net.activate(observationNew)
nextGuess = reward + GAMMA * max(ff[0], ff[1])
updateGuess = [nextGuess, bestGuess[1]] if action == 0 else [bestGuess[0], nextGuess]
miniBatch.addSample(observation, updateGuess)
observation = observationNew
if count % 50 == 49:
randList = random.sample(xrange(50), 50)
for rand in randList:
point = SupervisedDataSet(4,2)
obs,lab = miniBatch.getSample(rand)
point.addSample(obs,lab)
trainer = BackpropTrainer(net, point)
trainer.train()
miniBatch = SupervisedDataSet(4, 2)
if done or t == 200:
avg_score += t
if t >= 195:
EPSILON *= 0.5
streak += 1
else:
streak = 0
if (episode) % 50 == 49:
print(episode/50 + 1)
print '{0} average score'.format(float(avg_score) / 50.0)
EPSILON *= 0.9
avg_score = 0
observation = env.reset()
break
print(episode)
#env.monitor.close()
#gym.upload("/tmp/CartPole-v0-6", algorithm_id="hdlim15", api_key="sk_WivV3yCASzedEeXHcpKSCA")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment