Created
May 23, 2017 20:24
-
-
Save MarceloPrado/637c4e059b9eade1321d24c9c7a93086 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'Esse arquivo gera alguns stats dos algoritmos' | |
import numpy as np | |
import gym | |
# from main import random_search | |
import os | |
def run_episode(env, params, max_reward): | |
'Roda o episodio por no max. 200 timesteps, retornanto o totalReward para esse set de params' | |
observation = env.reset() | |
totalreward = 0 | |
for _ in range(max_reward): | |
env.render() #para ver treinado | |
action = 0 if np.matmul(params, observation) < 0 else 1 | |
observation, reward, done, info = env.step(action) | |
totalreward += reward | |
if done: | |
break | |
return totalreward | |
def random_search(env, max_reward, streak_counter): | |
''' | |
Gera weights aleatorios ate encontrar uma combinacao que | |
que satisfaca as condicoes impostas | |
''' | |
best_params = None | |
best_reward = 200 | |
streak = 0 | |
episode_counter = 0 | |
for i_episode in range(30000): | |
print(streak) | |
if streak == 0: | |
parameters = np.random.rand(4) * 2 - 1 | |
else: | |
parameters = best_params | |
reward = run_episode(env, parameters, max_reward) | |
if reward >= best_reward: | |
best_reward = reward | |
best_params = parameters | |
# caso durou 200 timesteps, considere como resolvido | |
if best_reward >= max_reward: | |
streak += 1 | |
episode_counter = i_episode | |
if reward < max_reward: | |
streak = 0 | |
if streak > streak_counter: | |
break | |
return episode_counter | |
out = 'gym/out' | |
if out: | |
if not os.path.exists(out): | |
os.makedirs(out) | |
else: | |
if not os.path.exists('gym-out/' + "CartPole-v1"): | |
os.makedirs('gym-out/' + "CartPole-v1") | |
out = 'gym-out/' + "CartPole-v1" | |
directory = "gym-out/" | |
env = gym.make("CartPole-v1") | |
env = gym.wrappers.Monitor(env, directory,force=True,video_callable=lambda episode_id: episode_id%10000==0) | |
random_search(env,200, 200) | |
env.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment