Last active
February 21, 2024 21:12
-
-
Save cracyc/02cbcac7b869329e7d8b280ef6c77fa5 to your computer and use it in GitHub Desktop.
NEATEvolve MAME
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- MarI/O by SethBling | |
-- Feel free to use this code, but please do not redistribute it. | |
-- usage: mame nes smb1 -window -autoboot_script NEATEvolve.lua | |
-- create savestate 1 at the start of the level you want it to learn | |
local Filename | |
local ButtonNames | |
local Game | |
local pool | |
local rightmost | |
local timeout | |
local marioX | |
local marioY | |
local screenX | |
local screenY | |
local controller | |
if emu.softname() == "smwu" then | |
Filename = "1" | |
ButtonNames = { | |
"A", | |
"B", | |
"X", | |
"Y", | |
"P1 Up", | |
"P1 Down", | |
"P1 Left", | |
"P1 Right", | |
} | |
Game = "Super Mario World (USA)" | |
elseif emu.softname() == "smb1" then | |
Filename = "1" | |
ButtonNames = { | |
"A", | |
"B", | |
"P1 Up", | |
"P1 Down", | |
"P1 Left", | |
"P1 Right", | |
} | |
Game = "Super Mario Bros." | |
end | |
local BoxRadius = 6 | |
local InputSize = (BoxRadius*2+1)*(BoxRadius*2+1) | |
local Inputs = InputSize+1 | |
local Outputs = #ButtonNames | |
local Population = 300 | |
local DeltaDisjoint = 2.0 | |
local DeltaWeights = 0.4 | |
local DeltaThreshold = 1.0 | |
local StaleSpecies = 15 | |
local MutateConnectionsChance = 0.25 | |
local PerturbChance = 0.90 | |
local CrossoverChance = 0.75 | |
local LinkMutationChance = 2.0 | |
local NodeMutationChance = 0.50 | |
local BiasMutationChance = 0.40 | |
local StepSize = 0.1 | |
local DisableMutationChance = 0.4 | |
local EnableMutationChance = 0.2 | |
local TimeoutConstant = 20 | |
local MaxNodes = 1000000 | |
local memory = manager:machine().devices[":maincpu"].spaces["program"] | |
local gui = select(2, next(manager:machine().screens)) | |
local function writeFile(filename) | |
local file = io.open(filename, "w") | |
file:write(pool.generation .. "\n") | |
file:write(pool.maxFitness .. "\n") | |
file:write(#pool.species .. "\n") | |
for n,species in pairs(pool.species) do | |
file:write(species.topFitness .. "\n") | |
file:write(species.staleness .. "\n") | |
file:write(#species.genomes .. "\n") | |
for m,genome in pairs(species.genomes) do | |
file:write(genome.fitness .. "\n") | |
file:write(genome.maxneuron .. "\n") | |
for mutation,rate in pairs(genome.mutationRates) do | |
file:write(mutation .. "\n") | |
file:write(rate .. "\n") | |
end | |
file:write("done\n") | |
file:write(#genome.genes .. "\n") | |
for l,gene in pairs(genome.genes) do | |
file:write(gene.into .. " ") | |
file:write(gene.out .. " ") | |
file:write(gene.weight .. " ") | |
file:write(gene.innovation .. " ") | |
if(gene.enabled) then | |
file:write("1\n") | |
else | |
file:write("0\n") | |
end | |
end | |
end | |
end | |
file:close() | |
end | |
local function getPositions() | |
if Game == "Super Mario World (USA)" then | |
marioX = memory:read_i16(0x94) | |
marioY = memory:read_i16(0x96) | |
local layer1x = memory:read_i16(0x1A); | |
local layer1y = memory:read_i16(0x1C); | |
screenX = marioX-layer1x | |
screenY = marioY-layer1y | |
elseif Game == "Super Mario Bros." then | |
marioX = memory:read_u8(0x6D) * 0x100 + memory:read_u8(0x86) | |
marioY = memory:read_u8(0x03B8)+16 | |
screenX = memory:read_u8(0x03AD) | |
screenY = memory:read_u8(0x03B8) | |
end | |
end | |
local function getTile(dx, dy) | |
local x, y | |
if Game == "Super Mario World (USA)" then | |
x = math.floor((marioX+dx+8)/16) | |
y = math.floor((marioY+dy)/16) | |
return memory:read_u8(0x7FC800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10) | |
elseif Game == "Super Mario Bros." then | |
x = marioX + dx + 8 | |
y = marioY + dy - 16 | |
local page = math.floor(x/256)%2 | |
local subx = math.floor((x%256)/16) | |
local suby = math.floor((y - 32)/16) | |
local addr = 0x500 + page*13*16+suby*16+subx | |
if suby >= 13 or suby < 0 then | |
return 0 | |
end | |
if memory:read_u8(addr) ~= 0 then | |
return 1 | |
else | |
return 0 | |
end | |
end | |
end | |
local function getSprites() | |
if Game == "Super Mario World (USA)" then | |
local sprites = {} | |
for slot=0,11 do | |
local status = memory:read_u8(0x14C8+slot) | |
if status ~= 0 then | |
spritex = memory:read_u8(0xE4+slot) + memory:read_u8(0x14E0+slot)*256 | |
spritey = memory:read_u8(0xD8+slot) + memory:read_u8(0x14D4+slot)*256 | |
sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey} | |
end | |
end | |
return sprites | |
elseif Game == "Super Mario Bros." then | |
local sprites = {} | |
for slot=0,4 do | |
local enemy = memory:read_u8(0xF+slot) | |
if enemy ~= 0 then | |
local ex = memory:read_u8(0x6E + slot)*0x100 + memory:read_u8(0x87+slot) | |
local ey = memory:read_u8(0xCF + slot)+24 | |
sprites[#sprites+1] = {["x"]=ex,["y"]=ey} | |
end | |
end | |
return sprites | |
end | |
end | |
local function getExtendedSprites() | |
if Game == "Super Mario World (USA)" then | |
local extended = {} | |
for slot=0,11 do | |
local number = memory:read_u8(0x170B+slot) | |
if number ~= 0 then | |
spritex = memory:read_u8(0x171F+slot) + memory:read_u8(0x1733+slot)*256 | |
spritey = memory:read_u8(0x1715+slot) + memory:read_u8(0x1729+slot)*256 | |
extended[#extended+1] = {["x"]=spritex, ["y"]=spritey} | |
end | |
end | |
return extended | |
elseif Game == "Super Mario Bros." then | |
return {} | |
end | |
end | |
local function getInputs() | |
getPositions() | |
local sprites = getSprites() | |
local extended = getExtendedSprites() | |
local inputs = {} | |
for dy=-BoxRadius*16,BoxRadius*16,16 do | |
for dx=-BoxRadius*16,BoxRadius*16,16 do | |
inputs[#inputs+1] = 0 | |
local tile = getTile(dx, dy) | |
if tile == 1 and marioY+dy < 0x1B0 then | |
inputs[#inputs] = 1 | |
end | |
for i = 1,#sprites do | |
distx = math.abs(sprites[i]["x"] - (marioX+dx)) | |
disty = math.abs(sprites[i]["y"] - (marioY+dy)) | |
if distx <= 8 and disty <= 8 then | |
inputs[#inputs] = -1 | |
end | |
end | |
for i = 1,#extended do | |
distx = math.abs(extended[i]["x"] - (marioX+dx)) | |
disty = math.abs(extended[i]["y"] - (marioY+dy)) | |
if distx < 8 and disty < 8 then | |
inputs[#inputs] = -1 | |
end | |
end | |
end | |
end | |
--mariovx = memory.read_s8(0x7B) | |
--mariovy = memory.read_s8(0x7D) | |
return inputs | |
end | |
local function sigmoid(x) | |
return 2/(1+math.exp(-4.9*x))-1 | |
end | |
local function newInnovation() | |
pool.innovation = pool.innovation + 1 | |
return pool.innovation | |
end | |
local function newPool() | |
local pool = {} | |
pool.species = {} | |
pool.generation = 0 | |
pool.innovation = Outputs | |
pool.currentSpecies = 1 | |
pool.currentGenome = 1 | |
pool.currentFrame = 0 | |
pool.maxFitness = 0 | |
return pool | |
end | |
local function newSpecies() | |
local species = {} | |
species.topFitness = 0 | |
species.staleness = 0 | |
species.genomes = {} | |
species.averageFitness = 0 | |
return species | |
end | |
local function newGenome() | |
local genome = {} | |
genome.genes = {} | |
genome.fitness = 0 | |
genome.adjustedFitness = 0 | |
genome.network = {} | |
genome.maxneuron = 0 | |
genome.globalRank = 0 | |
genome.mutationRates = {} | |
genome.mutationRates["connections"] = MutateConnectionsChance | |
genome.mutationRates["link"] = LinkMutationChance | |
genome.mutationRates["bias"] = BiasMutationChance | |
genome.mutationRates["node"] = NodeMutationChance | |
genome.mutationRates["enable"] = EnableMutationChance | |
genome.mutationRates["disable"] = DisableMutationChance | |
genome.mutationRates["step"] = StepSize | |
return genome | |
end | |
local function newGene() | |
local gene = {} | |
gene.into = 0 | |
gene.out = 0 | |
gene.weight = 0.0 | |
gene.enabled = true | |
gene.innovation = 0 | |
return gene | |
end | |
local function copyGene(gene) | |
local gene2 = newGene() | |
gene2.into = gene.into | |
gene2.out = gene.out | |
gene2.weight = gene.weight | |
gene2.enabled = gene.enabled | |
gene2.innovation = gene.innovation | |
return gene2 | |
end | |
local function copyGenome(genome) | |
local genome2 = newGenome() | |
for g=1,#genome.genes do | |
table.insert(genome2.genes, copyGene(genome.genes[g])) | |
end | |
genome2.maxneuron = genome.maxneuron | |
genome2.mutationRates["connections"] = genome.mutationRates["connections"] | |
genome2.mutationRates["link"] = genome.mutationRates["link"] | |
genome2.mutationRates["bias"] = genome.mutationRates["bias"] | |
genome2.mutationRates["node"] = genome.mutationRates["node"] | |
genome2.mutationRates["enable"] = genome.mutationRates["enable"] | |
genome2.mutationRates["disable"] = genome.mutationRates["disable"] | |
return genome2 | |
end | |
local function newNeuron() | |
local neuron = {} | |
neuron.incoming = {} | |
neuron.value = 0.0 | |
return neuron | |
end | |
local function generateNetwork(genome) | |
local network = {} | |
network.neurons = {} | |
for i=1,Inputs do | |
network.neurons[i] = newNeuron() | |
end | |
for o=1,Outputs do | |
network.neurons[MaxNodes+o] = newNeuron() | |
end | |
table.sort(genome.genes, function (a,b) | |
return (a.out < b.out) | |
end) | |
for i=1,#genome.genes do | |
local gene = genome.genes[i] | |
if gene.enabled then | |
if network.neurons[gene.out] == nil then | |
network.neurons[gene.out] = newNeuron() | |
end | |
local neuron = network.neurons[gene.out] | |
table.insert(neuron.incoming, gene) | |
if network.neurons[gene.into] == nil then | |
network.neurons[gene.into] = newNeuron() | |
end | |
end | |
end | |
genome.network = network | |
end | |
local function evaluateNetwork(network, inputs) | |
table.insert(inputs, 1) | |
if #inputs ~= Inputs then | |
emu.print_error("Incorrect number of neural network inputs.") | |
return {} | |
end | |
for i=1,Inputs do | |
network.neurons[i].value = inputs[i] | |
end | |
for _,neuron in pairs(network.neurons) do | |
local sum = 0 | |
for j = 1,#neuron.incoming do | |
local incoming = neuron.incoming[j] | |
local other = network.neurons[incoming.into] | |
sum = sum + incoming.weight * other.value | |
end | |
if #neuron.incoming > 0 then | |
neuron.value = sigmoid(sum) | |
end | |
end | |
local outputs = {} | |
for o=1,Outputs do | |
local button = ButtonNames[o] | |
if network.neurons[MaxNodes+o].value > 0 then | |
controller[button].state = 1 | |
else | |
controller[button].state = 0 | |
end | |
end | |
return outputs | |
end | |
local function crossover(g1, g2) | |
-- Make sure g1 is the higher fitness genome | |
if g2.fitness > g1.fitness then | |
local tempg = g1 | |
g1 = g2 | |
g2 = tempg | |
end | |
local child = newGenome() | |
local innovations2 = {} | |
for i=1,#g2.genes do | |
local gene = g2.genes[i] | |
innovations2[gene.innovation] = gene | |
end | |
for i=1,#g1.genes do | |
local gene1 = g1.genes[i] | |
local gene2 = innovations2[gene1.innovation] | |
if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then | |
table.insert(child.genes, copyGene(gene2)) | |
else | |
table.insert(child.genes, copyGene(gene1)) | |
end | |
end | |
child.maxneuron = math.max(g1.maxneuron,g2.maxneuron) | |
for mutation,rate in pairs(g1.mutationRates) do | |
child.mutationRates[mutation] = rate | |
end | |
return child | |
end | |
local function randomNeuron(genes, nonInput) | |
local neurons = {} | |
if not nonInput then | |
for i=1,Inputs do | |
neurons[i] = true | |
end | |
end | |
for o=1,Outputs do | |
neurons[MaxNodes+o] = true | |
end | |
for i=1,#genes do | |
if (not nonInput) or genes[i].into > Inputs then | |
neurons[genes[i].into] = true | |
end | |
if (not nonInput) or genes[i].out > Inputs then | |
neurons[genes[i].out] = true | |
end | |
end | |
local count = 0 | |
for _,_ in pairs(neurons) do | |
count = count + 1 | |
end | |
local n = math.random(1, count) | |
for k,v in pairs(neurons) do | |
n = n-1 | |
if n == 0 then | |
return k | |
end | |
end | |
return 0 | |
end | |
local function containsLink(genes, link) | |
for i=1,#genes do | |
local gene = genes[i] | |
if gene.into == link.into and gene.out == link.out then | |
return true | |
end | |
end | |
end | |
local function pointMutate(genome) | |
local step = genome.mutationRates["step"] | |
for i=1,#genome.genes do | |
local gene = genome.genes[i] | |
if math.random() < PerturbChance then | |
gene.weight = gene.weight + math.random() * step*2 - step | |
else | |
gene.weight = math.random()*4-2 | |
end | |
end | |
end | |
local function linkMutate(genome, forceBias) | |
local neuron1 = randomNeuron(genome.genes, false) | |
local neuron2 = randomNeuron(genome.genes, true) | |
local newLink = newGene() | |
if neuron1 <= Inputs and neuron2 <= Inputs then | |
--Both input nodes | |
return | |
end | |
if neuron2 <= Inputs then | |
-- Swap output and input | |
local temp = neuron1 | |
neuron1 = neuron2 | |
neuron2 = temp | |
end | |
newLink.into = neuron1 | |
newLink.out = neuron2 | |
if forceBias then | |
newLink.into = Inputs | |
end | |
if containsLink(genome.genes, newLink) then | |
return | |
end | |
newLink.innovation = newInnovation() | |
newLink.weight = math.random()*4-2 | |
table.insert(genome.genes, newLink) | |
end | |
local function nodeMutate(genome) | |
if #genome.genes == 0 then | |
return | |
end | |
genome.maxneuron = genome.maxneuron + 1 | |
local gene = genome.genes[math.random(1,#genome.genes)] | |
if not gene.enabled then | |
return | |
end | |
gene.enabled = false | |
local gene1 = copyGene(gene) | |
gene1.out = genome.maxneuron | |
gene1.weight = 1.0 | |
gene1.innovation = newInnovation() | |
gene1.enabled = true | |
table.insert(genome.genes, gene1) | |
local gene2 = copyGene(gene) | |
gene2.into = genome.maxneuron | |
gene2.innovation = newInnovation() | |
gene2.enabled = true | |
table.insert(genome.genes, gene2) | |
end | |
local function enableDisableMutate(genome, enable) | |
local candidates = {} | |
for _,gene in pairs(genome.genes) do | |
if gene.enabled == not enable then | |
table.insert(candidates, gene) | |
end | |
end | |
if #candidates == 0 then | |
return | |
end | |
local gene = candidates[math.random(1,#candidates)] | |
gene.enabled = not gene.enabled | |
end | |
local function mutate(genome) | |
for mutation,rate in pairs(genome.mutationRates) do | |
if math.random(1,2) == 1 then | |
genome.mutationRates[mutation] = 0.95*rate | |
else | |
genome.mutationRates[mutation] = 1.05263*rate | |
end | |
end | |
if math.random() < genome.mutationRates["connections"] then | |
pointMutate(genome) | |
end | |
local p = genome.mutationRates["link"] | |
while p > 0 do | |
if math.random() < p then | |
linkMutate(genome, false) | |
end | |
p = p - 1 | |
end | |
p = genome.mutationRates["bias"] | |
while p > 0 do | |
if math.random() < p then | |
linkMutate(genome, true) | |
end | |
p = p - 1 | |
end | |
p = genome.mutationRates["node"] | |
while p > 0 do | |
if math.random() < p then | |
nodeMutate(genome) | |
end | |
p = p - 1 | |
end | |
p = genome.mutationRates["enable"] | |
while p > 0 do | |
if math.random() < p then | |
enableDisableMutate(genome, true) | |
end | |
p = p - 1 | |
end | |
p = genome.mutationRates["disable"] | |
while p > 0 do | |
if math.random() < p then | |
enableDisableMutate(genome, false) | |
end | |
p = p - 1 | |
end | |
end | |
local function disjoint(genes1, genes2) | |
local i1 = {} | |
for i = 1,#genes1 do | |
local gene = genes1[i] | |
i1[gene.innovation] = true | |
end | |
local i2 = {} | |
for i = 1,#genes2 do | |
local gene = genes2[i] | |
i2[gene.innovation] = true | |
end | |
local disjointGenes = 0 | |
for i = 1,#genes1 do | |
local gene = genes1[i] | |
if not i2[gene.innovation] then | |
disjointGenes = disjointGenes+1 | |
end | |
end | |
for i = 1,#genes2 do | |
local gene = genes2[i] | |
if not i1[gene.innovation] then | |
disjointGenes = disjointGenes+1 | |
end | |
end | |
local n = math.max(#genes1, #genes2) | |
return disjointGenes / n | |
end | |
local function weights(genes1, genes2) | |
local i2 = {} | |
for i = 1,#genes2 do | |
local gene = genes2[i] | |
i2[gene.innovation] = gene | |
end | |
local sum = 0 | |
local coincident = 0 | |
for i = 1,#genes1 do | |
local gene = genes1[i] | |
if i2[gene.innovation] ~= nil then | |
local gene2 = i2[gene.innovation] | |
sum = sum + math.abs(gene.weight - gene2.weight) | |
coincident = coincident + 1 | |
end | |
end | |
return sum / coincident | |
end | |
local function sameSpecies(genome1, genome2) | |
local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes) | |
local dw = DeltaWeights*weights(genome1.genes, genome2.genes) | |
return dd + dw < DeltaThreshold | |
end | |
local function rankGlobally() | |
local global = {} | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
for g = 1,#species.genomes do | |
table.insert(global, species.genomes[g]) | |
end | |
end | |
table.sort(global, function (a,b) | |
return (a.fitness < b.fitness) | |
end) | |
for g=1,#global do | |
global[g].globalRank = g | |
end | |
end | |
local function calculateAverageFitness(species) | |
local total = 0 | |
for g=1,#species.genomes do | |
local genome = species.genomes[g] | |
total = total + genome.globalRank | |
end | |
species.averageFitness = total / #species.genomes | |
end | |
local function totalAverageFitness() | |
local total = 0 | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
total = total + species.averageFitness | |
end | |
return total | |
end | |
local function cullSpecies(cutToOne) | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
table.sort(species.genomes, function (a,b) | |
return (a.fitness > b.fitness) | |
end) | |
local remaining = math.ceil(#species.genomes/2) | |
if cutToOne then | |
remaining = 1 | |
end | |
while #species.genomes > remaining do | |
table.remove(species.genomes) | |
end | |
end | |
end | |
local function breedChild(species) | |
local child = {} | |
if math.random() < CrossoverChance then | |
g1 = species.genomes[math.random(1, #species.genomes)] | |
g2 = species.genomes[math.random(1, #species.genomes)] | |
child = crossover(g1, g2) | |
else | |
g = species.genomes[math.random(1, #species.genomes)] | |
child = copyGenome(g) | |
end | |
mutate(child) | |
return child | |
end | |
local function removeStaleSpecies() | |
local survived = {} | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
table.sort(species.genomes, function (a,b) | |
return (a.fitness > b.fitness) | |
end) | |
if species.genomes[1].fitness > species.topFitness then | |
species.topFitness = species.genomes[1].fitness | |
species.staleness = 0 | |
else | |
species.staleness = species.staleness + 1 | |
end | |
if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then | |
table.insert(survived, species) | |
end | |
end | |
pool.species = survived | |
end | |
local function removeWeakSpecies() | |
local survived = {} | |
local sum = totalAverageFitness() | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
breed = math.floor(species.averageFitness / sum * Population) | |
if breed >= 1 then | |
table.insert(survived, species) | |
end | |
end | |
pool.species = survived | |
end | |
local function addToSpecies(child) | |
local foundSpecies = false | |
for s=1,#pool.species do | |
local species = pool.species[s] | |
if not foundSpecies and sameSpecies(child, species.genomes[1]) then | |
table.insert(species.genomes, child) | |
foundSpecies = true | |
end | |
end | |
if not foundSpecies then | |
local childSpecies = newSpecies() | |
table.insert(childSpecies.genomes, child) | |
table.insert(pool.species, childSpecies) | |
end | |
end | |
local function newGeneration() | |
cullSpecies(false) -- Cull the bottom half of each species | |
rankGlobally() | |
removeStaleSpecies() | |
rankGlobally() | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
calculateAverageFitness(species) | |
end | |
removeWeakSpecies() | |
local sum = totalAverageFitness() | |
local children = {} | |
for s = 1,#pool.species do | |
local species = pool.species[s] | |
breed = math.floor(species.averageFitness / sum * Population) - 1 | |
for i=1,breed do | |
table.insert(children, breedChild(species)) | |
end | |
end | |
cullSpecies(true) -- Cull all but the top member of each species | |
while #children + #pool.species < Population do | |
local species = pool.species[math.random(1, #pool.species)] | |
table.insert(children, breedChild(species)) | |
end | |
for c=1,#children do | |
local child = children[c] | |
addToSpecies(child) | |
end | |
pool.generation = pool.generation + 1 | |
writeFile("backup." .. pool.generation .. "." .. emu.softname() .. emu.romname() .. ".pool") | |
end | |
local function basicGenome() | |
local genome = newGenome() | |
local innovation = 1 | |
genome.maxneuron = Inputs | |
mutate(genome) | |
return genome | |
end | |
local function joypad_set(buttons) | |
for name, button in pairs(buttons) do | |
button.field:set_value(button.state) | |
end | |
end | |
local function clearJoypad() | |
controller = {} | |
for b = 1,#ButtonNames do | |
for name, port in pairs(manager:machine():ioport().ports) do | |
if name:match("ctrl1") then | |
for fname, field in pairs(port.fields) do | |
if fname == ButtonNames[b] then | |
local button = {} | |
button.port = port | |
button.field = field | |
button.state = 0 | |
controller[ButtonNames[b]] = button | |
end | |
end | |
break | |
end | |
end | |
end | |
joypad_set(controller) | |
end | |
local function evaluateCurrent() | |
local species = pool.species[pool.currentSpecies] | |
local genome = species.genomes[pool.currentGenome] | |
inputs = getInputs() | |
evaluateNetwork(genome.network, inputs) | |
if controller["P1 Left"].state ~= 0 and controller["P1 Right"].state ~= 0 then | |
controller["P1 Left"].state = 0 | |
controller["P1 Right"].state = 0 | |
end | |
if controller["P1 Up"].state ~= 0 and controller["P1 Down"].state ~= 0 then | |
controller["P1 Up"].state = 0 | |
controller["P1 Down"].state = 0 | |
end | |
joypad_set(controller) | |
end | |
local function initializeRun() | |
manager:machine():load(Filename); | |
rightmost = 0 | |
pool.currentFrame = 0 | |
timeout = TimeoutConstant | |
clearJoypad() | |
local species = pool.species[pool.currentSpecies] | |
local genome = species.genomes[pool.currentGenome] | |
generateNetwork(genome) | |
evaluateCurrent() | |
end | |
local function initializePool() | |
pool = newPool() | |
for i=1,Population do | |
basic = basicGenome() | |
addToSpecies(basic) | |
end | |
initializeRun() | |
end | |
if pool == nil then | |
initializePool() | |
end | |
local function nextGenome() | |
pool.currentGenome = pool.currentGenome + 1 | |
if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then | |
pool.currentGenome = 1 | |
pool.currentSpecies = pool.currentSpecies+1 | |
if pool.currentSpecies > #pool.species then | |
newGeneration() | |
pool.currentSpecies = 1 | |
end | |
end | |
end | |
local function fitnessAlreadyMeasured() | |
local species = pool.species[pool.currentSpecies] | |
local genome = species.genomes[pool.currentGenome] | |
return genome.fitness ~= 0 | |
end | |
local function displayGenome(genome) | |
local network = genome.network | |
local cells = {} | |
local i = 1 | |
local cell = {} | |
for dy=-BoxRadius,BoxRadius do | |
for dx=-BoxRadius,BoxRadius do | |
cell = {} | |
cell.x = 50+5*dx | |
cell.y = 70+5*dy | |
cell.value = network.neurons[i].value | |
cells[i] = cell | |
i = i + 1 | |
end | |
end | |
local biasCell = {} | |
biasCell.x = 80 | |
biasCell.y = 110 | |
biasCell.value = network.neurons[Inputs].value | |
cells[Inputs] = biasCell | |
for o = 1,Outputs do | |
cell = {} | |
cell.x = 220 | |
cell.y = 30 + 8 * o | |
cell.value = network.neurons[MaxNodes + o].value | |
cells[MaxNodes+o] = cell | |
local color | |
if cell.value > 0 then | |
color = 0xFF0000FF | |
else | |
color = 0xFF000000 | |
end | |
gui:draw_text(223, 24+8*o, ButtonNames[o], color) | |
end | |
for n,neuron in pairs(network.neurons) do | |
cell = {} | |
if n > Inputs and n <= MaxNodes then | |
cell.x = 140 | |
cell.y = 40 | |
cell.value = neuron.value | |
cells[n] = cell | |
end | |
end | |
for n=1,4 do | |
for _,gene in pairs(genome.genes) do | |
if gene.enabled then | |
local c1 = cells[gene.into] | |
local c2 = cells[gene.out] | |
if gene.into > Inputs and gene.into <= MaxNodes then | |
c1.x = 0.75*c1.x + 0.25*c2.x | |
if c1.x >= c2.x then | |
c1.x = c1.x - 40 | |
end | |
if c1.x < 90 then | |
c1.x = 90 | |
end | |
if c1.x > 220 then | |
c1.x = 220 | |
end | |
c1.y = 0.75*c1.y + 0.25*c2.y | |
end | |
if gene.out > Inputs and gene.out <= MaxNodes then | |
c2.x = 0.25*c1.x + 0.75*c2.x | |
if c1.x >= c2.x then | |
c2.x = c2.x + 40 | |
end | |
if c2.x < 90 then | |
c2.x = 90 | |
end | |
if c2.x > 220 then | |
c2.x = 220 | |
end | |
c2.y = 0.25*c1.y + 0.75*c2.y | |
end | |
end | |
end | |
end | |
gui:draw_box(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2, 0x80808080,0xFF000000) | |
for n,cell in pairs(cells) do | |
if n > Inputs or cell.value ~= 0 then | |
local color = math.floor((cell.value+1)/2*256) | |
if color > 255 then color = 255 end | |
if color < 0 then color = 0 end | |
local opacity = 0xFF000000 | |
if cell.value == 0 then | |
opacity = 0x50000000 | |
end | |
color = opacity + color*0x10000 + color*0x100 + color | |
gui:draw_box(cell.x-2,cell.y-2,cell.x+2,cell.y+2,color,opacity) | |
end | |
end | |
for _,gene in pairs(genome.genes) do | |
if gene.enabled then | |
local c1 = cells[gene.into] | |
local c2 = cells[gene.out] | |
local opacity = 0xA0000000 | |
if c1.value == 0 then | |
opacity = 0x20000000 | |
end | |
local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80) | |
if gene.weight > 0 then | |
color = opacity + 0x8000 + 0x10000*color | |
else | |
color = opacity + 0x800000 + 0x100*color | |
end | |
gui:draw_line(c1.x+1, c1.y, c2.x-3, c2.y, color) | |
end | |
end | |
gui:draw_box(49,71,51,78,0x80FF0000,0x00000000) | |
--[[ if forms.ischecked(showMutationRates) then | |
local pos = 100 | |
for mutation,rate in pairs(genome.mutationRates) do | |
gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10) | |
pos = pos + 8 | |
end | |
end]] | |
end | |
local function savePool() | |
local filename = emu.softname() .. emu.romname() .. ".pool" | |
writeFile(filename) | |
end | |
local function loadFile(filename) | |
local file = io.open(filename, "r") | |
pool = newPool() | |
pool.generation = file:read("*number") | |
pool.maxFitness = file:read("*number") | |
-- forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) | |
local numSpecies = file:read("*number") | |
for s=1,numSpecies do | |
local species = newSpecies() | |
table.insert(pool.species, species) | |
species.topFitness = file:read("*number") | |
species.staleness = file:read("*number") | |
local numGenomes = file:read("*number") | |
for g=1,numGenomes do | |
local genome = newGenome() | |
table.insert(species.genomes, genome) | |
genome.fitness = file:read("*number") | |
genome.maxneuron = file:read("*number") | |
local line = file:read("*line") | |
while line ~= "done" do | |
genome.mutationRates[line] = file:read("*number") | |
line = file:read("*line") | |
end | |
local numGenes = file:read("*number") | |
for n=1,numGenes do | |
local gene = newGene() | |
table.insert(genome.genes, gene) | |
local enabled | |
gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number") | |
if enabled == 0 then | |
gene.enabled = false | |
else | |
gene.enabled = true | |
end | |
end | |
end | |
end | |
file:close() | |
while fitnessAlreadyMeasured() do | |
nextGenome() | |
end | |
initializeRun() | |
pool.currentFrame = pool.currentFrame + 1 | |
end | |
local function loadPool() | |
local filename = emu.softname() .. emu.romname() .. ".pool" | |
loadFile(filename) | |
end | |
local function playTop() | |
local maxfitness = 0 | |
local maxs, maxg | |
for s,species in pairs(pool.species) do | |
for g,genome in pairs(species.genomes) do | |
if genome.fitness > maxfitness then | |
maxfitness = genome.fitness | |
maxs = s | |
maxg = g | |
end | |
end | |
end | |
pool.currentSpecies = maxs | |
pool.currentGenome = maxg | |
pool.maxFitness = maxfitness | |
-- forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) | |
initializeRun() | |
pool.currentFrame = pool.currentFrame + 1 | |
return | |
end | |
writeFile("temp.pool") | |
--event.onexit(onExit) | |
--[[form = forms.newform(200, 260, "Fitness") | |
maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8) | |
showNetwork = forms.checkbox(form, "Show Map", 5, 30) | |
showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52) | |
restartButton = forms.button(form, "Restart", initializePool, 5, 77) | |
saveButton = forms.button(form, "Save", savePool, 5, 102) | |
loadButton = forms.button(form, "Load", loadPool, 80, 102) | |
saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148) | |
saveLoadLabel = forms.label(form, "Save/Load:", 5, 129) | |
playTopButton = forms.button(form, "Play Top", playTop, 5, 170) | |
hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)]] | |
emu.register_frame_done(function() | |
local backgroundColor = 0xD0FFFFFF | |
-- if not forms.ischecked(hideBanner) then | |
gui:draw_box(0, 0, 300, 26, backgroundColor, backgroundColor) | |
-- end | |
local species = pool.species[pool.currentSpecies] | |
local genome = species.genomes[pool.currentGenome] | |
-- if forms.ischecked(showNetwork) then | |
displayGenome(genome) | |
-- end | |
if pool.currentFrame%5 == 0 then | |
evaluateCurrent() | |
end | |
joypad_set(controller) | |
getPositions() | |
if marioX > rightmost then | |
rightmost = marioX | |
timeout = TimeoutConstant | |
end | |
timeout = timeout - 1 | |
local timeoutBonus = pool.currentFrame / 4 | |
if timeout + timeoutBonus <= 0 then | |
local fitness = rightmost - pool.currentFrame / 2 | |
if Game == "Super Mario World (USA)" and rightmost > 4816 then | |
fitness = fitness + 1000 | |
end | |
if Game == "Super Mario Bros." and rightmost > 3186 then | |
fitness = fitness + 1000 | |
end | |
if fitness == 0 then | |
fitness = -1 | |
end | |
genome.fitness = fitness | |
if fitness > pool.maxFitness then | |
pool.maxFitness = fitness | |
--forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) | |
writeFile("backup." .. pool.generation .. "." .. emu.softname() .. emu.romname() .. ".pool") | |
end | |
emu.print_error("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness .. " marioX = " .. marioX .. " marioY = " .. marioY) | |
pool.currentSpecies = 1 | |
pool.currentGenome = 1 | |
while fitnessAlreadyMeasured() do | |
nextGenome() | |
end | |
initializeRun() | |
end | |
local measured = 0 | |
local total = 0 | |
for _,species in pairs(pool.species) do | |
for _,genome in pairs(species.genomes) do | |
total = total + 1 | |
if genome.fitness ~= 0 then | |
measured = measured + 1 | |
end | |
end | |
end | |
-- if not forms.ischecked(hideBanner) then | |
gui:draw_text(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000) | |
gui:draw_text(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000) | |
gui:draw_text(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000) | |
-- end | |
pool.currentFrame = pool.currentFrame + 1 | |
end) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment