Skip to content

Instantly share code, notes, and snippets.

@ppkn
Created December 10, 2022 19:22
Show Gist options
  • Save ppkn/4cd4103451f2961c51ed9dc30d54087b to your computer and use it in GitHub Desktop.
Save ppkn/4cd4103451f2961c51ed9dc30d54087b to your computer and use it in GitHub Desktop.
Main script for NEAT Zelda reinforcement learning project.
-- ZeldAI by dpipkin
-- Feel free to use this code, but please do not redistribute it.
-- Borrowed heavily from MarI/O by SethBling
-- Intended for use with the BizHawk emulator and The Legend of Zelda (NES)
-- Make sure you have a save state on the overworld opening scene named
-- "zelda0.state" and put a copy in both the Lua folder and the root directory
-- of BizHawk
Filename = "zelda0.state"
ButtonNames = {
"A",
"B",
"Up",
"Down",
"Left",
"Right",
"Start"
}
BoxRadius = 6
InputSize = (BoxRadius * 2 + 1) * (BoxRadius * 2 + 1)
Inputs = InputSize + 1
Outputs = #ButtonNames
MutateConnectionsChance = 0.25
LinkMutationChance = 2.0
BiasMutationChance = 0.40
NodeMutationChance = 0.50
EnableMutationChance = 0.2
DisableMutationChance = 0.4
StepSize = 0.1
PerturbChance = 0.90
DeltaDisjoint = 2.0
DeltaWeights = 0.4
DeltaThreshold = 1.0
CrossoverChance = 0.75
StaleSpecies = 15
Population = 300
TimeoutConstant = 500
MaxNodes = 1000000
function getLinkPosition()
local x = memory.readbyte(0x70)
local y = memory.readbyte(0x84)
linkX = 1
linkY = 1
if (x % 16) > 8 then
linkX = math.floor(linkX + (x / 16)) + 1
else
linkX = math.floor(linkX + x / 16)
end
if (y % 16) > 8 then
linkY = math.floor(linkY + ((y - 64) / 16)) + 1
else
linkY = math.floor(linkY + ((y - 64) / 16))
end
end
function getBGTiles()
memory.usememorydomain('CIRAM (nametables)')
local StartAddr = 0x100
local EndAddr = 0x380
local RowLen = 0x20
local RowOffset = 30
local WalkableCells = {
[0xf3]=true,
[0x24]=true,
[0x26]=true,
[0x74]=true,
[0x68]=true,
[0x84]=true,
[0x76]=true,
[0x70]=true
}
local bgTiles = {}
-- the top-right tile will represent the 2x2 tile group
-- loop through the addresses and put them in 2d matrix
local i = 1
for rowAddr = StartAddr, EndAddr, RowLen * 2 do
bgTiles[i] = {}
local j = 1
for addr = rowAddr, rowAddr + RowOffset, 2 do
tile = memory.readbyte(addr)
if WalkableCells[tile] == true then
bgTiles[i][j] = 0
else
bgTiles[i][j] = 1
end
j = j + 1
end
i = i + 1
end
memory.usememorydomain(mainmemory.getname())
return bgTiles
end
function addEnemyTiles(bgTiles)
EnemyStart = 0x0350
EnemyEnd = 0x0357
xPosOffset = 0x70
yPosOffset = 0x84
enemyNumber = 1
for addr = EnemyStart, EnemyEnd do
enemyType = memory.readbyte(addr)
if enemyType > 0 and enemyType ~= 0x60 then
enemyX = memory.readbyte(enemyNumber + xPosOffset)
enemyY = memory.readbyte(enemyNumber + yPosOffset)
local x = 1
local y = 1
if (enemyX % 16) > 8 then
x = x + math.floor(enemyX / 16) + 1
else
x = x + math.floor(enemyX / 16)
end
if (enemyY % 16) > 8 then
y = y + math.floor((enemyY - 64) / 16) + 1
else
y = y + math.floor((enemyY - 64) / 16)
end
if (x < 1) then
x = 1
elseif (x > 16) then
x = 16
end
if (y < 1) then
y = 1
elseif (y > 11) then
y = 11
end
bgTiles[y][x] = -1
end
enemyNumber = enemyNumber + 1
end
return bgTiles
end
function getInputs()
local inputs = {}
local tiles = addEnemyTiles(getBGTiles())
getLinkPosition()
local tile = 0
for dy = -BoxRadius, BoxRadius do
for dx = -BoxRadius, BoxRadius do
if linkY + dy < 1 or linkY + dy > 11 then
tile = 0
elseif linkX + dx < 1 or linkX + dx > 16 then
tile = 0
else
tile = tiles[linkY + dy][linkX + dx]
end
inputs[#inputs + 1] = tile
end
end
return inputs
end
function sigmoid(x)
return 2 / (1 + math.exp(-4.9 * x)) - 1
end
function newInnovation()
pool.innovation = pool.innovation + 1
return pool.innovation
end
function newPool()
local pool = {}
pool.species = {}
pool.generation = 0
pool.innovation = Outputs
pool.currentSpecies = 1
pool.currentGenome = 1
pool.currentFrame = 0
pool.maxFitness = 0
return pool
end
function newSpecies()
local species = {}
species.topFitness = 0
species.staleness = 0
species.genomes = {}
species.averageFitness = 0
return species
end
function newGenome()
local genome = {}
genome.genes = {}
genome.fitness = 0
genome.adjustedFitness = 0
genome.network = {}
genome.maxneuron = 0
genome.globalRank = 0
genome.mutationRates = {}
genome.mutationRates["connections"] = MutateConnectionsChance
genome.mutationRates["link"] = LinkMutationChance
genome.mutationRates["bias"] = BiasMutationChance
genome.mutationRates["node"] = NodeMutationChance
genome.mutationRates["enable"] = EnableMutationChance
genome.mutationRates["disable"] = DisableMutationChance
genome.mutationRates["step"] = StepSize
return genome
end
function copyGenome(genome)
local genome2 = newGenome()
for g = 1, #genome.genes do
table.insert(genome2.genes, copyGene(genome.genes[g]))
end
genome2.maxneuron = genome.maxneuron
genome2.mutationRates["connections"] = genome.mutationRates["connections"]
genome2.mutationRates["link"] = genome.mutationRates["link"]
genome2.mutationRates["bias"] = genome.mutationRates["bias"]
genome2.mutationRates["node"] = genome.mutationRates["node"]
genome2.mutationRates["enable"] = genome.mutationRates["enable"]
genome2.mutationRates["disable"] = genome.mutationRates["disable"]
return genome2
end
function basicGenome()
local genome = newGenome()
local innovation = 1
genome.maxneuron = Inputs
mutate(genome)
return genome
end
function newGene()
local gene = {}
gene.into = 0
gene.out = 0
gene.weight = 0.0
gene.enabled = true
gene.innovation = 0
return gene
end
function copyGene(gene)
local gene2 = newGene()
gene2.into = gene.into
gene2.out = gene.out
gene2.weight = gene.weight
gene2.enabled = gene.enabled
gene2.innovation = gene.innovation
return gene2
end
function newNeuron()
local neuron = {}
neuron.incoming = {}
neuron.value = 0.0
return neuron
end
function generateNetwork(genome)
local network = {}
network.neurons = {}
for i = 1, Inputs do
network.neurons[i] = newNeuron()
end
for o = 1, Outputs do
network.neurons[MaxNodes + o] = newNeuron()
end
table.sort(genome.genes, function (a, b)
return (a.out < b.out)
end)
for i = 1, #genome.genes do
local gene = genome.genes[i]
if gene.enabled then
if network.neurons[gene.out] == nil then
network.neurons[gene.out] = newNeuron()
end
local neuron = network.neurons[gene.out]
table.insert(neuron.incoming, gene)
if network.neurons[gene.into] == nil then
network.neurons[gene.into] = newNeuron()
end
end
end
genome.network = network
end
function evaluateNetwork(network, inputs)
table.insert(inputs, 1)
if #inputs ~= Inputs then
console.writeline("Incorrect number of neural network inputs.")
return {}
end
for i = 1, Inputs do
network.neurons[i].value = inputs[i]
end
for _, neuron in pairs(network.neurons) do
local sum = 0
for j = 1, #neuron.incoming do
local incoming = neuron.incoming[j]
local other = network.neurons[incoming.into]
sum = sum + incoming.weight * other.value
end
if #neuron.incoming > 0 then
neuron.value = sigmoid(sum)
end
end
local outputs = {}
for o = 1, Outputs do
local button = "P1 " .. ButtonNames[o]
if network.neurons[MaxNodes + o].value > 0 then
outputs[button] = true
else
outputs[button] = false
end
end
return outputs
end
function crossover(g1, g2)
if g2.fitness > g1.fitness then
tempg = g1
g1 = g2
g2 = tempg
end
local child = newGenome()
local innovations2 = {}
for i = 1, #g2.genes do
local gene = g2.genes[i]
innovations2[gene.innovation] = gene
end
for i = 1, #g1.genes do
local gene1 = g1.genes[i]
local gene2 = innovations2[gene1.innovation]
if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
table.insert(child.genes, copyGene(gene2))
else
table.insert(child.genes, copyGene(gene1))
end
end
child.maxneuron = math.max(g1.maxneuron, g2.maxneuron)
for mutation, rate in pairs(g1.mutationRates) do
child.mutationRates[mutation] = rate
end
return child
end
function randomNeuron(genes, nonInput)
local neurons = {}
if not nonInputs then
for i = 1, Inputs do
neurons[i] = true
end
end
for o = 1, Outputs do
neurons[MaxNodes + o] = true
end
for i = 1, #genes do
if (not nonInput) or genes[i].into > Inputs then
neurons[genes[i].into] = true
end
if (not nonInput) or genes[i].out > Inputs then
neurons[genes[i].out] = true
end
end
local count = 0
for _, _ in pairs(neurons) do
count = count + 1
end
local n = math.random(1, count)
for k, v in pairs(neurons) do
n = n - 1
if n == 0 then
return k
end
end
return 0
end
function containsLink(genes, link)
for i = 1, #genes do
local gene = genes[i]
if gene.into == link.into and gene.out == link.out then
return true
end
end
end
function pointMutate(genome)
local step = genome.mutationRates["step"]
for i = 1, #genome.genes do
local gene = genome.genes[i]
if math.random() < PerturbChance then
gene.weight = gene.weight + math.random() * step * 2 - step
else
gene.weight = math.random() * 4 - 2
end
end
end
function linkMutate(genome, forceBias)
local neuron1 = randomNeuron(genome.genes, false)
local neuron2 = randomNeuron(genome.genes, true)
local newLink = newGene()
if neuron1 <= Inputs and neuron2 <= Inputs then
return
end
if neuron2 <- Inputs then
local temp = neuron1
neuron1 = neuron2
neuron2 = temp
end
newLink.into = neuron1
newLink.out = neuron2
if forceBias then
newLink.into = Inputs
end
if containsLink(genome.genes, newLink) then
return
end
newLink.innovation = newInnovation()
newLink.weight = math.random() * 4 - 2
table.insert(genome.genes, newLink)
end
function nodeMutate(genome)
if #genome.genes == 0 then
return
end
genome.maxneuron = genome.maxneuron + 1
local gene = genome.genes[math.random(1, #genome.genes)]
if not gene.enabled then
return
end
gene.enabled = false
local gene1 = copyGene(gene)
gene1.out = genome.maxneuron
gene1.weight = 1.0
gene1.innovation = newInnovation()
gene1.enabled = true
table.insert(genome.genes, gene1)
local gene2 = copyGene(gene)
gene2.into = genome.maxneuron
gene2.innovation = newInnovation()
gene2.enabled = true
table.insert(genome.genes, gene2)
end
function enableDisableMutate(genome, enable)
local candidates = {}
for _, gene in pairs(genome.genes) do
if gene.enabled == not enable then
table.insert(candidates, gene)
end
end
if #candidates == 0 then return end
local gene = candidates[math.random(1, #candidates)]
gene.enabled = not gene.enabled
end
function mutate(genome)
for mutation, rate in pairs(genome.mutationRates) do
if math.random(1, 2) == 1 then
genome.mutationRates[mutation] = 0.95 * rate
else
genome.mutationRates[mutation] = 1.05263 * rate
end
end
if math.random() < genome.mutationRates["connections"] then
pointMutate(genome)
end
local p = genome.mutationRates["link"]
while p > 0 do
if math.random() < p then
linkMutate(genome, false)
end
p = p - 1
end
p = genome.mutationRates["bias"]
while p > 0 do
if math.random() < p then
linkMutate(genome, true)
end
p = p - 1
end
p = genome.mutationRates["node"]
while p > 0 do
if math.random() < p then
nodeMutate(genome)
end
p = p - 1
end
p = genome.mutationRates["enable"]
while p > 0 do
if math.random() < p then
enableDisableMutate(genome, true)
end
p = p - 1
end
p = genome.mutationRates["disable"]
while p > 0 do
if math.random() < p then
enableDisableMutate(genome, false)
end
p = p - 1
end
end
function disjoint(genes1, genes2)
local i1 = {}
for i = 1, #genes1 do
local gene = genes1[i]
i1[gene.innovation] = true
end
local i2 = {}
for i = 1, #genes2 do
local gene = genes2[i]
i1[gene.innovation] = true
end
local disjointGenes = 0
for i = 1, #genes1 do
local gene = genes1[i]
if not i2[gene.innovation] then
disjointGenes = disjointGenes + 1
end
end
for i = 1, #genes2 do
local gene = genes2[i]
if not i1[gene.innovation] then
disjointGenes = disjointGenes + 1
end
end
local n = math.max(#genes1, #genes2)
return disjointGenes / n
end
function weights(genes1, genes2)
local i2 = {}
for i = 1, #genes2 do
local gene = genes2[i]
i2[gene.innovation] = gene
end
local sum = 0
local coincident = 0
for i = 1, #genes1 do
local gene = genes1[i]
if i2[gene.innovation] ~= nil then
local gene2 = i2[gene.innovation]
sum = sum + math.abs(gene.weight - gene2.weight)
coincident = coincident + 1
end
end
return sum / coincident
end
function sameSpecies(genome1, genome2)
local dd = DeltaDisjoint * disjoint(genome1.genes, genome2.genes)
local dw = DeltaWeights * weights(genome1.genes, genome2.genes)
return dd + dw < DeltaThreshold
end
function rankGlobally()
local global = {}
for s = 1, #pool.species do
local species = pool.species[s]
for g = 1, #species.genomes do
table.insert(global, species.genomes[g])
end
end
table.sort(global, function (a, b)
return (a.fitness < b.fitness)
end)
for g = 1, #global do
global[g].globalRank = g
end
end
function calculateAverageFitness(species)
local total = 0
for g = 1, #species.genomes do
local genome = species.genomes[g]
total = total + genome.globalRank
end
species.averageFitness = total / #species.genomes
end
function totalAverageFitness()
local total = 0
for s = 1, #pool.species do
local species = pool.species[s]
total = total + species.averageFitness
end
return total
end
function cullSpecies(cutToOne)
for s = 1, #pool.species do
local species = pool.species[s]
table.sort(species.genomes, function (a, b)
return (a.fitness > b.fitness)
end)
local remaining = math.ceil(#species.genomes / 2)
if cutToOne then
remaining = 1
end
while #species.genomes > remaining do
table.remove(species.genomes)
end
end
end
function breedChild(species)
local child = {}
if math.random() < CrossoverChance then
g1 = species.genomes[math.random(1, #species.genomes)]
g2 = species.genomes[math.random(1, #species.genomes)]
child = crossover(g1, g2)
else
g = species.genomes[math.random(1, #species.genomes)]
child = copyGenome(g)
end
mutate(child)
return child
end
function removeStaleSpecies()
local survived = {}
for s = 1, #pool.species do
local species = pool.species[s]
table.sort(species.genomes, function (a, b)
return (a.fitness > b.fitness)
end)
if species.genomes[1].fitness > species.topFitness then
species.topFitness = species.genomes[1].fitness
species.staleness = 0
else
species.staleness = species.staleness + 1
end
if species.staleness < StaleSpecies or species.topFitness > pool.maxFitness then
table.insert(survived, species)
end
end
pool.species = survived
end
function removeWeakSpecies()
local survived = {}
local sum = totalAverageFitness()
for s = 1, #pool.species do
local species = pool.species[s]
breed = math.floor(species.averageFitness / sum * Population)
if breed >= 1 then
table.insert(survived, species)
end
end
pool.species = survived
end
function addToSpecies(child)
local foundSpecies = false
for s = 1, #pool.species do
local species = pool.species[s]
if not foundSpecies and sameSpecies(child, species.genomes[1]) then
table.insert(species.genomes, child)
foundSpecies = true
end
end
if not foundSpecies then
local childSpecies = newSpecies()
table.insert(childSpecies.genomes, child)
table.insert(pool.species, childSpecies)
end
end
function newGeneration()
cullSpecies(false)
rankGlobally()
removeStaleSpecies()
rankGlobally()
for s = 1, #pool.species do
local species = pool.species[s]
calculateAverageFitness(species)
end
removeWeakSpecies()
local sum = totalAverageFitness()
local children = {}
for s = 1, #pool.species do
local species = pool.species[s]
breed = math.floor(species.averageFitness / sum * Population) - 1
for i = 1, breed do
table.insert(children, breedChild(species))
end
end
cullSpecies(true)
while #children + #pool.species < Population do
local species = pool.species[math.random(1, #pool.species)]
table.insert(children, breedChild(species))
end
for c = 1, #children do
local child = children[c]
addToSpecies(child)
end
pool.generation = pool.generation + 1
writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
end
function initializePool()
pool = newPool()
for i = 1, Population do
basic = basicGenome()
addToSpecies(basic)
end
initializeRun()
end
function clearJoypad()
controller = {}
for b = 1, #ButtonNames do
controller["P1 " .. ButtonNames[b]] = false
end
joypad.set(controller)
end
function initializeRun()
savestate.load(Filename);
numVisited = 0
visited = {}
pool.currentFrame = 0
timeout = TimeoutConstant
clearJoypad()
local species = pool.species[pool.currentSpecies]
local genome = species.genomes[pool.currentGenome]
generateNetwork(genome)
evaluateCurrent()
end
function evaluateCurrent()
local species = pool.species[pool.currentSpecies]
local genome = species.genomes[pool.currentGenome]
inputs = getInputs()
controller = evaluateNetwork(genome.network, inputs)
if controller["P1 Left"] and controller["P1 Right"] then
controller["P1 Left"] = false
controller["P1 Right"] = false
end
if controller["P1 Up"] and controller["P1 Down"] then
controller["P1 Up"] = false
controller["P1 Down"] = false
end
joypad.set(controller)
end
if pool == nil then
initializePool()
end
function nextGenome()
pool.currentGenome = pool.currentGenome + 1
if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
pool.currentGenome = 1
pool.currentSpecies = pool.currentSpecies + 1
if pool.currentSpecies > #pool.species then
newGeneration()
pool.currentSpecies = 1
end
end
end
function fitnessAlreadyMeasured()
local species = pool.species[pool.currentSpecies]
local genome = species.genomes[pool.currentGenome]
return genome.fitness ~= 0
end
function displayGenome(genome)
local network = genome.network
local cells = {}
local i = 1
local cell = {}
for dy = -BoxRadius, BoxRadius do
for dx = -BoxRadius, BoxRadius do
cell = {}
cell.x = 50 + 5 * dx
cell.y = 70 + 5 * dy
cell.value = network.neurons[i].value
cells[i] = cell
i = i + 1
end
end
local biasCell = {}
biasCell.x = 80
biasCell.y = 110
biasCell.value = network.neurons[Inputs].value
cells[Inputs] = biasCell
for o = 1, Outputs do
cell = {}
cell.x = 220
cell.y = 30 + 8 * o
cell.value = network.neurons[MaxNodes + o].value
cells[MaxNodes + o] = cell
local color
if cell.value > 0 then
color = 0xFF0000FF
else
color = 0xFF000000
end
gui.drawText(223, 24 + 8 * o, ButtonNames[o], color, 9)
end
for n, neuron in pairs(network.neurons) do
cell = {}
if n > Inputs and n <= MaxNodes then
cell.x = 140
cell.y = 40
cell.value = neuron.value
cells[n] = cell
end
end
for n = 1, 4 do
for _, gene in pairs(genome.genes) do
if gene.enabled then
local c1 = cells[gene.into]
local c2 = cells[gene.out]
if gene.into > Inputs and gene.into <= MaxNodes then
c1.x = 0.75 * c1.x + 0.25 * c2.x
if c1.x >= c2.x then
c1.x = c1.x - 40
end
if c1.x < 90 then
c1.x = 90
end
if c1.x > 220 then
c1.x = 220
end
c1.y = 0.75 * c1.y + 0.25 * c2.y
end
if gene.out > Inputs and gene.out <= MaxNodes then
c2.x = 0.25 * c1.x + 0.75 * c2.x
if c1.x >= c2.x then
c2.x = c2.x + 40
end
if c2.x < 90 then
c2.x = 90
end
if c2.x > 220 then
c2.x = 220
end
c2.y = 0.25 * c1.y + 0.75 * c2.y
end
end
end
end
gui.drawBox(50 - BoxRadius * 5 - 3, 70 - BoxRadius * 5 - 3, 50 + BoxRadius * 5 + 2, 70 + BoxRadius * 5 + 2, 0xFF000000, 0x80808080)
for n, cell in pairs(cells) do
if n > Inputs or cell.value ~= 0 then
local color = math.floor((cell.value + 1) / 2 * 256)
if color > 255 then color = 255 end
if color < 0 then color = 0 end
local opacity = 0xFF000000
if cell.value == 0 then
opacity = 0x50000000
end
color = opacity + color * 0x10000 + color * 0x100 + color
gui.drawBox(cell.x - 2, cell.y - 2, cell.x + 2, cell.y + 2, opacity, color)
end
end
for _, gene in pairs(genome.genes) do
if gene.enabled then
local c1 = cells[gene.into]
local c2 = cells[gene.out]
local opacity = 0xA0000000
if c1.value == 0 then
opacity = 0x20000000
end
local color = 0x80 - math.floor(math.abs(sigmoid(gene.weight)) * 0x80)
if gene.weight > 0 then
color = opacity + 0x8000 + 0x10000 * color
else
color = opacity + 0x800000 + 0x100 * color
end
gui.drawLine(c1.x + 1, c1.y, c2.x - 3, c2.y, color)
end
end
gui.drawBox(49, 71, 51, 78, 0x00000000, 0x80FF0000)
if forms.ischecked(showMutationRates) then
local pos = 100
for mutation, rate in pairs(genome.mutationRates) do
gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
pos = pos + 8
end
end
end
function writeFile(filename)
local file = io.open(filename, "w")
file:write(pool.generation .. "\n")
file:write(pool.maxFitness .. "\n")
file:write(#pool.species .. "\n")
for n, species in pairs(pool.species) do
file:write(species.topFitness .. "\n")
file:write(species.staleness .. "\n")
file:write(#species.genomes .. "\n")
for m, genome in pairs(species.genomes) do
file:write(genome.fitness .. "\n")
file:write(genome.maxneuron .. "\n")
for mutation, rate in pairs(genome.mutationRates) do
file:write(mutation .. "\n")
file:write(rate .. "\n")
end
file:write("done\n")
file:write(#genome.genes .. "\n")
for l, gene in pairs(genome.genes) do
file:write(gene.into .. " ")
file:write(gene.out .. " ")
file:write(gene.weight .. " ")
file:write(gene.innovation .. " ")
if (gene.enabled) then
file:write("1\n")
else
file:write("0\n")
end
end
end
end
file:close()
end
function savePool()
local filename = forms.gettext(saveLoadFile)
writeFile(filename)
end
function loadFile(filename)
local file = io.open(filename, "r")
pool = newPool()
pool.generation = file:read("*number")
pool.maxFitness = file:read("*number")
forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
local numSpecies = file:read("*number")
for s = 1, numSpecies do
local species = newSpecies()
table.insert(pool.species, species)
species.topFitness = file:read("*number")
species.staleness = file:read("*number")
local numGenomes = file:read("*number")
for g = 1, numGenomes do
local genome = newGenome()
table.insert(species.genomes, genome)
genome.fitness = file:read("*number")
genome.maxneuron = file:read("*number")
local line = file:read("*line")
while line ~= "done" do
genome.mutationRates[line] = file:read("*number")
line = file:read("*line")
end
local numGenes = file:read("*number")
for n = 1, numGenes do
local gene = newGene()
table.insert(genome.genes, gene)
local enabled
gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
if enabled == 0 then
gene.enabled = false
else
gene.enabled = true
end
end
end
end
file:close()
while fitnessAlreadyMeasured() do
nextGenome()
end
initializeRun()
pool.currentFrame = pool.currentFrame + 1
end
function loadPool()
local filename = forms.gettext(saveLoadFile)
loadFile(filename)
end
function playTop()
local maxfitness = 0
local maxs, maxg
for s, species in pairs(pool.species) do
for g, genome in pairs(species.genomes) do
if genome.fitness > maxfitness then
maxfitness = genome.fitness
maxs = s
maxg = g
end
end
end
pool.currentSpecies = maxs
pool.currentGenome = maxg
pool.maxFitness = maxfitness
forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
initializeRun()
pool.currentFrame = pool.currentFrame + 1
return
end
function onExit()
forms.destroy(form)
end
function calcFitness()
local fitness = numVisited * 700 - pool.currentFrame
if fitness == 0 then
fitness = -1
end
return fitness
end
writeFile("temp.pool")
event.onexit(onExit)
form = forms.newform(200, 260, "Fitness")
maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
showNetwork = forms.checkbox(form, "Show Map", 5, 30)
showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
restartButton = forms.button(form, "Restart", initializePool, 5, 77)
saveButton = forms.button(form, "Save", savePool, 5, 102)
loadButton = forms.button(form, "Load", loadPool, 80, 102)
saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
while true do
local backgroundColor = 0xD0FFFFFF
if not forms.ischecked(hideBanner) then
gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
end
local species = pool.species[pool.currentSpecies]
local genome = species.genomes[pool.currentGenome]
if forms.ischecked(showNetwork) then
displayGenome(genome)
end
if pool.currentFrame % 5 == 0 then
evaluateCurrent()
end
joypad.set(controller)
currentScreen = string.format('%x %x', memory.readbyte(0x10), memory.readbyte(0xEB))
if not visited[currentScreen] then
visited[currentScreen] = true
numVisited = numVisited + 1
timeout = TimeoutConstant
end
timeout = timeout - 1
local timeoutBonus = pool.currentFrame / 4
if timeout + timeoutBonus <= 0 then
genome.fitness = calcFitness()
if genome.fitness > pool.maxFitness then
pool.maxFitness = genome.fitness
forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
writeFile("backup." .. pool.generation .. "." ..forms.gettext(saveLoadFile))
end
console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. genome.fitness)
pool.currentSpecies = 1
pool.currentGenome = 1
while fitnessAlreadyMeasured() do
nextGenome()
end
initializeRun()
end
local measured = 0
local total = 0
for _, species in pairs(pool.species) do
for _, genome in pairs(species.genomes) do
total = total + 1
if genome.fitness ~= 0 then
measured = measured + 1
end
end
end
if not forms.ischecked(hideBanner) then
gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured / total * 100) .. "%)", 0xFF000000, 11)
gui.drawText(0, 12, "Fitness: " .. math.floor(numVisited * 700 - pool.currentFrame - (timeout + timeoutBonus) * 2 / 3), 0xFF000000, 11)
gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 11)
end
pool.currentFrame = pool.currentFrame + 1
emu.frameadvance()
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment