Last active
April 7, 2017 20:00
-
-
Save jaume-ferrarons/99604236bcf2e6e8dd7e34b1aadb4a14 to your computer and use it in GitHub Desktop.
OpenAI CartPole-v0 solved with a Support Vector Regressor - https://gym.openai.com/evaluations/eval_h4Cbis6SiSfwXHf7HuVnw#reproducibility
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from gym import wrappers | |
import numpy as np | |
from sklearn import svm | |
env = gym.make('CartPole-v0') | |
env = wrappers.Monitor(env, '/tmp/cartpole-experiment-1', force=True) | |
observations = [] | |
actions = [] | |
rewards = [] | |
model = None | |
def update_model(): | |
"""Updates the prediction model using all data available""" | |
global model | |
a_observations = np.array(observations) | |
a_actions = np.array(actions).reshape((len(actions), 1)) | |
X = np.hstack([a_observations, a_actions]) | |
a_rewards = np.array(rewards) | |
y = np.stack(rewards) | |
model = svm.SVR() | |
model.fit(X, y) | |
def predict(observation, exploring_ratio): | |
"""Predicts the best action using the current observation""" | |
if model is None or np.random.random() < exploring_ratio: | |
return np.random.choice(range(2),1)[0] | |
X = np.hstack([np.array([observation, observation]), np.array([[0],[1]])]) | |
pred = model.predict(X) | |
return list(range(2))[np.argmax(pred)] | |
for i_episode in range(200): | |
observation = env.reset() | |
reward = 0 | |
for t in range(1000): | |
env.render() | |
#print observation, reward | |
observations.append(observation) | |
action = predict(observation, 1/(i_episode/8.0 + 1)) | |
actions.append(action) | |
observation, reward, done, info = env.step(action) | |
if done: | |
print("Episode {} finished after {} timesteps".format(i_episode, t+1)) | |
if t < 199: rewards += (np.array(list(reversed(range(t+1))))/float(t)).tolist() #Add cumulative rewards | |
else: rewards += [1]*(t+1) | |
break | |
update_model() | |
env.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment