Skip to content

Instantly share code, notes, and snippets.

@bonus85
Created June 23, 2016 12:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bonus85/f4fd7bbe0f4ad5f5e488ca9009c6ce22 to your computer and use it in GitHub Desktop.
Save bonus85/f4fd7bbe0f4ad5f5e488ca9009c6ce22 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import gym
import random
import numpy as np
class Evaluator:
def __init__(self, env_name='CartPole-v0', max_iterations=200, render=False):
self.env = gym.make(env_name)
self.render = render
self.max_iterations = max_iterations
self.min_abs = max_iterations - 5
self.episode_count = 0
def evaluate(self, distribution):
observation = self.env.reset()
cummulative_reward = 0.
t=0
while True:
if self.render:
self.env.render()
w_sum = np.sum(distribution*observation)
action = 1 if w_sum > 0 else 0
observation, reward, done, info = self.env.step(action)
cummulative_reward += reward
t += 1
if done or t > self.max_iterations:
break
return cummulative_reward
def test_distribution(self, distribution, n_eval=100):
rsum = 0.
for i_episode in range(n_eval):
rw = self.evaluate(distribution)
self.episode_count += 1
rsum += rw
if rw < self.min_abs:
return False
print('Distribution passed: {}'.format(distribution))
return True
def main(N):
evaluator = Evaluator(max_iterations=201)
evaluator.env.monitor.start('/tmp/cartpole-experiment-1', force=True)
best_distribution = np.array([0,0,0,0], dtype=float)
best_reward = 0.
w = np.zeros(4)
for n_test in range(N):
if evaluator.test_distribution(w):
best_distribution = w
print 'Episodes to solve: {}'.format(evaluator.episode_count)
break
else:
w = best_distribution + np.random.rand(4)
evaluator.env.monitor.close()
if __name__ == '__main__':
main(100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment