Skip to content

Instantly share code, notes, and snippets.

@StuartFarmer
Created August 6, 2016 03:59
Show Gist options
  • Save StuartFarmer/c0e2d2eb6c3bb7b075da2735aedfad9f to your computer and use it in GitHub Desktop.
Save StuartFarmer/c0e2d2eb6c3bb7b075da2735aedfad9f to your computer and use it in GitHub Desktop.
# I'll refactor this later, but it's almost midnight and I just wanted to demonstrate how sick NEAT is at solving problems without defining any direction for the algoritm
from __future__ import print_function
import gym
import numpy as np
import itertools
import os
from neat import nn, population, statistics
np.set_printoptions(threshold=np.inf)
env = gym.make('CartPole-v0')
# run through the population
def eval_fitness(genomes):
for g in genomes:
observation = env.reset()
env.render()
net = nn.create_feed_forward_phenotype(g)
fitness = 0
while 1:
inputs = observation
# active neurons
output = net.serial_activate(inputs)
if (output[0] >= 0):
observation, reward, done, info = env.step(1)
else:
observation, reward, done, info = env.step(0)
fitness += reward
env.render()
if done:
print(fitness)
env.reset()
break
# evaluate the fitness
g.fitness = fitness
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, 'cartPole_config')
pop = population.Population(config_path)
pop.run(eval_fitness, 300)
env.monitor.start('cartpole-experiment/', force=True)
winner = pop.statistics.best_genome()
streak = 0
winningnet = nn.create_feed_forward_phenotype(winner)
observation = env.reset()
env.render()
while streak < 100:
fitness = 0
frames = 0
while 1:
inputs = observation
# active neurons
output = winningnet.serial_activate(inputs)
if (output[0] >= 0):
observation, reward, done, info = env.step(1)
else:
observation, reward, done, info = env.step(0)
fitness += reward
env.render()
frames += 1
if frames >= 200:
done = True
if done:
if fitness >= 195:
print ('streak: ', streak)
streak += 1
else:
print(fitness)
print('streak: ', streak)
streak = 0
env.reset()
break
print("completed!")
env.monitor.close()
gym.upload('cartpole-experiment/', api_key='XXX')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment