Created
April 24, 2019 15:19
-
-
Save AOrtizDZ/725c1efc191d3aeb854dddbc1cbacbd3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import neat | |
import pickle | |
import os | |
import operator | |
from matplotlib import pyplot as plt | |
import pygame | |
import random | |
import numpy as np | |
GENERATION = 0 | |
MAX_FITNESS = 0 | |
BEST_GENOME = 0 | |
reward = 0 | |
fitness_log = list() | |
WIDTH = 600 | |
HEIGHT = 200 | |
FPS = 200 | |
# define colors | |
WHITE = (255, 255, 255) | |
BLACK = (0, 0, 0) | |
RED = (255, 0, 0) | |
SEMI_RED = (255, 100, 100) | |
GREEN = (0, 255, 0) | |
SEMI_GREEN = (100, 255, 100) | |
BLUE = (0, 0, 255) | |
SEMI_BLUE = (100, 100, 255) | |
num_players = 1 | |
level = 0 | |
prob_per_second = 0.6 # Probability of creating an enemy | |
score = 0 | |
colors = [RED, GREEN, BLUE] | |
semicolors = [SEMI_RED, SEMI_GREEN, SEMI_BLUE] | |
dontcreate = False | |
mustcreate = False | |
forbidden_creation_counter = 0 | |
must_creation_counter = 0 | |
# initialize pygame and create window | |
pygame.init() | |
pygame.mixer.init() | |
screen = pygame.display.set_mode((WIDTH, HEIGHT)) | |
pygame.display.set_caption("Dinosaur") | |
clock = pygame.time.Clock() | |
font = pygame.font.Font(None, 35) | |
keys = [["LEFT", "UP", "RIGHT"], ["a", "w", "d"], ["KP4", "KP8", "KP6"]] | |
keys = [[getattr(pygame, "K_" + e) for e in elem] for elem in keys] | |
class Player(): | |
def __init__(self, keys, color, semicolor, x_position): | |
self.color = color | |
self.semicolor = semicolor | |
self.image = pygame.Surface((40, 40)) | |
self.image.fill(self.color) | |
self.rect = self.image.get_rect() | |
self.rect.centerx = WIDTH / 8 + x_position | |
self.rect.bottom = HEIGHT - 10 | |
self.speedx = 0 | |
self.speedy = 0 | |
self.accel_y = 0.5 | |
self.vidas = 1 | |
self.inmortal = False | |
self.counter = 0 | |
self.keys = keys | |
self.bot = True | |
def update(self, action): | |
# 0 left, 1 up, 2 rigth | |
self.speedx = 0 | |
if self.bot == False: | |
keystate = pygame.key.get_pressed() | |
if keystate[self.keys[0]]: | |
self.speedx = -5 | |
if keystate[self.keys[2]]: | |
self.speedx = 5 | |
if self.rect.bottom >= HEIGHT - 10: | |
if keystate[self.keys[1]]: | |
self.speedy = -8 # Salta | |
else: # If it is a bot: | |
if action[0] == 1: | |
self.speedx = -5 | |
if action[2] == 1: | |
self.speedx = 5 | |
if self.rect.bottom >= HEIGHT - 10: | |
if action[1] == 1: | |
self.speedy = -8 | |
self.speedy += self.accel_y | |
self.rect.y += self.speedy | |
if self.rect.bottom > HEIGHT - 10: | |
self.rect.bottom = HEIGHT - 10 | |
self.speedy = 0 | |
self.rect.x += self.speedx | |
if self.rect.right > WIDTH: | |
self.rect.right = WIDTH | |
if self.rect.left < 0: | |
self.rect.left = 0 | |
if self.inmortal == True: | |
self.image.fill(self.semicolor) | |
else: | |
self.image.fill(self.color) | |
def choque(self): | |
if not self.inmortal: | |
self.vidas -= 1 | |
self.inmortal = True | |
class Enemy(): | |
def __init__(self, level): | |
self.image = pygame.Surface((random.randint(10, 40), random.randint(10, 40))) | |
self.image.fill(BLACK) | |
self.rect = self.image.get_rect() | |
self.rect.right = WIDTH | |
self.rect.bottom = HEIGHT - 10 | |
self.speedx = -(random.randint(4, 7) + level) | |
def update(self): | |
self.rect.x += self.speedx | |
def game(genome, config): | |
random.seed(12) | |
global reward, mustcreate, dontcreate, must_creation_counter, score | |
global forbidden_creation_counter | |
net = neat.nn.FeedForwardNetwork.create(genome, config) | |
pygame.init() | |
pygame.mixer.init() | |
# screen = pygame.display.set_mode((WIDTH, HEIGHT)) | |
# pygame.display.set_caption("Dinosaur") | |
clock = pygame.time.Clock() | |
# font = pygame.font.Font(None, 35) | |
players = list() | |
for i in range(num_players): | |
players.append(Player(keys[i], colors[i], semicolors[i], i * 45)) | |
enemies = list() | |
# Game loop | |
running = True | |
fps_counter = 0 | |
reward = 0 | |
action = np.zeros([3]) | |
action[0] = 1 | |
while running: | |
reward += 1 | |
# keep loop running at the right speed | |
# clock.tick(FPS) | |
# Process input (events) | |
for event in pygame.event.get(): | |
# check for closing window | |
if event.type == pygame.QUIT: | |
running = False | |
if (random.random() < prob_per_second / FPS and not dontcreate) or mustcreate: | |
enemies.append(Enemy(level)) | |
dontcreate = True | |
mustcreate = False | |
must_creation_counter = 0 | |
else: | |
must_creation_counter += 1 | |
if dontcreate == True: | |
forbidden_creation_counter += 1 | |
if forbidden_creation_counter == 30: | |
dontcreate = False | |
forbidden_creation_counter = 0 | |
if must_creation_counter == 120: | |
mustcreate = True | |
# screen.fill(WHITE) | |
# Update | |
for player in players: | |
player.update(action) | |
# screen.blit(player.image, player.rect) | |
# Draw / render | |
for enemy in enemies: | |
enemy.update() | |
# screen.blit(enemy.image, enemy.rect) | |
if enemy.rect.right < 0: | |
enemies.remove(enemy) | |
score += 1 | |
for player in players: | |
if pygame.sprite.collide_rect(player, enemy): | |
player.choque() | |
for player in players: | |
if player.inmortal == True: | |
player.counter += 1 | |
if player.counter == 120: | |
player.counter = 0 | |
player.inmortal = False | |
if player.vidas == 0: | |
running = False | |
return reward | |
# screen.blit(font.render(str(player.vidas), True, player.color), [60 + 60 * players.index(player), 40]) | |
# enemies.sort(key=operator.attrgetter('rect.left')) | |
enemies_coord = list() | |
for enemy in enemies: | |
enemies_coord.append(enemy.rect.left) | |
enemies_coord.append(enemy.rect.width) | |
enemies_coord.append(enemy.rect.height) | |
enemies_coord.append(enemy.speedx) | |
player_coord = [player.rect.left, | |
player.rect.top, | |
player.speedy if player.speedy != 0.5 else 0] | |
state = player_coord + enemies_coord | |
while len(state) < 15: | |
state.append(0) | |
while len(state) > 15: | |
state.pop() | |
output = net.activate(state) | |
max_Q = np.argmax(output) | |
action = np.zeros([3]) | |
action[max_Q] = 1 | |
if max_Q == 1: reward -= 2 | |
# screen.blit(font.render(str(reward), True, BLACK), [WIDTH - 60, 40]) | |
# pygame.draw.line(screen, BLACK, (0, HEIGHT - 10), (WIDTH, HEIGHT - 10), 3) | |
# *after* drawing everything, flip the display | |
# pygame.display.flip() | |
fps_counter += 1 | |
def eval_genomes(genomes, config): | |
i = 0 | |
global fitness_log | |
global reward | |
global GENERATION, MAX_FITNESS, BEST_GENOME | |
GENERATION += 1 | |
for genome_id, genome in genomes: | |
genome.fitness = game(genome, config) | |
print("Gen : %d Genome # : %d Fitness : %f Max Fitness : %f" % (GENERATION, i, genome.fitness, MAX_FITNESS)) | |
if genome.fitness >= MAX_FITNESS: | |
MAX_FITNESS = genome.fitness | |
BEST_GENOME = genome | |
reward = 0 | |
i += 1 | |
fitness_log.append(genome.fitness) | |
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, | |
neat.DefaultSpeciesSet, neat.DefaultStagnation, | |
'config') | |
pop = neat.Population(config) | |
stats = neat.StatisticsReporter() | |
pop.add_reporter(stats) | |
winner = pop.run(eval_genomes, 500) | |
print(winner) | |
plt.plot(fitness_log) | |
outputDir = 'C:/Users/alexo/.spyder-py3/' | |
os.chdir(outputDir) | |
serialNo = len(os.listdir(outputDir)) + 1 | |
outputFile = open(str(serialNo) + '_' + str(int(MAX_FITNESS)) + '.p', 'wb') | |
pickle.dump(winner, outputFile) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment