Skip to content

Instantly share code, notes, and snippets.

@MarceloPrado
Created May 23, 2017 20:24
Show Gist options
  • Save MarceloPrado/637c4e059b9eade1321d24c9c7a93086 to your computer and use it in GitHub Desktop.
Save MarceloPrado/637c4e059b9eade1321d24c9c7a93086 to your computer and use it in GitHub Desktop.
'Esse arquivo gera alguns stats dos algoritmos'
import numpy as np
import gym
# from main import random_search
import os
def run_episode(env, params, max_reward):
'Roda o episodio por no max. 200 timesteps, retornanto o totalReward para esse set de params'
observation = env.reset()
totalreward = 0
for _ in range(max_reward):
env.render() #para ver treinado
action = 0 if np.matmul(params, observation) < 0 else 1
observation, reward, done, info = env.step(action)
totalreward += reward
if done:
break
return totalreward
def random_search(env, max_reward, streak_counter):
'''
Gera weights aleatorios ate encontrar uma combinacao que
que satisfaca as condicoes impostas
'''
best_params = None
best_reward = 200
streak = 0
episode_counter = 0
for i_episode in range(30000):
print(streak)
if streak == 0:
parameters = np.random.rand(4) * 2 - 1
else:
parameters = best_params
reward = run_episode(env, parameters, max_reward)
if reward >= best_reward:
best_reward = reward
best_params = parameters
# caso durou 200 timesteps, considere como resolvido
if best_reward >= max_reward:
streak += 1
episode_counter = i_episode
if reward < max_reward:
streak = 0
if streak > streak_counter:
break
return episode_counter
out = 'gym/out'
if out:
if not os.path.exists(out):
os.makedirs(out)
else:
if not os.path.exists('gym-out/' + "CartPole-v1"):
os.makedirs('gym-out/' + "CartPole-v1")
out = 'gym-out/' + "CartPole-v1"
directory = "gym-out/"
env = gym.make("CartPole-v1")
env = gym.wrappers.Monitor(env, directory,force=True,video_callable=lambda episode_id: episode_id%10000==0)
random_search(env,200, 200)
env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment