NEATEvolve MAME
-- MarI/O by SethBling | |
-- Feel free to use this code, but please do not redistribute it. | |
-- usage: mame nes smb1 -window -autoboot_script NEATEvolve.lua | |
-- create savestate 1 at the start of the level you want it to learn | |
local Filename | |
local ButtonNames | |
local Game | |
local pool | |
local rightmost | |
local timeout | |
local marioX | |
local marioY | |
local screenX | |
local screenY | |
local controller | |
if emu.softname() == "smwu" then | |
Filename = "1" | |
ButtonNames = { | |
"A", | |
"B", | |
"X", | |
"Y", | |
"P1 Up", | |
"P1 Down", | |
"P1 Left", | |
"P1 Right", | |
} | |
Game = "Super Mario World (USA)" | |
elseif emu.softname() == "smb1" then | |
Filename = "1" | |
ButtonNames = { | |
"A", | |
"B", | |
"P1 Up", | |
"P1 Down", | |
"P1 Left", | |
"P1 Right", | |
} | |
Game = "Super Mario Bros." | |
end | |
local BoxRadius = 6 | |
local InputSize = (BoxRadius*2+1)*(BoxRadius*2+1) | |
local Inputs = InputSize+1 | |
local Outputs = #ButtonNames | |
local Population = 300 | |
local DeltaDisjoint = 2.0 | |
local DeltaWeights = 0.4 | |
local DeltaThreshold = 1.0 | |
local StaleSpecies = 15 | |
local MutateConnectionsChance = 0.25 | |
local PerturbChance = 0.90 | |
local CrossoverChance = 0.75 | |
local LinkMutationChance = 2.0 | |
local NodeMutationChance = 0.50 | |
local BiasMutationChance = 0.40 | |
local StepSize = 0.1 | |
local DisableMutationChance = 0.4 | |
local EnableMutationChance = 0.2 | |
local TimeoutConstant = 20 | |
local MaxNodes = 1000000 | |
local memory = manager:machine().devices[":maincpu"].spaces["program"] | |
local gui = select(2, next(manager:machine().screens)) | |
local function writeFile(filename) | |
local file = io.open(filename, "w") | |
file:write(pool.generation .. "\n") | |
file:write(pool.maxFitness .. "\n") | |
file:write(#pool.species .. "\n") | |
for n,species in pairs(pool.species) do | |
file:write(species.topFitness .. "\n") | |
file:write(species.staleness .. "\n") | |
file:write(#species.genomes .. "\n") | |
for m,genome in pairs(species.genomes) do | |
file:write(genome.fitness .. "\n") | |
file:write(genome.maxneuron .. "\n") | |
for mutation,rate in pairs(genome.mutationRates) do | |
file:write(mutation .. "\n") | |
file:write(rate .. "\n") | |
end | |
file:write("done\n") | |
file:write(#genome.genes .. "\n") | |
for l,gene in pairs(genome.genes) do | |
file:write(gene.into .. " ") | |
file:write(gene.out .. " ") | |
file:write(gene.weight .. " ") | |
file:write(gene.innovation .. " ") | |
if(gene.enabled) then | |
file:write("1\n") | |
else | |
file:write("0\n") | |
end | |
end | |
end | |
end | |
file:close() | |
end | |
local function getPositions() | |
if Game == "Super Mario World (USA)" then | |
marioX = memory:read_i16(0x94) | |
marioY = memory:read_i16(0x96) | |
local layer1x = memory:read_i16(0x1A); | |
local layer1y = memory:read_i16(0x1C); | |
screenX = marioX-layer1x | |
screenY = marioY-layer1y | |
elseif Game == "Super Mario Bros." then | |
marioX = memory:read_u8(0x6D) * 0x100 + memory:read_u8(0x86) | |
marioY = memory:read_u8(0x03B8)+16 | |
screenX = memory:read_u8(0x03AD) | |
screenY = memory:read_u8(0x03B8) | |
end | |
end | |
local function getTile(dx, dy) | |
local x, y | |
if Game == "Super Mario World (USA)" then | |
x = math.floor((marioX+dx+8)/16) | |
y = math.floor((marioY+dy)/16) | |
return memory:read_u8(0x7FC800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10) | |
elseif Game == "Super Mario Bros." then | |
x = marioX + dx + 8 | |
y = marioY + dy - 16 | |
local page = math.floor(x/256)%2 | |
local subx = math.floor((x%256)/16) | |
local suby = math.floor((y - 32)/16) | |
local addr = 0x500 + page*13*16+suby*16+subx | |
if suby >= 13 or suby < 0 then | |
return 0 | |
end | |
if memory:read_u8(addr) ~= 0 then | |
return 1 | |
else | |
return 0 | |
end | |
end | |
end | |
local function getSprites() | |
if Game == "Super Mario World (USA)" then | |
local sprites = {} | |
for slot=0,11 do | |
local status = memory:read_u8(0x14C8+slot) | |
if status ~= 0 then | |
spritex = memory:read_u8(0xE4+slot) + memory:read_u8(0x14E0+slot)*256 | |
spritey = memory:read_u8(0xD8+slot) + memory:read_u8(0x14D4+slot)*256 | |
sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey} | |
end | |
end | |
return sprites | |
elseif Game == "Super Mario Bros." then | |
local sprites = {} | |
for slot=0,4 do | |
local enemy = memory:read_u8(0xF+slot) | |
if enemy ~= 0 then | |
local ex = memory:read_u8(0x6E + slot)*0x100 + memory:read_u8(0x87+slot) | |
local ey = memory:read_u8(0xCF + slot)+24 | |
sprites[#sprites+1] = {["x"]=ex,["y"]=ey} | |
end | |
end | |
return sprites | |
end | |
end | |
local function getExtendedSprites() | |
if Game == "Super Mario World (USA)" then | |
local extended = {} | |
for slot=0,11 do | |
local number = memory:read_u8(0x170B+slot) | |
if number ~= 0 then | |
spritex = memory:read_u8(0x171F+slot) + memory:read_u8(0x1733+slot)*256 | |
spritey = memory:read_u8(0x1715+slot) + memory:read_u8(0x1729+slot)*256 | |
extended[#extended+1] = {["x"]=spritex, ["y"]=spritey} | |
end | |
end | |
return extended | |
elseif Game == "Super Mario Bros." then | |
return {} | |
end | |
end | |
local function getInputs() | |
getPositions() | |
local sprites = getSprites() | |
local extended = getExtendedSprites() | |
local inputs = {} | |
for dy=-BoxRadius*16,BoxRadius*16,16 do | |
for dx=-BoxRadius*16,BoxRadius*16,16 do | |
inputs[#inputs+1] = 0 | |
local tile = getTile(dx, dy) | |
if tile == 1 and marioY+dy < 0x1B0 then | |
inputs[#inputs] = 1 | |
end | |
for i = 1,#sprites do | |
distx = math.abs(sprites[i]["x"] - (marioX+dx)) | |
disty = math.abs(sprites[i]["y"] - (marioY+dy)) | |
if distx <= 8 and disty <= 8 then | |
inputs[#inputs] = -1 | |
end | |
end | |
for i = 1,#extended do | |
distx = math.abs(extended[i]["x"] - (marioX+dx)) | |
disty = math.abs(extended[i]["y"] - (marioY+dy)) | |
if distx < 8 and disty < 8 then | |
inputs[#inputs] = -1 | |
end | |
end | |
end | |
end | |
--mariovx = memory.read_s8(0x7B) | |
--mariovy = memory.read_s8(0x7D) | |
return inputs | |
end | |
local function sigmoid(x) | |
return 2/(1+math.exp(-4.9*x))-1 | |
end | |
local function newInnovation() | |
pool.innovation = pool.innovation + 1 | |
return pool.innovation | |
end | |
local function newPool() | |
local pool = {} | |
pool.species = {} | |
pool.generation = 0 | |
pool.innovation = Outputs | |
pool.currentSpecies = 1 | |
pool.currentGenome = 1 | |
pool.currentFrame = 0 | |
pool.maxFitness = 0 | |
return pool | |
end | |
local function newSpecies() | |
local species = {} | |
species.topFitness = 0 | |
species.staleness = 0 | |
species.genomes = {} | |
species.averageFitness = 0 | |
return species | |
end | |
local function newGenome() | |
local genome = {} | |
genome.genes = {} | |
genome.fitness = 0 | |
genome.adjustedFitness = 0 | |
genome.network = {} | |
genome.maxneuron = 0 | |
genome.globalRank = 0 | |
genome.mutationRates = {} | |
genome.mutationRates["connections"] = MutateConnectionsChance | |
genome.mutationRates["link"] = LinkMutationChance | |
genome.mutationRates["bias"] = BiasMutationChance | |
genome.mutationRates["node"] = NodeMutationChance | |
genome.mutationRates["enable"] = EnableMutationChance | |
genome.mutationRates["disable"] = DisableMutationChance | |
genome.mutationRates["step"] = StepSize | |
return genome | |
end | |
local function newGene() | |
local gene = {} | |
gene.into = 0 | |
gene.out = 0 | |
gene.weight = 0.0 | |
gene.enabled = true | |
gene.innovation = 0 | |
return gene | |
end | |
local function copyGene(gene) | |
local gene2 = newGene() | |
gene2.into = gene.into | |
gene2.out = gene.out | |
gene2.weight = gene.weight | |
gene2.enabled = gene.enabled | |
gene2.innovation = gene.innovation | |
return gene2 | |
end | |
local function copyGenome(genome) | |
local genome2 = newGenome() | |
for g=1,#genome.genes do | |
table.insert(genome2.genes, copyGene(genome.genes[g])) | |
end | |
genome2.maxneuron = genome.maxneuron | |
genome2.mutationRates["connections"] = genome.mutationRates["connections"] | |
genome2.mutationRates["link"] = genome.mutationRates["link"] | |
genome2.mutationRates["bias"] = genome.mutationRates["bias"] | |
genome2.mutationRates["node"] = genome.mutationRates["node"] | |
genome2.mutationRates["enable"] = genome.mutationRates["enable"] | |
genome2.mutationRates["disable"] = genome.mutationRates["disable"] | |
return genome2 | |
end | |
local function newNeuron() | |
local neuron = {} | |
neuron.incoming = {} | |
neuron.value = 0.0 | |
return neuron | |
end | |
local function generateNetwork(genome) | |
local network = {} | |
network.neurons = {} | |
for i=1,Inputs do | |
network.neurons[i] = newNeuron() | |
end | |
for o=1,Outputs do | |
network.neurons[MaxNodes+o] = newNeuron() | |
end | |
table.sort(genome.genes, function (a,b) | |
return (a.out < b.out) | |
end) | |
for i=1,#genome.genes do | |
local gene = genome.genes[i] | |
if gene.enabled then | |
if network.neurons[gene.out] == nil then | |
network.neurons[gene.out] = newNeuron() | |
end | |
local neuron = network.neurons[gene.out] | |
table.insert(neuron.incoming, gene) | |
if network.neurons[gene.into] == nil then | |
network.neurons[gene.into] = newNeuron() | |
end | |
end | |
end | |
genome.network = network | |
end | |
local function evaluateNetwork(network, inputs) | |
table.insert(inputs, 1) | |
if #inputs ~= Inputs then | |
emu.print_error("Incorrect number of neural network inputs.") | |
return {} | |
end | |
for i=1,Inputs do | |
network.neurons[i].value = inputs[i] | |
end | |
for _,neuron in pairs(network.neurons) do | |
local sum = 0 | |
for j = 1,#neuron.incoming do | |
local incoming = neuron.incoming[j] | |
local other = network.neurons[incoming.into] | |
sum = sum + incoming.weight * other.value | |
end | |
if #neuron.incoming > 0 then | |
neuron.value = sigmoid(sum) | |
end | |
end | |
local outputs = {} | |
for o=1,Outputs do | |
local button = ButtonNames[o] | |
if network.neurons[MaxNodes+o].value > 0 then | |
controller[button].state = 1 | |
else | |
controller[button].state = 0 | |
end | |
end | |
return outputs | |
end | |
local function crossover(g1, g2) | |
-- Make sure g1 is the higher fitness genome | |
if g2.fitness > g1.fitness then | |
local tempg = g1 | |
g1 = g2 | |
g2 = tempg | |
end | |
local child = newGenome() | |
local innovations2 = {} | |
for i=1,#g2.genes do | |
local gene = g2.genes[i] | |
innovations2[gene.innovation] = gene | |
end | |
for i=1,#g1.genes do | |
local gene1 = g1.genes[i] | |
local gene2 = innovations2[gene1.innovation] | |
if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then | |
table.insert(child.genes, copyGene(gene2)) | |
else | |
table.insert(child.genes, copyGene(gene1)) | |
end | |
end | |
child.maxneuron = math.max(g1.maxneuron,g2.maxneuron) | |
for mutation,rate in pairs(g1.mutationRates) do | |
child.mutationRates[mutation] = rate | |
end | |
return child | |
end | |
local function randomNeuron(genes, nonInput) | |
local neurons = {} | |
if not nonInput then | |
for i=1,Inputs do | |
neurons[i] = true | |
end | |
end | |
for o=1,Outputs do | |
neurons[MaxNodes+o] = true | |
end | |
for i=1,#genes do | |
if (not nonInput) or genes[i].into > Inputs then | |
neurons[genes[i].into] = true | |
end | |
if (not nonInput) or genes[i].out > Inputs then | |
neurons[genes[i].out] = true | |
end | |
end | |
local count = 0 | |
for _,_ in pairs(neurons) do | |
count = count + 1 | |
end | |
local n = math.random(1, count) | |
for k,v in pairs(neurons) do | |
n = n-1 | |
if n == 0 then | |
return k | |
end | |
end | |
return 0 | |
end | |
local function containsLink(genes, link) | |
for i=1,#genes do | |
local gene = genes[i] | |
if gene.into == link.into and gene.out == link.out then | |
return true | |
end | |
end | |
end | |
local function pointMutate(genome) | |
local step = genome.mutationRates["step"] | |
for i=1,#genome.genes do | |
local gene = genome.genes[i] | |
if math.random() < PerturbChance then | |
gene.weight = gene.weight + math.random() * step*2 - step | |
else | |
gene.weight = math.random()*4-2 | |
end | |
end | |
end | |
local function linkMutate(genome, forceBias) | |
local neuron1 = randomNeuron(genome.genes, false) | |
local neuron2 = randomNeuron(genome.genes, true) | |
local newLink = newGene() | |
if neuron1 <= Inputs and neuron2 <= Inputs then | |
--Both input nodes | |
return | |
end | |
if neuron2 <= Inputs then | |
-- Swap output and input | |
local temp = neuron1 | |
neuron1 = neuron2 | |
neuron2 = temp | |
end | |
newLink.into = neuron1 | |
newLink.out = neuron2 | |
if forceBias then | |
newLink.into = Inputs | |
end | |
if containsLink(genome.genes, newLink) then | |
return | |
end | |
newLink.innovation = newInnovation() | |
newLink.weight = math.random()*4-2 | |
table.insert(genome.genes, newLink) | |
end | |
local function nodeMutate(genome) | |
if #genome.genes == 0 then | |
return | |
end | |
genome.maxneuron = genome.maxneuron + 1 | |
local gene = genome.genes[math.random(1,#genome.genes)] | |
if not gene.enabled then | |
return | |
end | |
gene.enabled = false | |
local gene1 = copyGene(gene) | |
gene1.out = genome.maxneuron | |
gene1.weight = 1.0 | |
gene1.innovation = newInnovation() | |
gene1.enabled = true | |
table.insert(genome.genes, gene1) | |
local gene2 = copyGene(gene) | |
gene2.into = genome.maxneuron | |
gene2.innovation = newInnovation() | |
gene2.enabled = true | |
table.insert(genome.genes, gene2) | |
end | |
local function enableDisableMutate(genome, enable) | |
local candidates = {} | |
for _,gene in pairs(genome.genes) do | |
if gene.enabled == not enable then | |
table.insert(candidates, gene) | |
end | |
end | |
if #candidates == 0 then | |
return | |
end | |
local gene = candidates[math.random(1,#candidates)] | |
gene.enabled = not gene.enabled | |
end | |
local function mutate(genome) | |
for mutation,rate in pairs(genome.mutationRates) do | |
if math.random(1,2) == 1 then | |
genome.mutationRates[mutation] = 0.95*rate | |
else | |
genome.mutationRates[mutation] = 1.05263*rate | |
end | |
end | |
if math.random() < genome.mutationRates["connections"] then | |
pointMutate(genome) | |
end | |
local p = genome.mutationRates["link"] | |
while p > 0 do | |
if math.random() < p then | |
linkMutate(genome, false) | |
end | |
p = p - 1 | |
end | |
p = genome.mutationRates["bias"] | |
while p > 0 do | |
if math.random() < p then | |
linkMutate(genome, true) | |
end | |
p = p - 1 | |
end | |
p = genome.mutationRates["node"] | |
while p > 0 do | |
if math.random() < p then | |
nodeMutate(genome) | |
end | |
p = p - 1 | |
end | |
p = genome.mutationRates["enable"] | |
while p > 0 do | |
if math.random() < p then | |
enableDisableMutate(genome, true) | |
end | |
p = p - 1 | |
end | |
p = genome.mutationRates["disable"] | |
while p > 0 do | |
if math.random() < p then | |
enableDisableMutate(genome, false) | |
end | |
p = p - 1 | |
end | |
end | |
local function disjoint(genes1, genes2) | |
local i1 = {} | |
for i = 1,#genes1 do | |
local gene = genes1[i] | |
i1[gene.innovation] = true | |
end | |
local i2 = {} | |
for i = 1,#genes2 do | |
local gene = genes2[i] | |
i2[gene.innovation] = true | |
end | |
local disjointGenes = 0 | |
for i = 1,#genes1 do | |
local gene = genes1[i] | |
if not i2[gene.innovation] then | |
disjointGenes = disjointGenes+1 | |
end | |
end | |
for i = 1,#genes2 do | |
local gene = genes2[i] | |
if not i1[gene.innovation] then | |
disjointGenes = disjointGenes+1 | |
end | |
end | |
local n = math.max(#genes1, #genes2) | |
return disjointGenes / n | |
end | |
local function weights(genes1, genes2) | |
local i2 = {} | |
for i = 1,#genes2 do | |
local gene = genes2[i] | |
i2[gene.innovation] = gene | |
end | |
local sum = 0 | |
local coincident = 0 | |
for i = 1,#genes1 do | |
local gene = genes1[i] | |
if i2[gene.innovation] ~= nil then | |
local gene2 = i2[gene.innovation] | |
sum = sum + math.abs(gene.weight - gene2.weight) | |
coincident = coincident + 1 | |
end | |
end | |
return sum / coincident | |
end | |
local function sameSpecies(genome1, genome2) | |
local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes) | |
local dw = DeltaWeights*weights(genome1.genes, genome2.genes) | |
return dd + dw < DeltaThreshold | |
end | |
local function rankGlobally() | |
local global = {} | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
for g = 1,#species.genomes do | |
table.insert(global, species.genomes[g]) | |
end | |
end | |
table.sort(global, function (a,b) | |
return (a.fitness < b.fitness) | |
end) | |
for g=1,#global do | |
global[g].globalRank = g | |
end | |
end | |
local function calculateAverageFitness(species) | |
local total = 0 | |
for g=1,#species.genomes do | |
local genome = species.genomes[g] | |
total = total + genome.globalRank | |
end | |
species.averageFitness = total / #species.genomes | |
end | |
local function totalAverageFitness() | |
local total = 0 | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
total = total + species.averageFitness | |
end | |
return total | |
end | |
local function cullSpecies(cutToOne) | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
table.sort(species.genomes, function (a,b) | |
return (a.fitness > b.fitness) | |
end) | |
local remaining = math.ceil(#species.genomes/2) | |
if cutToOne then | |
remaining = 1 | |
end | |
while #species.genomes > remaining do | |
table.remove(species.genomes) | |
end | |
end | |
end | |
local function breedChild(species) | |
local child = {} | |
if math.random() < CrossoverChance then | |
g1 = species.genomes[math.random(1, #species.genomes)] | |
g2 = species.genomes[math.random(1, #species.genomes)] | |
child = crossover(g1, g2) | |
else | |
g = species.genomes[math.random(1, #species.genomes)] | |
child = copyGenome(g) | |
end | |
mutate(child) | |
return child | |
end | |
local function removeStaleSpecies() | |
local survived = {} | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
table.sort(species.genomes, function (a,b) | |
return (a.fitness > b.fitness) | |
end) | |
if species.genomes[1].fitness > species.topFitness then | |
species.topFitness = species.genomes[1].fitness | |
species.staleness = 0 | |
else | |
species.staleness = species.staleness + 1 | |
end | |
if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then | |
table.insert(survived, species) | |
end | |
end | |
pool.species = survived | |
end | |
local function removeWeakSpecies() | |
local survived = {} | |
local sum = totalAverageFitness() | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
breed = math.floor(species.averageFitness / sum * Population) | |
if breed >= 1 then | |
table.insert(survived, species) | |
end | |
end | |
pool.species = survived | |
end | |
local function addToSpecies(child) | |
local foundSpecies = false | |
for s=1,#pool.species do | |
local species = pool.species[s] | |
if not foundSpecies and sameSpecies(child, species.genomes[1]) then | |
table.insert(species.genomes, child) | |
foundSpecies = true | |
end | |
end | |
if not foundSpecies then | |
local childSpecies = newSpecies() | |
table.insert(childSpecies.genomes, child) | |
table.insert(pool.species, childSpecies) | |
end | |
end | |
local function newGeneration() | |
cullSpecies(false) -- Cull the bottom half of each species | |
rankGlobally() | |
removeStaleSpecies() | |
rankGlobally() | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
calculateAverageFitness(species) | |
end | |
removeWeakSpecies() | |
local sum = totalAverageFitness() | |
local children = {} | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
breed = math.floor(species.averageFitness / sum * Population) - 1 | |
for i=1,breed do | |
table.insert(children, breedChild(species)) | |
end | |
end | |
cullSpecies(true) -- Cull all but the top member of each species | |
while #children + #pool.species < Population do | |
local species = pool.species[math.random(1, #pool.species)] | |
table.insert(children, breedChild(species)) | |
end | |
for c=1,#children do | |
local child = children[c] | |
addToSpecies(child) | |
end | |
pool.generation = pool.generation + 1 | |
writeFile("backup." .. pool.generation .. "." .. emu.softname() .. emu.romname() .. ".pool") | |
end | |
local function basicGenome() | |
local genome = newGenome() | |
local innovation = 1 | |
genome.maxneuron = Inputs | |
mutate(genome) | |
return genome | |
end | |
local function joypad_set(buttons) | |
for name, button in pairs(buttons) do | |
button.field:set_value(button.state) | |
end | |
end | |
local function clearJoypad() | |
controller = {} | |
for b = 1,#ButtonNames do | |
for name, port in pairs(manager:machine():ioport().ports) do | |
if name:match("ctrl1") then | |
for fname, field in pairs(port.fields) do | |
if fname == ButtonNames[b] then | |
local button = {} | |
button.port = port | |
button.field = field | |
button.state = 0 | |
controller[ButtonNames[b]] = button | |
end | |
end | |
break | |
end | |
end | |
end | |
joypad_set(controller) | |
end | |
local function evaluateCurrent() | |
local species = pool.species[pool.currentSpecies] | |
local genome = species.genomes[pool.currentGenome] | |
inputs = getInputs() | |
evaluateNetwork(genome.network, inputs) | |
if controller["P1 Left"].state ~= 0 and controller["P1 Right"].state ~= 0 then | |
controller["P1 Left"].state = 0 | |
controller["P1 Right"].state = 0 | |
end | |
if controller["P1 Up"].state ~= 0 and controller["P1 Down"].state ~= 0 then | |
controller["P1 Up"].state = 0 | |
controller["P1 Down"].state = 0 | |
end | |
joypad_set(controller) | |
end | |
local function initializeRun() | |
manager:machine():load(Filename); | |
rightmost = 0 | |
pool.currentFrame = 0 | |
timeout = TimeoutConstant | |
clearJoypad() | |
local species = pool.species[pool.currentSpecies] | |
local genome = species.genomes[pool.currentGenome] | |
generateNetwork(genome) | |
evaluateCurrent() | |
end | |
local function initializePool() | |
pool = newPool() | |
for i=1,Population do | |
basic = basicGenome() | |
addToSpecies(basic) | |
end | |
initializeRun() | |
end | |
if pool == nil then | |
initializePool() | |
end | |
local function nextGenome() | |
pool.currentGenome = pool.currentGenome + 1 | |
if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then | |
pool.currentGenome = 1 | |
pool.currentSpecies = pool.currentSpecies+1 | |
if pool.currentSpecies > #pool.species then | |
newGeneration() | |
pool.currentSpecies = 1 | |
end | |
end | |
end | |
local function fitnessAlreadyMeasured() | |
local species = pool.species[pool.currentSpecies] | |
local genome = species.genomes[pool.currentGenome] | |
return genome.fitness ~= 0 | |
end | |
local function displayGenome(genome) | |
local network = genome.network | |
local cells = {} | |
local i = 1 | |
local cell = {} | |
for dy=-BoxRadius,BoxRadius do | |
for dx=-BoxRadius,BoxRadius do | |
cell = {} | |
cell.x = 50+5*dx | |
cell.y = 70+5*dy | |
cell.value = network.neurons[i].value | |
cells[i] = cell | |
i = i + 1 | |
end | |
end | |
local biasCell = {} | |
biasCell.x = 80 | |
biasCell.y = 110 | |
biasCell.value = network.neurons[Inputs].value | |
cells[Inputs] = biasCell | |
for o = 1,Outputs do | |
cell = {} | |
cell.x = 220 | |
cell.y = 30 + 8 * o | |
cell.value = network.neurons[MaxNodes + o].value | |
cells[MaxNodes+o] = cell | |
local color | |
if cell.value > 0 then | |
color = 0xFF0000FF | |
else | |
color = 0xFF000000 | |
end | |
gui:draw_text(223, 24+8*o, ButtonNames[o], color) | |
end | |
for n,neuron in pairs(network.neurons) do | |
cell = {} | |
if n > Inputs and n <= MaxNodes then | |
cell.x = 140 | |
cell.y = 40 | |
cell.value = neuron.value | |
cells[n] = cell | |
end | |
end | |
for n=1,4 do | |
for _,gene in pairs(genome.genes) do | |
if gene.enabled then | |
local c1 = cells[gene.into] | |
local c2 = cells[gene.out] | |
if gene.into > Inputs and gene.into <= MaxNodes then | |
c1.x = 0.75*c1.x + 0.25*c2.x | |
if c1.x >= c2.x then | |
c1.x = c1.x - 40 | |
end | |
if c1.x < 90 then | |
c1.x = 90 | |
end | |
if c1.x > 220 then | |
c1.x = 220 | |
end | |
c1.y = 0.75*c1.y + 0.25*c2.y | |
end | |
if gene.out > Inputs and gene.out <= MaxNodes then | |
c2.x = 0.25*c1.x + 0.75*c2.x | |
if c1.x >= c2.x then | |
c2.x = c2.x + 40 | |
end | |
if c2.x < 90 then | |
c2.x = 90 | |
end | |
if c2.x > 220 then | |
c2.x = 220 | |
end | |
c2.y = 0.25*c1.y + 0.75*c2.y | |
end | |
end | |
end | |
end | |
gui:draw_box(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2, 0x80808080,0xFF000000) | |
for n,cell in pairs(cells) do | |
if n > Inputs or cell.value ~= 0 then | |
local color = math.floor((cell.value+1)/2*256) | |
if color > 255 then color = 255 end | |
if color < 0 then color = 0 end | |
local opacity = 0xFF000000 | |
if cell.value == 0 then | |
opacity = 0x50000000 | |
end | |
color = opacity + color*0x10000 + color*0x100 + color | |
gui:draw_box(cell.x-2,cell.y-2,cell.x+2,cell.y+2,color,opacity) | |
end | |
end | |
for _,gene in pairs(genome.genes) do | |
if gene.enabled then | |
local c1 = cells[gene.into] | |
local c2 = cells[gene.out] | |
local opacity = 0xA0000000 | |
if c1.value == 0 then | |
opacity = 0x20000000 | |
end | |
local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80) | |
if gene.weight > 0 then | |
color = opacity + 0x8000 + 0x10000*color | |
else | |
color = opacity + 0x800000 + 0x100*color | |
end | |
gui:draw_line(c1.x+1, c1.y, c2.x-3, c2.y, color) | |
end | |
end | |
gui:draw_box(49,71,51,78,0x80FF0000,0x00000000) | |
--[[ if forms.ischecked(showMutationRates) then | |
local pos = 100 | |
for mutation,rate in pairs(genome.mutationRates) do | |
gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10) | |
pos = pos + 8 | |
end | |
end]] | |
end | |
local function savePool() | |
local filename = emu.softname() .. emu.romname() .. ".pool" | |
writeFile(filename) | |
end | |
local function loadFile(filename) | |
local file = io.open(filename, "r") | |
pool = newPool() | |
pool.generation = file:read("*number") | |
pool.maxFitness = file:read("*number") | |
-- forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) | |
local numSpecies = file:read("*number") | |
for s=1,numSpecies do | |
local species = newSpecies() | |
table.insert(pool.species, species) | |
species.topFitness = file:read("*number") | |
species.staleness = file:read("*number") | |
local numGenomes = file:read("*number") | |
for g=1,numGenomes do | |
local genome = newGenome() | |
table.insert(species.genomes, genome) | |
genome.fitness = file:read("*number") | |
genome.maxneuron = file:read("*number") | |
local line = file:read("*line") | |
while line ~= "done" do | |
genome.mutationRates[line] = file:read("*number") | |
line = file:read("*line") | |
end | |
local numGenes = file:read("*number") | |
for n=1,numGenes do | |
local gene = newGene() | |
table.insert(genome.genes, gene) | |
local enabled | |
gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number") | |
if enabled == 0 then | |
gene.enabled = false | |
else | |
gene.enabled = true | |
end | |
end | |
end | |
end | |
file:close() | |
while fitnessAlreadyMeasured() do | |
nextGenome() | |
end | |
initializeRun() | |
pool.currentFrame = pool.currentFrame + 1 | |
end | |
local function loadPool() | |
local filename = emu.softname() .. emu.romname() .. ".pool" | |
loadFile(filename) | |
end | |
local function playTop() | |
local maxfitness = 0 | |
local maxs, maxg | |
for s,species in pairs(pool.species) do | |
for g,genome in pairs(species.genomes) do | |
if genome.fitness > maxfitness then | |
maxfitness = genome.fitness | |
maxs = s | |
maxg = g | |
end | |
end | |
end | |
pool.currentSpecies = maxs | |
pool.currentGenome = maxg | |
pool.maxFitness = maxfitness | |
-- forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) | |
initializeRun() | |
pool.currentFrame = pool.currentFrame + 1 | |
return | |
end | |
writeFile("temp.pool") | |
--event.onexit(onExit) | |
--[[form = forms.newform(200, 260, "Fitness") | |
maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8) | |
showNetwork = forms.checkbox(form, "Show Map", 5, 30) | |
showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52) | |
restartButton = forms.button(form, "Restart", initializePool, 5, 77) | |
saveButton = forms.button(form, "Save", savePool, 5, 102) | |
loadButton = forms.button(form, "Load", loadPool, 80, 102) | |
saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148) | |
saveLoadLabel = forms.label(form, "Save/Load:", 5, 129) | |
playTopButton = forms.button(form, "Play Top", playTop, 5, 170) | |
hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)]] | |
emu.register_frame_done(function() | |
local backgroundColor = 0xD0FFFFFF | |
-- if not forms.ischecked(hideBanner) then | |
gui:draw_box(0, 0, 300, 26, backgroundColor, backgroundColor) | |
-- end | |
local species = pool.species[pool.currentSpecies] | |
local genome = species.genomes[pool.currentGenome] | |
-- if forms.ischecked(showNetwork) then | |
displayGenome(genome) | |
-- end | |
if pool.currentFrame%5 == 0 then | |
evaluateCurrent() | |
end | |
joypad_set(controller) | |
getPositions() | |
if marioX > rightmost then | |
rightmost = marioX | |
timeout = TimeoutConstant | |
end | |
timeout = timeout - 1 | |
local timeoutBonus = pool.currentFrame / 4 | |
if timeout + timeoutBonus <= 0 then | |
local fitness = rightmost - pool.currentFrame / 2 | |
if Game == "Super Mario World (USA)" and rightmost > 4816 then | |
fitness = fitness + 1000 | |
end | |
if Game == "Super Mario Bros." and rightmost > 3186 then | |
fitness = fitness + 1000 | |
end | |
if fitness == 0 then | |
fitness = -1 | |
end | |
genome.fitness = fitness | |
if fitness > pool.maxFitness then | |
pool.maxFitness = fitness | |
--forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) | |
writeFile("backup." .. pool.generation .. "." .. emu.softname() .. emu.romname() .. ".pool") | |
end | |
emu.print_error("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness .. " marioX = " .. marioX .. " marioY = " .. marioY) | |
pool.currentSpecies = 1 | |
pool.currentGenome = 1 | |
while fitnessAlreadyMeasured() do | |
nextGenome() | |
end | |
initializeRun() | |
end | |
local measured = 0 | |
local total = 0 | |
for _,species in pairs(pool.species) do | |
for _,genome in pairs(species.genomes) do | |
total = total + 1 | |
if genome.fitness ~= 0 then | |
measured = measured + 1 | |
end | |
end | |
end | |
-- if not forms.ischecked(hideBanner) then | |
gui:draw_text(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000) | |
gui:draw_text(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000) | |
gui:draw_text(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000) | |
-- end | |
pool.currentFrame = pool.currentFrame + 1 | |
end) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment