Created
October 10, 2016 22:56
-
-
Save stefanopalmieri/2efec0e09c14de06fb93fbb91e18a93a to your computer and use it in GitHub Desktop.
Walker Experiment Gist code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#--- parameters for the pole experiment v1 ---# | |
# The `Types` section specifies which classes should be used for various | |
# tasks in the NEAT algorithm. If you use a non-default class here, you | |
# must register it with your Config instance before loading the config file. | |
[Types] | |
stagnation_type = DefaultStagnation | |
reproduction_type = DefaultReproduction | |
[phenotype] | |
input_nodes = 24 | |
hidden_nodes = 0 | |
output_nodes = 4 | |
initial_connection = fs_neat | |
max_weight = 10 | |
min_weight = -10 | |
feedforward = 0 | |
activation_functions = tanh sigmoid relu identity | |
weight_stdev = 3 | |
[genetic] | |
pop_size = 400 | |
max_fitness_threshold = 302 | |
prob_add_conn = 0.3 | |
prob_add_node = 0.1 | |
prob_delete_conn = 0.05 | |
prob_delete_node = 0.03 | |
prob_mutate_bias = 0.00109 | |
bias_mutation_power = 0.01 | |
prob_mutate_response = 0.01 | |
response_mutation_power = 0.01 | |
prob_mutate_weight = 0.3 | |
prob_replace_weight = 0.03 | |
weight_mutation_power = 0.1 | |
prob_mutate_activation = 0.01 | |
prob_toggle_link = 0.0138 | |
reset_on_extinction = 1 | |
[genotype compatibility] | |
compatibility_threshold = 3 | |
excess_coefficient = 1.0 | |
disjoint_coefficient = 1.0 | |
weight_coefficient = 0.4 | |
[DefaultStagnation] | |
species_fitness_func = mean | |
max_stagnation = 5 | |
[DefaultReproduction] | |
elitism = 3 | |
survival_threshold = 0.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Organism shown in the OpenAI Evaluation was found in the 58th generation | |
# A total of 116000 episodes were used for training (400 * 5 * 58) | |
# Though the highest fitness organism had a fitness score of 303 in training, it did not solve the environment at test time. | |
from __future__ import print_function | |
import gym | |
import numpy as np | |
import itertools | |
import os | |
from neat import nn, population, statistics | |
np.set_printoptions(threshold=np.inf) | |
env = gym.make('BipedalWalker-v2') | |
# run through the population | |
def eval_fitness(genomes): | |
for g in genomes: | |
observation = env.reset() | |
# env.render() | |
net = nn.create_feed_forward_phenotype(g) | |
fitness = 0 | |
reward = 0 | |
frames = 0 | |
total_fitness = 0 | |
for k in range(5): | |
while 1: | |
inputs = observation | |
# active neurons | |
output = net.serial_activate(inputs) | |
output = np.clip(output, -1, 1) | |
# print(output) | |
observation, reward, done, info = env.step(np.array(output)) | |
fitness += reward | |
frames += 1 | |
# env.render() | |
if done or frames > 2000: | |
total_fitness += fitness | |
# print(fitness) | |
env.reset() | |
break | |
# evaluate the fitness | |
g.fitness = total_fitness / 5 | |
print(g.fitness) | |
local_dir = os.path.dirname(__file__) | |
config_path = os.path.join(local_dir, 'xor2_config') | |
pop = population.Population(config_path) | |
pop.run(eval_fitness, 1000) | |
winner = pop.statistics.best_genome() | |
del pop | |
winningnet = nn.create_feed_forward_phenotype(winner) | |
env.monitor.start('walker-experiment/', force=True) | |
streak = 0 | |
while streak < 100: | |
fitness = 0 | |
frames = 0 | |
reward = 0 | |
observation = env.reset() | |
env.render() | |
while 1: | |
inputs = observation | |
# active neurons | |
output = winningnet.serial_activate(inputs) | |
output = np.clip(output, -1, 1) | |
# print(output) | |
observation, reward, done, info = env.step(np.array(output)) | |
fitness += reward | |
env.render() | |
frames += 1 | |
if done or frames > 2000: | |
if fitness >= 300: | |
print(fitness) | |
print ('streak: ', streak) | |
streak += 1 | |
else: | |
print(fitness) | |
print('streak: ', streak) | |
streak = 0 | |
break | |
print("completed!") | |
env.monitor.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment