Skip to content

Instantly share code, notes, and snippets.

@yashpatel5400
Created July 25, 2017 08:20
Show Gist options
  • Save yashpatel5400/43a0017c3d4fad50e39d968350274b66 to your computer and use it in GitHub Desktop.
Save yashpatel5400/43a0017c3d4fad50e39d968350274b66 to your computer and use it in GitHub Desktop.
OpenAI CartPole w/ Keras: Prediction
"""
__name__ = predict.py
__author__ = Yash Patel
__description__ = Does the prediction using the defined model and data
"""
import gym
import numpy as np
from data import gather_data
from model import create_model
def predict():
env = gym.make("CartPole-v0")
trainingX, trainingY = gather_data(env)
model = create_model()
model.fit(trainingX, trainingY, epochs=5)
scores = []
num_trials = 50
sim_steps = 500
for trial in range(num_trials):
observation = env.reset()
score = 0
for step in range(sim_steps):
action = np.argmax(model.predict(observation.reshape(1,4)))
observation, reward, done, _ = env.step(action)
score += reward
if done:
break
scores.append(score)
print(np.mean(scores))
if __name__ == "__main__":
predict()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment