Skip to content

Instantly share code, notes, and snippets.

@m-kyle
Created March 15, 2015 00:08
Show Gist options
  • Save m-kyle/ff1e701f0b50501af6b4 to your computer and use it in GitHub Desktop.
Save m-kyle/ff1e701f0b50501af6b4 to your computer and use it in GitHub Desktop.
SethBling's neural network & genetic algorithm for Super Mario
console.clear()
filename = "DP1.state"
boxRadius = 6
buttonNames = {
"A",
"B",
"X",
"Y",
"Up",
"Down",
"Left",
"Right",
}
layerSizes = {100, 20, 10, #buttonNames}
function getTile(dx, dy)
marioX = memory.read_s16_le(0x94)
marioY = memory.read_s16_le(0x96)
x = math.floor((marioX+dx)/16)
y = math.floor((marioY+dy)/16)
return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
end
function getSprites()
local sprites = {}
for slot=0,11 do
local status = memory.readbyte(0x14C8+slot)
if status ~= 0 then
spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
end
end
return sprites
end
function getExtendedSprites()
local extended = {}
for slot=0,11 do
local number = memory.readbyte(0x170B+slot)
if number ~= 0 then
spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
end
end
return extended
end
function getInputs()
marioX = memory.read_s16_le(0x94)
marioY = memory.read_s16_le(0x96)
sprites = getSprites()
extended = getExtendedSprites()
local inputs = {}
for dy=-boxRadius*16,boxRadius*16,16 do
for dx=-boxRadius*16,boxRadius*16,16 do
inputs[#inputs+1] = 0
tile = getTile(dx, dy)
if tile == 1 and marioY+dy < 0x1B0 then
inputs[#inputs] = 1
end
for i = 1,#sprites do
distx = math.abs(sprites[i]["x"] - (marioX+dx))
disty = math.abs(sprites[i]["y"] - (marioY+dy))
if distx < 8 and disty < 8 then
inputs[#inputs] = -1
end
end
for i = 1,#extended do
distx = math.abs(extended[i]["x"] - (marioX+dx))
disty = math.abs(extended[i]["y"] - (marioY+dy))
if distx < 8 and disty < 8 then
inputs[#inputs] = -1
end
end
end
end
mariovx = memory.read_s8(0x7B)
mariovy = memory.read_s8(0x7D)
inputs[#inputs+1] = mariovx / 70
inputs[#inputs+1] = mariovy / 70
return inputs
end
function evaluate(inputs, chromosome)
local layer = {}
local prevLayer = inputs
local c = 1
for i=1,#layerSizes do
layer = {}
for n=1,layerSizes[i] do
layer[n] = 0
end
for m=1,#layer do
for n=1,#prevLayer do
layer[m] = layer[m] + chromosome[c] * prevLayer[n]
c = c + 1
end
layer[m] = math.atan(layer[m] + chromosome[c])
c = c + 1
end
prevLayer = layer
end
return layer
end
function randomChromosome()
local c = {}
local inputs = getInputs()
prevSize = #inputs
for i=1,#layerSizes do
for m=1,layerSizes[i] do
for n=1,prevSize do
if math.random(10)==1 then
c[#c+1] = math.random()*2-1
else
c[#c+1] = math.random()*0.2-0.1
end
end
c[#c+1] = math.random()*2-1
end
prevSize = layerSizes[i]
end
return c
end
function initializeRun()
savestate.load(filename);
rightmost = 0
frame = 0
timeout = 20
end
function crossover(c1, c2)
local c = {["chromosome"] = {}, ["fitness"] = 0}
local pick = true
for i=1,#c1["chromosome"] do
if math.random(#c1["chromosome"]/2) == 1 then
pick = not pick
end
if pick then
c["chromosome"][i] = c1["chromosome"][i]
else
c["chromosome"][i] = c2["chromosome"][i]
end
end
return c
end
function mutate(c)
for i=1,#c["chromosome"] do
if math.random(50) == 1 then
c["chromosome"][i] = math.random()*2-1
end
end
end
function createNewGeneration()
table.sort(pool, function (a,b)
return (a["fitness"] > b["fitness"])
end)
for i=((#pool)/2),(#pool) do
c1 = pool[math.random(#pool/2)]
c2 = pool[math.random(#pool/2)]
pool[i] = crossover(c1, c2)
mutate(pool[i])
end
generation = generation + 1
end
function clearJoypad()
local controller = {}
for b = 1,#buttonNames do
controller["P1 " .. buttonNames[b]] = false
end
joypad.set(controller)
end
function showTop()
clearJoypad()
currentChromosome = 1
initializeRun()
end
function onExit()
forms.destroy(form)
end
event.onexit(onExit)
function connectionCost(chromosome)
local total = 0
for i=1,#chromosome["chromosome"] do
c = chromosome["chromosome"][i]
total = total + c*c
end
return total
end
function initializeSimulation()
pool = {}
for i=1,20 do
pool[i] = {["chromosome"] = randomChromosome(), ["fitness"] = 0}
end
currentChromosome = 1
generation = 0
maxfitness = 0
initializeRun()
end
if pool == nil then
initializeSimulation()
else
forms.settext(maxFitnessLabel, "Top Fitness: " .. math.floor(maxfitness))
end
form = forms.newform(200, 164, "Fitness")
maxFitnessLabel = forms.label(form, "Top Fitness: ", 5, 8)
goButton = forms.button(form, "Show Top", showTop, 5, 30)
goButton = forms.button(form, "Restart", initializeSimulation, 80, 30)
showUI = forms.checkbox(form, "Show Inputs", 5, 52)
inputsLabel = forms.label(form, "Inputs", 5, 74)
showChromosomes = forms.checkbox(form, "Show Map", 5, 96)
while true do
marioX = memory.read_s16_le(0x94)
marioY = memory.read_s16_le(0x96)
timeoutBonus = frame / 4
if timeout + timeoutBonus <= 0 then
fitness = rightmost - frame / 10 - connectionCost(pool[currentChromosome])/10
pool[currentChromosome]["fitness"] = fitness
if fitness > maxfitness then
forms.settext(maxFitnessLabel, "Top Fitness: " .. math.floor(fitness))
maxfitness = fitness
end
console.writeline("Generation " .. generation .. " chromosome " .. currentChromosome .. " fitness: " .. math.floor(fitness))
if currentChromosome == #pool then
createNewGeneration()
currentChromosome = #pool/2+1
else
currentChromosome = currentChromosome + 1
end
initializeRun()
end
if timeout + timeoutBonus > 2 and frame % 5 == 0 then
inputs = getInputs()
outputs = evaluate(inputs, pool[currentChromosome]["chromosome"])
controller = {}
inputsString = ""
for n = 1,#buttonNames do
if outputs[n] > 0 then
controller["P1 " .. buttonNames[n]] = true
inputsString = inputsString .. buttonNames[n]
else
controller["P1 " .. buttonNames[n]] = false
end
end
forms.settext(inputsLabel, inputsString)
end
joypad.set(controller)
if timeout + timeoutBonus <= 2 then
clearJoypad()
end
if marioX > rightmost then
timeout = 20
rightmost = marioX
end
timeout = timeout - 1
frame = frame + 1
if forms.ischecked(showUI) and inputs ~= nil then
layer1x = memory.read_s16_le(0x1A);
layer1y = memory.read_s16_le(0x1C);
for dy = 0,boxRadius*2 do
for dx = 0,boxRadius*2 do
input = inputs[dy*(boxRadius*2+1)+dx+1]
local x = marioX+(dx-boxRadius)*16-layer1x
local y = marioY+(dy-boxRadius)*16-layer1y
if input == -1 then
gui.drawBox(x, y, x+16, y+16, 0xFFFF0000, 0xA0FF0000)
elseif input == 1 then
gui.drawBox(x, y, x+16, y+16, 0xFF00FF00, 0xA000FF00)
else
gui.drawBox(x, y, x+16, y+16, 0xFFFFFF00, 0xA0FFFF00)
end
--gui.drawText(,,string.format("%i", input),0x80FFFFFF, 11)
end
end
local x = marioX-layer1x
local y = marioY-layer1y
gui.drawBox(x, y, x+16, y+32, 0xFF000000)
end
if forms.ischecked(showChromosomes) then
gui.drawBox(0, 3, 201, 3+#pool*3, 0xFFFFFFFF, 0xFFFFFFFF)
for c=1,#pool do
local y = 1+c*3
local size = #pool[c]["chromosome"]
for n=1,size do
if n%math.floor(size/200) == 0 then
local x = 1+n*200/#pool[c]["chromosome"]
v = pool[c]["chromosome"][n]
r = (1-v)/2
g = (v-1)/2
gui.drawLine(x, y, x, y+1, 0xFF000000 + math.floor(r*0xFF)*0x10000 + math.floor(g*0xFF)*0x100)
end
end
end
end
emu.frameadvance();
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment