diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..88d9686 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,5 @@ +[*] +indent_size = 4 +indent_style = space +charset = utf-8 +end_of_line = lf diff --git a/.gitignore b/.gitignore index 7585af4..00b85ff 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ -*.lsmv *.log catchem/ +state/ +crashsave* +*.backup diff --git a/config.lua b/config.lua index 6d727c6..4f2f0ac 100644 --- a/config.lua +++ b/config.lua @@ -1,69 +1,70 @@ -local _M = {} - ---[[ - Change BizhawkDir to your BizHawk directory. ---]] ---_M.BizhawkDir = "C:/Users/mmill/Downloads/BizHawk-2.2/" -_M.BizhawkDir = "X:/B2_BizHawkLab/BizHawk-2.2.2/" - -_M.StateDir = _M.BizhawkDir .. "Lua/SNES/neat-mario/state/" -_M.PoolDir = _M.BizhawkDir .. "Lua/SNES/neat-mario/pool/" - ---[[ - At the moment the first in list will get loaded. - Rearrange for other savestates. (will be redone soon) ---]] -_M.State = { - "DP1.state", -- Donut Plains 1 - "YI1.state", -- Yoshi's Island 1 - "YI2.state", -- Yoshi's Island 2 -} - ---[[ - Start game with specific powerup. - 0 = No powerup - 1 = Mushroom - 2 = Feather - 3 = Flower - Comment out to disable. ---]] -_M.StartPowerup = 0 - -_M.NeatConfig = { ---Filename = "DP1.state", -Filename = _M.PoolDir .. _M.State[1], -Population = 300, -DeltaDisjoint = 2.0, -DeltaWeights = 0.4, -DeltaThreshold = 1.0, -StaleSpecies = 15, -MutateConnectionsChance = 0.25, -PerturbChance = 0.90, -CrossoverChance = 0.75, -LinkMutationChance = 2.0, -NodeMutationChance = 0.50, -BiasMutationChance = 0.40, -StepSize = 0.1, -DisableMutationChance = 0.4, -EnableMutationChance = 0.2, -TimeoutConstant = 20, -MaxNodes = 1000000, -} - -_M.ButtonNames = { - "A", - "B", - "X", - "Y", - "Up", - "Down", - "Left", - "Right", - } - -_M.BoxRadius = 6 -_M.InputSize = (_M.BoxRadius*2+1)*(_M.BoxRadius*2+1) - -_M.Running = false - +local _M = {} + +--[[ + Change script dir to your script directory +--]] +_M.ScriptDir = "/media/removable/Main/user1000/neat-donk" + +_M.StateDir = _M.ScriptDir .. "/state/" +_M.PoolDir = _M.ScriptDir .. "/pool/" + +--[[ + At the moment the first in list will get loaded. + Rearrange for other savestates. (will be redone soon) +--]] +_M.State = { + "PiratePanic.lsmv", +} + +--[[ + Start game with specific powerup. + 0 = No powerup + 1 = Mushroom + 2 = Feather + 3 = Flower + Comment out to disable. +--]] +_M.StartPowerup = 0 + +_M.NeatConfig = { +--Filename = "DP1.state", +Filename = _M.PoolDir .. _M.State[1], +Population = 300, +DeltaDisjoint = 2.0, +DeltaWeights = 0.4, +DeltaThreshold = 1.0, +StaleSpecies = 15, +MutateConnectionsChance = 0.25, +PerturbChance = 0.90, +CrossoverChance = 0.75, +LinkMutationChance = 2.0, +NodeMutationChance = 0.50, +BiasMutationChance = 0.40, +StepSize = 0.1, +DisableMutationChance = 0.4, +EnableMutationChance = 0.2, +TimeoutConstant = 20, +MaxNodes = 1000000, +} + +_M.ButtonNames = { + "B", + "Y", + "Select", + "Start", + "Up", + "Down", + "Left", + "Right", + "A", + "X", + "L", + "R", +} + +_M.BoxRadius = 6 +_M.InputSize = (_M.BoxRadius*2+1)*(_M.BoxRadius*2+1) + +_M.Running = true + return _M \ No newline at end of file diff --git a/donkutil.lua b/donkutil.lua index 496d88b..9897037 100644 --- a/donkutil.lua +++ b/donkutil.lua @@ -1,63 +1,65 @@ +util = require "util" + +FG_COLOR = 0x00ffffff +BG_COLOR = 0x99000000 +TILEDATA_POINTER = 0x7e0098 +TILE_SIZE = 32 +TILE_RADIUS = 4 +SPRITE_BASE = 0x7e0de2 +SOLID_LESS_THAN = 0x7e00a0 +DIDDY_X_VELOCITY = 0x7e0e02 +DIDDY_Y_VELOCITY = 0x7e0e06 +DIXIE_X_VELOCITY = 0x7e0e60 +DIXIE_Y_VELOCITY = 0x7e0e64 +CAMERA_X = 0x7e17ba +CAMERA_Y = 0x7e17c0 +CAMERA_MODE = 0x7e054f +TILE_COLLISION_MATH_POINTER = 0x7e17b2 +VERTICAL_POINTER = 0xc414 +PARTY_X = 0x7e0a2a +PARTY_Y = 0x7e0a2c + count = 0 detailsidx = -1 +jumping = false helddown = false floatmode = false +rulers = true pokemon = false pokecount = 0 showhelp = false locked = false lockdata = nil incsprite = 0 -fgcolor = 0x00ffffff -bgcolor = 0x99000000 -function table_to_string(tbl) - local result = "{" - local keys = {} - for k in pairs(tbl) do - table.insert(keys, k) - end - table.sort(keys) - for _, k in ipairs(keys) do - local v = tbl[k] - if type(v) == "number" and v == 0 then - goto continue - end +party_tile_offset = 0 +party_y_ground = 0 - -- Check the key type (ignore any numerical keys - assume its an array) - if type(k) == "string" then - result = result.."[\""..k.."\"]".."=" - end +last_called = 0 +function set_party_tile_offset (val) + if party_tile_offset_debounce == val then + return + end + local sec, usec = utime() + last_called = sec * 1000000 + usec + party_tile_offset_debounce = val +end - -- Check the value type - if type(v) == "table" then - result = result..table_to_string(v) - elseif type(v) == "boolean" then - result = result..tostring(v) - else - result = result.."\""..v.."\"" - end - result = result..",\n" - ::continue:: - end - -- Remove leading commas from the result - if result ~= "" then - result = result:sub(1, result:len()-1) - end - return result.."}" +function text(x, y, msg) + gui.text(x, y, msg, FG_COLOR, BG_COLOR) end function on_keyhook (key, state) - if not helddown and state["value"] == 1 then + if not helddown and state.value == 1 then if key == "1" and not locked then helddown = true detailsidx = detailsidx - 1 if detailsidx < -1 then - detailsidx = 20 + detailsidx = 22 end elseif key == "2" and not locked then helddown = true detailsidx = detailsidx + 1 - if detailsidx > 20 then + if detailsidx > 22 then detailsidx = -1 end elseif key == "3" then @@ -80,89 +82,90 @@ function on_keyhook (key, state) elseif key == "7" then helddown = true floatmode = not floatmode + elseif key == "8" then + helddown = true + rulers = not rulers elseif key == "0" then showhelp = true end - elseif state["value"] == 0 then + elseif state.value == 0 then helddown = false showhelp = false end end function on_input (subframe) + jumping = input.get(0,0) ~= 0 + if floatmode then memory.writebyte(0x7e19ce, 0x16) memory.writebyte(0x7e0e12, 0x99) memory.writebyte(0x7e0e70, 0x99) if input.get(0, 6) == 1 then - memory.writeword(0x7e0e02, -0x5ff) - memory.writeword(0x7e0e60, -0x5ff) + memory.writeword(DIDDY_X_VELOCITY, -0x5ff) + memory.writeword(DIXIE_X_VELOCITY, -0x5ff) - memory.writeword(0x7e0e06, 0) - memory.writeword(0x7e0e64, 0) + memory.writeword(DIDDY_Y_VELOCITY, 0) + memory.writeword(DIXIE_Y_VELOCITY, 0) elseif input.get(0, 7) == 1 then - memory.writeword(0x7e0e02, 0x5ff) - memory.writeword(0x7e0e60, 0x5ff) + memory.writeword(DIDDY_X_VELOCITY, 0x5ff) + memory.writeword(DIXIE_X_VELOCITY, 0x5ff) - memory.writeword(0x7e0e06, 0) - memory.writeword(0x7e0e64, 0) + memory.writeword(DIDDY_Y_VELOCITY, 0) + memory.writeword(DIXIE_Y_VELOCITY, 0) end if input.get(0, 4) == 1 then - memory.writeword(0x7e0e06, -0x05ff) - memory.writeword(0x7e0e64, -0x05ff) + memory.writeword(DIDDY_Y_VELOCITY, -0x05ff) + memory.writeword(DIXIE_Y_VELOCITY, -0x05ff) elseif input.get(0, 5) == 1 then - memory.writeword(0x7e0e06, 0x5ff) - memory.writeword(0x7e0e64, 0x5ff) + memory.writeword(DIDDY_Y_VELOCITY, 0x5ff) + memory.writeword(DIXIE_Y_VELOCITY, 0x5ff) end end end -function file_exists(name) - local f=io.open(name,"r") - if f~=nil then io.close(f) return true else return false end -end - function get_sprite(base_addr) return { - ["control"] = memory.readword(base_addr), - ["draworder"] = memory.readword(base_addr + 0x02), - ["x"] = memory.readword(base_addr + 0x06), - ["y"] = memory.readword(base_addr + 0x0a), - ["jumpheight"] = memory.readword(base_addr + 0x0e), - ["style"] = memory.readword(base_addr + 0x12), - ["currentframe"] = memory.readword(base_addr + 0x18), - ["nextframe"] = memory.readword(base_addr + 0x1a), - ["state"] = memory.readword(base_addr + 0x1e), - ["velox"] = memory.readsword(base_addr + 0x20), - ["veloy"] = memory.readsword(base_addr + 0x24), - ["velomaxx"] = memory.readsword(base_addr + 0x26), - ["velomaxy"] = memory.readsword(base_addr + 0x2a), - ["motion"] = memory.readword(base_addr + 0x2e), - ["attr"] = memory.readword(base_addr + 0x30), - ["animnum"] = memory.readword(base_addr + 0x36), - ["remainingframe"] = memory.readword(base_addr + 0x38), - ["animcontrol"] = memory.readword(base_addr + 0x3a), - ["animreadpos"] = memory.readword(base_addr + 0x3c), - ["animcontrol2"] = memory.readword(base_addr + 0x3e), - ["animformat"] = memory.readword(base_addr + 0x40), - ["damage1"] = memory.readword(base_addr + 0x44), - ["damage2"] = memory.readword(base_addr + 0x46), - ["damage3"] = memory.readword(base_addr + 0x48), - ["damage4"] = memory.readword(base_addr + 0x4a), - ["damage5"] = memory.readword(base_addr + 0x4c), - ["damage6"] = memory.readword(base_addr + 0x4e), - ["spriteparam"] = memory.readword(base_addr + 0x58), + base_addr = string.format("%04x", base_addr), + control = memory.readword(base_addr), + draworder = memory.readword(base_addr + 0x02), + x = memory.readword(base_addr + 0x06), + y = memory.readword(base_addr + 0x0a), + jumpheight = memory.readword(base_addr + 0x0e), + style = memory.readword(base_addr + 0x12), + currentframe = memory.readword(base_addr + 0x18), + nextframe = memory.readword(base_addr + 0x1a), + state = memory.readword(base_addr + 0x1e), + velox = memory.readsword(base_addr + 0x20), + veloy = memory.readsword(base_addr + 0x24), + velomaxx = memory.readsword(base_addr + 0x26), + velomaxy = memory.readsword(base_addr + 0x2a), + motion = memory.readword(base_addr + 0x2e), + attr = memory.readword(base_addr + 0x30), + animnum = memory.readword(base_addr + 0x36), + remainingframe = memory.readword(base_addr + 0x38), + animcontrol = memory.readword(base_addr + 0x3a), + animreadpos = memory.readword(base_addr + 0x3c), + animcontrol2 = memory.readword(base_addr + 0x3e), + animformat = memory.readword(base_addr + 0x40), + damage1 = memory.readword(base_addr + 0x44), + damage2 = memory.readword(base_addr + 0x46), + damage3 = memory.readword(base_addr + 0x48), + damage4 = memory.readword(base_addr + 0x4a), + damage5 = memory.readword(base_addr + 0x4c), + damage6 = memory.readword(base_addr + 0x4e), + spriteparam = memory.readword(base_addr + 0x58), } end function sprite_details(idx) - local base_addr = idx * 94 + 0x7e0e9e + local base_addr = idx * 94 + SPRITE_BASE local sprite = get_sprite(base_addr) - if sprite["control"] == 0 then - gui.text(0, 0, "Sprite "..idx.." (Empty)", fgcolor, bgcolor) + if sprite.control == 0 then + text(0, 0, "Sprite "..idx.." (Empty)") incsprite = 0 locked = false lockdata = nil @@ -170,7 +173,7 @@ function sprite_details(idx) end if incsprite ~= 0 then - memory.writeword(base_addr + 0x36, sprite["animnum"] + incsprite) + memory.writeword(base_addr + 0x36, sprite.animnum + incsprite) lockdata = nil incsprite = 0 @@ -184,7 +187,7 @@ function sprite_details(idx) memory.writeregion(base_addr, 94, lockdata) end - gui.text(0, 0, "Sprite "..idx..(locked and " (Locked)" or "")..":\n\n"..table_to_string(sprite), fgcolor, bgcolor) + text(0, 0, "Sprite "..idx..(locked and " (Locked)" or "")..":\n\n"..util.table_to_string(sprite)) end function on_paint (not_synth) @@ -193,7 +196,7 @@ function on_paint (not_synth) local guiWidth, guiHeight = gui.resolution() if showhelp then - gui.text(0, 0, [[ + text(0, 0, [[ Keyboard Help =============== @@ -207,70 +210,156 @@ Sprite Details: [6] Enable / Disable Pokemon mode (take screenshots of enemies) [7] Enable / Disable float mode (fly with up/down) -]], fgcolor, bgcolor) +[8] Enable / Disable stage tile rulers +]]) return end - gui.text(guiWidth - 75, 0, "Help [0]", fgcolor, bgcolor) - - local stats = "" + local toggles = "" if pokemon then - stats = stats.."Pokemon: "..pokecount.."\n" + toggles = toggles..string.format("Pokemon: %d\n", pokecount) end if floatmode then - stats = stats.."Float on\n" + toggles = toggles.."Float on\n" end - gui.text(0, guiHeight - 40, stats, fgcolor, bgcolor) + text(0, guiHeight - 40, toggles) - stats = stats.."\nPokemon: "..pokecount + local directions = { + "Standard", + "Blur", + "Up" + } - local cameraX = memory.readword(0x7e17ba) - 256 - local cameraY = memory.readword(0x7e17c0) - 256 + local cameraX = memory.readword(CAMERA_X) - 256 + local cameraY = memory.readword(CAMERA_Y) - 256 + local cameraDir = memory.readbyte(CAMERA_MODE) - local partyScreenX = (memory.readword(0x7e0a2a) - 256 - cameraX) * 2 - local partyScreenY = (memory.readword(0x7e0a2c) - 256 - cameraY) * 2 + local direction = directions[cameraDir+1] - if detailsidx ~= -1 then - sprite_details(detailsidx) - else - gui.text(0, 0, "[1] <- Sprite Details Off -> [2]", fgcolor, bgcolor) - end + local vertical = memory.readword(TILE_COLLISION_MATH_POINTER) == VERTICAL_POINTER - gui.text(guiWidth - 200, guiHeight - 20, "Camera: "..tostring(cameraX)..","..tostring(cameraY), fgcolor, bgcolor) + local stats = string.format([[ +%s camera %d,%d +Vertical: %s +Tile offset: %04x +]], direction, cameraX, cameraY, vertical, party_tile_offset) - gui.text(partyScreenX, partyScreenY, "Party", fgcolor, bgcolor) + text(guiWidth - 200, guiHeight - 60, stats) + + local partyX = memory.readword(PARTY_X) - 256 + local partyY = memory.readword(PARTY_Y) - 256 + + text((partyX - cameraX) * 2, (partyY - cameraY) * 2 + 20, "Party") local sprites = {} - for idx = 0,20,1 do - local base_addr = idx * 94 + 0x7e0e9e + for idx = 0,22,1 do + local base_addr = idx * 94 + SPRITE_BASE local sprite = get_sprite(base_addr) sprites[idx] = sprite - if sprite["control"] == 0 then + if sprite.control == 0 then goto continue end - local spriteScreenX = (sprite["x"] - 256 - cameraX) * 2 - local spriteScreenY = (sprite["y"] - 256 - cameraY) * 2 + local spriteScreenX = (sprite.x - 256 - cameraX) * 2 + local spriteScreenY = (sprite.y - 256 - cameraY) * 2 - local sprcolor = bgcolor + local sprcolor = BG_COLOR if detailsidx == idx then sprcolor = 0x00ff0000 end - gui.text(spriteScreenX, spriteScreenY, sprite["animnum"]..","..sprite["attr"], fgcolor, sprcolor) + gui.text(spriteScreenX, spriteScreenY, sprite.control..","..sprite.animnum..","..sprite.attr, FG_COLOR, sprcolor) - local filename = os.getenv("HOME").."/neat-donk/catchem/"..sprite["animnum"]..","..sprite["attr"]..".png" - if pokemon and spriteScreenX > (guiWidth / 4) and spriteScreenX < (guiWidth / 4) * 3 and spriteScreenY > (guiHeight / 3) and spriteScreenY < guiHeight and not file_exists(filename) then + local filename = os.getenv("HOME").."/neat-donk/catchem/"..sprite.animnum..","..sprite.attr..".png" + if pokemon and spriteScreenX > (guiWidth / 4) and spriteScreenX < (guiWidth / 4) * 3 and spriteScreenY > (guiHeight / 3) and spriteScreenY < guiHeight and not util.file_exists(filename) then gui.screenshot(filename) pokecount = pokecount + 1 end ::continue:: end + + if rulers and cameraX >= 0 then + local halfWidth = math.floor(guiWidth / 2) + local halfHeight = math.floor(guiHeight / 2) + + local cameraTileX = math.floor(cameraX / TILE_SIZE) + gui.line(0, halfHeight, guiWidth, halfHeight, BG_COLOR) + for i = cameraTileX, cameraTileX + guiWidth / TILE_SIZE / 2,1 do + gui.text((i * TILE_SIZE - cameraX) * 2, halfHeight, tostring(i), FG_COLOR, BG_COLOR) + end + + local cameraTileY = math.floor(cameraY / TILE_SIZE) + gui.line(halfWidth, 0, halfWidth, guiHeight, BG_COLOR) + for i = cameraTileY, cameraTileY + guiHeight / TILE_SIZE / 2,1 do + gui.text(halfWidth, (i * TILE_SIZE - cameraY) * 2, tostring(i), FG_COLOR, BG_COLOR) + end + end + + local tilePtr = memory.readhword(TILEDATA_POINTER) + local solidLessThan = memory.readword(SOLID_LESS_THAN) + + for x = -TILE_RADIUS, TILE_RADIUS, 1 do + for y = -TILE_RADIUS, TILE_RADIUS, 1 do + local offset = 0 + if vertical then + offset = party_tile_offset + (y * 24 + x) * 2 + else + offset = party_tile_offset + (x * 16 + y) * 2 + end + + local tile = memory.readword(tilePtr + offset) + + if tile == 0 or tile >= solidLessThan then + goto continue + end + + local tileX = (math.floor(partyX / TILE_SIZE + x) * TILE_SIZE - cameraX) + local tileY = (math.floor(party_y_ground / TILE_SIZE + y) * TILE_SIZE - cameraY) + gui.text(tileX * 2, tileY * 2, string.format("%04x,%02x", offset & 0xffff, tile), FG_COLOR, 0x66888800) + + ::continue:: + end + end + + + + if detailsidx ~= -1 then + sprite_details(detailsidx) + else + text(0, 20, "[1] <- Sprite Details Off -> [2]") + end + + text(guiWidth - 125, 20, "Help [Hold 0]") +end + +function tile_retrieval() + local tile = math.floor(memory.getregister("y") / 2) * 2 + local newX = memory.readword(0x7e00a6) + local partyX = memory.readword(PARTY_X) + local oldX = partyX & 0x1f + local partyY = memory.readword(PARTY_Y) + + if oldX - 5 < newX and newX < oldX + 5 and + not jumping and + memory.readword(0x7e0034) == partyY then + set_party_tile_offset(tile) + party_y_ground = partyY - 256 + end +end + +function on_timer() + local sec, usec = utime() + local now = sec * 1000000 + usec + if last_called + 100 * 1000 < now then + party_tile_offset = party_tile_offset_debounce + end + + set_timer_timeout(100 * 1000) end input.keyhook("1", true) @@ -280,4 +369,9 @@ input.keyhook("4", true) input.keyhook("5", true) input.keyhook("6", true) input.keyhook("7", true) -input.keyhook("0", true) \ No newline at end of file +input.keyhook("8", true) +input.keyhook("0", true) + +memory2.BUS:registerexec(TILE_RETRIEVAL, tile_retrieval) + +set_timer_timeout(100 * 1000) \ No newline at end of file diff --git a/game.lua b/game.lua index 3cf43e0..0fa173f 100644 --- a/game.lua +++ b/game.lua @@ -1,164 +1,161 @@ ---Notes here -config = require "config" -spritelist = require "spritelist" -local _M = {} - -function _M.getPositions() - partyX = memory.readword(0x7e0a2a) - 256 - partyY = memory.readword(0x7e0a2c) - 256 - - local cameraX = memory.readword(0x7e17ba) - 256 - local cameraY = memory.readword(0x7e17c0) - 256 - - _M.screenX = partyX-layer1x - _M.screenY = partyY-layer1y -end - -function _M.getBananas() - local bananas = memory.readword(0x7e08bc) - return bananas -end - -function _M.getCoins() - local coins = memory.readword(0x7e08ca) - return coins -end - -function _M.getLives() - local lives = memory.readsbyte(0x7e08be) + 1 - return lives -end - -function _M.writeLives(lives) - memory.writebyte(0x7e08be, lives - 1) - memory.writebyte(0x7e08c0, lives - 1) -end - -function _M.getPowerup() - return 0 -end - -function _M.writePowerup(powerup) - return - -- memory.writebyte(0x0019, powerup) -end - - -function _M.getHit(alreadyHit) - return not alreadyHit and memory.readbyte(0x7e08be) < memory.readbyte(0x7e08c0) -end - -function _M.getHitTimer() - return memory.readbyte(0x7e08c0) - memory.readbyte(0x7e08be) -end - -function _M.getTile(dx, dy) - local partyScreenX = (partyX - cameraX) * 2 - local partyScreenY = (partyY - cameraY) * 2 - - x = math.floor((partyX+dx+8)/16) - y = math.floor((partyY+dy)/16) - - return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10) -end - -function _M.getSprites() - local sprites = {} - for slot=0,11 do - local status = memory.readbyte(0x14C8+slot) - if status ~= 0 then - spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256 - spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256 - sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey, ["good"] = spritelist.Sprites[memory.readbyte(0x009e + slot) + 1]} - end - end - - return sprites -end - -function _M.getExtendedSprites() - local extended = {} - for slot=0,11 do - local number = memory.readbyte(0x170B+slot) - if number ~= 0 then - spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256 - spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256 - extended[#extended+1] = {["x"]=spritex, ["y"]=spritey, ["good"] = spritelist.extSprites[memory.readbyte(0x170B + slot) + 1]} - end - end - - return extended -end - -function _M.getInputs() - _M.getPositions() - - sprites = _M.getSprites() - extended = _M.getExtendedSprites() - - local inputs = {} - local inputDeltaDistance = {} - - local layer1x = memory.read_s16_le(0x1A); - local layer1y = memory.read_s16_le(0x1C); - - - for dy=-config.BoxRadius*16,config.BoxRadius*16,16 do - for dx=-config.BoxRadius*16,config.BoxRadius*16,16 do - inputs[#inputs+1] = 0 - inputDeltaDistance[#inputDeltaDistance+1] = 1 - - tile = _M.getTile(dx, dy) - if tile == 1 and partyY+dy < 0x1B0 then - inputs[#inputs] = 1 - end - - for i = 1,#sprites do - distx = math.abs(sprites[i]["x"] - (partyX+dx)) - disty = math.abs(sprites[i]["y"] - (partyY+dy)) - if distx <= 8 and disty <= 8 then - inputs[#inputs] = sprites[i]["good"] - - local dist = math.sqrt((distx * distx) + (disty * disty)) - if dist > 8 then - inputDeltaDistance[#inputDeltaDistance] = mathFunctions.squashDistance(dist) - --gui.drawLine(screenX, screenY, sprites[i]["x"] - layer1x, sprites[i]["y"] - layer1y, 0x50000000) - end - end - end - - for i = 1,#extended do - distx = math.abs(extended[i]["x"] - (partyX+dx)) - disty = math.abs(extended[i]["y"] - (partyY+dy)) - if distx < 8 and disty < 8 then - - --console.writeline(screenX .. "," .. screenY .. " to " .. extended[i]["x"]-layer1x .. "," .. extended[i]["y"]-layer1y) - inputs[#inputs] = extended[i]["good"] - local dist = math.sqrt((distx * distx) + (disty * disty)) - if dist > 8 then - inputDeltaDistance[#inputDeltaDistance] = mathFunctions.squashDistance(dist) - --gui.drawLine(screenX, screenY, extended[i]["x"] - layer1x, extended[i]["y"] - layer1y, 0x50000000) - end - --if dist > 100 then - --dw = mathFunctions.squashDistance(dist) - --console.writeline(dist .. " to " .. dw) - --gui.drawLine(screenX, screenY, extended[i]["x"] - layer1x, extended[i]["y"] - layer1y, 0x50000000) - --end - --inputs[#inputs] = {["value"]=-1, ["dw"]=dw} - end - end - end - end - - return inputs, inputDeltaDistance -end - -function _M.clearJoypad() - controller = {} - for b = 1,#config.ButtonNames do - controller["P1 " .. config.ButtonNames[b]] = false - end - joypad.set(controller) -end - -return _M +--Notes here +config = require "config" +spritelist = require "spritelist" +local _M = {} + +function _M.getPositions() + partyX = memory.readword(0x7e0a2a) - 256 + partyY = memory.readword(0x7e0a2c) - 256 + + local cameraX = memory.readword(0x7e17ba) - 256 + local cameraY = memory.readword(0x7e17c0) - 256 + + _M.screenX = (partyX-cameraX)*2 + _M.screenY = (partyY-cameraY)*2 +end + +function _M.getBananas() + local bananas = memory.readword(0x7e08bc) + return bananas +end + +function _M.getCoins() + local coins = memory.readword(0x7e08ca) + return coins +end + +function _M.getLives() + local lives = memory.readsbyte(0x7e08be) + 1 + return lives +end + +function _M.writeLives(lives) + memory.writebyte(0x7e08be, lives - 1) + memory.writebyte(0x7e08c0, lives - 1) +end + +function _M.getPowerup() + return 0 +end + +function _M.writePowerup(powerup) + return + -- memory.writebyte(0x0019, powerup) +end + + +function _M.getHit(alreadyHit) + return not alreadyHit and memory.readbyte(0x7e08be) < memory.readbyte(0x7e08c0) +end + +function _M.getHitTimer() + return memory.readbyte(0x7e08c0) - memory.readbyte(0x7e08be) +end + +function _M.getTile(dx, dy) + local partyScreenX = (partyX - cameraX) * 2 + local partyScreenY = (partyY - cameraY) * 2 + + x = math.floor((partyX+dx+8)/16) + y = math.floor((partyY+dy)/16) + + return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10) +end + +function _M.getSprites() + local sprites = {} + for slot=0,11 do + local status = memory.readbyte(0x14C8+slot) + if status ~= 0 then + spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256 + spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256 + sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey, ["good"] = spritelist.Sprites[memory.readbyte(0x009e + slot) + 1]} + end + end + + return sprites +end + +function _M.getExtendedSprites() + local extended = {} + for slot=0,11 do + local number = memory.readbyte(0x170B+slot) + if number ~= 0 then + spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256 + spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256 + extended[#extended+1] = {["x"]=spritex, ["y"]=spritey, ["good"] = spritelist.extSprites[memory.readbyte(0x170B + slot) + 1]} + end + end + + return extended +end + +function _M.getInputs() + _M.getPositions() + + sprites = _M.getSprites() + extended = _M.getExtendedSprites() + + local inputs = {} + local inputDeltaDistance = {} + + local layer1x = memory.readword(0x7f0000); + local layer1y = memory.read_s16_le(0x1C); + + for dy=-config.BoxRadius*16,config.BoxRadius*16,16 do + for dx=-config.BoxRadius*16,config.BoxRadius*16,16 do + inputs[#inputs+1] = 0 + inputDeltaDistance[#inputDeltaDistance+1] = 1 + + tile = _M.getTile(dx, dy) + if tile == 1 and partyY+dy < 0x1B0 then + inputs[#inputs] = 1 + end + + for i = 1,#sprites do + distx = math.abs(sprites[i]["x"] - (partyX+dx)) + disty = math.abs(sprites[i]["y"] - (partyY+dy)) + if distx <= 8 and disty <= 8 then + inputs[#inputs] = sprites[i]["good"] + + local dist = math.sqrt((distx * distx) + (disty * disty)) + if dist > 8 then + inputDeltaDistance[#inputDeltaDistance] = mathFunctions.squashDistance(dist) + --gui.drawLine(screenX, screenY, sprites[i]["x"] - layer1x, sprites[i]["y"] - layer1y, 0x50000000) + end + end + end + + for i = 1,#extended do + distx = math.abs(extended[i]["x"] - (partyX+dx)) + disty = math.abs(extended[i]["y"] - (partyY+dy)) + if distx < 8 and disty < 8 then + + --console.writeline(screenX .. "," .. screenY .. " to " .. extended[i]["x"]-layer1x .. "," .. extended[i]["y"]-layer1y) + inputs[#inputs] = extended[i]["good"] + local dist = math.sqrt((distx * distx) + (disty * disty)) + if dist > 8 then + inputDeltaDistance[#inputDeltaDistance] = mathFunctions.squashDistance(dist) + --gui.drawLine(screenX, screenY, extended[i]["x"] - layer1x, extended[i]["y"] - layer1y, 0x50000000) + end + --if dist > 100 then + --dw = mathFunctions.squashDistance(dist) + --console.writeline(dist .. " to " .. dw) + --gui.drawLine(screenX, screenY, extended[i]["x"] - layer1x, extended[i]["y"] - layer1y, 0x50000000) + --end + --inputs[#inputs] = {["value"]=-1, ["dw"]=dw} + end + end + end + end + + return inputs, inputDeltaDistance +end + +function _M.clearJoypad() + for b = 1,#config.ButtonNames do + input.set(0, b - 1, 0) + end +end + +return _M diff --git a/mathFunctions.lua b/mathFunctions.lua index e9c9823..8870847 100644 --- a/mathFunctions.lua +++ b/mathFunctions.lua @@ -1,29 +1,29 @@ ---Notes here - -local _M = {} - -function _M.sigmoid(x) - return 2/(1+math.exp(-4.9*x))-1 -end - -function _M.squashDistance(x) - local window = 0.20 - local delta = 0.25 - - local dist = (x-8) - local newDist = 1 - - while dist > 0 do - newDist = newDist - (window*delta) - dist = dist - 1 - end - - if newDist < 0.80 then - newDist = 0.80 - end - - return newDist -end - - +--Notes here + +local _M = {} + +function _M.sigmoid(x) + return 2/(1+math.exp(-4.9*x))-1 +end + +function _M.squashDistance(x) + local window = 0.20 + local delta = 0.25 + + local dist = (x-8) + local newDist = 1 + + while dist > 0 do + newDist = newDist - (window*delta) + dist = dist - 1 + end + + if newDist < 0.80 then + newDist = 0.80 + end + + return newDist +end + + return _M \ No newline at end of file diff --git a/neat-donk.lua b/neat-donk.lua index ef36548..c1fdbcd 100644 --- a/neat-donk.lua +++ b/neat-donk.lua @@ -1,1163 +1,1165 @@ ---Update to Seth-Bling's MarI/O app - -config = require "config" -spritelist = require "spritelist" -game = require "game" -mathFunctions = require "mathFunctions" - -Inputs = config.InputSize+1 -Outputs = #config.ButtonNames - -function newInnovation() - pool.innovation = pool.innovation + 1 - return pool.innovation -end - -function newPool() - local pool = {} - pool.species = {} - pool.generation = 0 - pool.innovation = Outputs - pool.currentSpecies = 1 - pool.currentGenome = 1 - pool.currentFrame = 0 - pool.maxFitness = 0 - - return pool -end - -function newSpecies() - local species = {} - species.topFitness = 0 - species.staleness = 0 - species.genomes = {} - species.averageFitness = 0 - - return species -end - -function newGenome() - local genome = {} - genome.genes = {} - genome.fitness = 0 - genome.adjustedFitness = 0 - genome.network = {} - genome.maxneuron = 0 - genome.globalRank = 0 - genome.mutationRates = {} - genome.mutationRates["connections"] = config.NeatConfig.MutateConnectionsChance - genome.mutationRates["link"] = config.NeatConfig.LinkMutationChance - genome.mutationRates["bias"] = config.NeatConfig.BiasMutationChance - genome.mutationRates["node"] = config.NeatConfig.NodeMutationChance - genome.mutationRates["enable"] = config.NeatConfig.EnableMutationChance - genome.mutationRates["disable"] = config.NeatConfig.DisableMutationChance - genome.mutationRates["step"] = config.NeatConfig.StepSize - - return genome -end - -function copyGenome(genome) - local genome2 = newGenome() - for g=1,#genome.genes do - table.insert(genome2.genes, copyGene(genome.genes[g])) - end - genome2.maxneuron = genome.maxneuron - genome2.mutationRates["connections"] = genome.mutationRates["connections"] - genome2.mutationRates["link"] = genome.mutationRates["link"] - genome2.mutationRates["bias"] = genome.mutationRates["bias"] - genome2.mutationRates["node"] = genome.mutationRates["node"] - genome2.mutationRates["enable"] = genome.mutationRates["enable"] - genome2.mutationRates["disable"] = genome.mutationRates["disable"] - - return genome2 -end - -function basicGenome() - local genome = newGenome() - local innovation = 1 - - genome.maxneuron = Inputs - mutate(genome) - - return genome -end - -function newGene() - local gene = {} - gene.into = 0 - gene.out = 0 - gene.weight = 0.0 - gene.enabled = true - gene.innovation = 0 - - return gene -end - -function copyGene(gene) - local gene2 = newGene() - gene2.into = gene.into - gene2.out = gene.out - gene2.weight = gene.weight - gene2.enabled = gene.enabled - gene2.innovation = gene.innovation - - return gene2 -end - -function newNeuron() - local neuron = {} - neuron.incoming = {} - neuron.value = 0.0 - --neuron.dw = 1 - return neuron -end - -function generateNetwork(genome) - local network = {} - network.neurons = {} - - for i=1,Inputs do - network.neurons[i] = newNeuron() - end - - for o=1,Outputs do - network.neurons[config.NeatConfig.MaxNodes+o] = newNeuron() - end - - table.sort(genome.genes, function (a,b) - return (a.out < b.out) - end) - for i=1,#genome.genes do - local gene = genome.genes[i] - if gene.enabled then - if network.neurons[gene.out] == nil then - network.neurons[gene.out] = newNeuron() - end - local neuron = network.neurons[gene.out] - table.insert(neuron.incoming, gene) - if network.neurons[gene.into] == nil then - network.neurons[gene.into] = newNeuron() - end - end - end - - genome.network = network -end - -function evaluateNetwork(network, inputs, inputDeltas) - table.insert(inputs, 1) - table.insert(inputDeltas,99) - if #inputs ~= Inputs then - console.writeline("Incorrect number of neural network inputs.") - return {} - end - - - for i=1,Inputs do - network.neurons[i].value = inputs[i] * inputDeltas[i] - --network.neurons[i].value = inputs[i] - end - - for _,neuron in pairs(network.neurons) do - local sum = 0 - for j = 1,#neuron.incoming do - local incoming = neuron.incoming[j] - local other = network.neurons[incoming.into] - sum = sum + incoming.weight * other.value - end - - if #neuron.incoming > 0 then - neuron.value = mathFunctions.sigmoid(sum) - end - end - - local outputs = {} - for o=1,Outputs do - local button = "P1 " .. config.ButtonNames[o] - if network.neurons[config.NeatConfig.MaxNodes+o].value > 0 then - outputs[button] = true - else - outputs[button] = false - end - end - - return outputs -end - -function crossover(g1, g2) - -- Make sure g1 is the higher fitness genome - if g2.fitness > g1.fitness then - tempg = g1 - g1 = g2 - g2 = tempg - end - - local child = newGenome() - - local innovations2 = {} - for i=1,#g2.genes do - local gene = g2.genes[i] - innovations2[gene.innovation] = gene - end - - for i=1,#g1.genes do - local gene1 = g1.genes[i] - local gene2 = innovations2[gene1.innovation] - if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then - table.insert(child.genes, copyGene(gene2)) - else - table.insert(child.genes, copyGene(gene1)) - end - end - - child.maxneuron = math.max(g1.maxneuron,g2.maxneuron) - - for mutation,rate in pairs(g1.mutationRates) do - child.mutationRates[mutation] = rate - end - - return child -end - -function randomNeuron(genes, nonInput) - local neurons = {} - if not nonInput then - for i=1,Inputs do - neurons[i] = true - end - end - for o=1,Outputs do - neurons[config.NeatConfig.MaxNodes+o] = true - end - for i=1,#genes do - if (not nonInput) or genes[i].into > Inputs then - neurons[genes[i].into] = true - end - if (not nonInput) or genes[i].out > Inputs then - neurons[genes[i].out] = true - end - end - - local count = 0 - for _,_ in pairs(neurons) do - count = count + 1 - end - local n = math.random(1, count) - - for k,v in pairs(neurons) do - n = n-1 - if n == 0 then - return k - end - end - - return 0 -end - -function containsLink(genes, link) - for i=1,#genes do - local gene = genes[i] - if gene.into == link.into and gene.out == link.out then - return true - end - end -end - -function pointMutate(genome) - local step = genome.mutationRates["step"] - - for i=1,#genome.genes do - local gene = genome.genes[i] - if math.random() < config.NeatConfig.PerturbChance then - gene.weight = gene.weight + math.random() * step*2 - step - else - gene.weight = math.random()*4-2 - end - end -end - -function linkMutate(genome, forceBias) - local neuron1 = randomNeuron(genome.genes, false) - local neuron2 = randomNeuron(genome.genes, true) - - local newLink = newGene() - if neuron1 <= Inputs and neuron2 <= Inputs then - --Both input nodes - return - end - if neuron2 <= Inputs then - -- Swap output and input - local temp = neuron1 - neuron1 = neuron2 - neuron2 = temp - end - - newLink.into = neuron1 - newLink.out = neuron2 - if forceBias then - newLink.into = Inputs - end - - if containsLink(genome.genes, newLink) then - return - end - newLink.innovation = newInnovation() - newLink.weight = math.random()*4-2 - - table.insert(genome.genes, newLink) -end - -function nodeMutate(genome) - if #genome.genes == 0 then - return - end - - genome.maxneuron = genome.maxneuron + 1 - - local gene = genome.genes[math.random(1,#genome.genes)] - if not gene.enabled then - return - end - gene.enabled = false - - local gene1 = copyGene(gene) - gene1.out = genome.maxneuron - gene1.weight = 1.0 - gene1.innovation = newInnovation() - gene1.enabled = true - table.insert(genome.genes, gene1) - - local gene2 = copyGene(gene) - gene2.into = genome.maxneuron - gene2.innovation = newInnovation() - gene2.enabled = true - table.insert(genome.genes, gene2) -end - -function enableDisableMutate(genome, enable) - local candidates = {} - for _,gene in pairs(genome.genes) do - if gene.enabled == not enable then - table.insert(candidates, gene) - end - end - - if #candidates == 0 then - return - end - - local gene = candidates[math.random(1,#candidates)] - gene.enabled = not gene.enabled -end - -function mutate(genome) - for mutation,rate in pairs(genome.mutationRates) do - if math.random(1,2) == 1 then - genome.mutationRates[mutation] = 0.95*rate - else - genome.mutationRates[mutation] = 1.05263*rate - end - end - - if math.random() < genome.mutationRates["connections"] then - pointMutate(genome) - end - - local p = genome.mutationRates["link"] - while p > 0 do - if math.random() < p then - linkMutate(genome, false) - end - p = p - 1 - end - - p = genome.mutationRates["bias"] - while p > 0 do - if math.random() < p then - linkMutate(genome, true) - end - p = p - 1 - end - - p = genome.mutationRates["node"] - while p > 0 do - if math.random() < p then - nodeMutate(genome) - end - p = p - 1 - end - - p = genome.mutationRates["enable"] - while p > 0 do - if math.random() < p then - enableDisableMutate(genome, true) - end - p = p - 1 - end - - p = genome.mutationRates["disable"] - while p > 0 do - if math.random() < p then - enableDisableMutate(genome, false) - end - p = p - 1 - end -end - -function disjoint(genes1, genes2) - local i1 = {} - for i = 1,#genes1 do - local gene = genes1[i] - i1[gene.innovation] = true - end - - local i2 = {} - for i = 1,#genes2 do - local gene = genes2[i] - i2[gene.innovation] = true - end - - local disjointGenes = 0 - for i = 1,#genes1 do - local gene = genes1[i] - if not i2[gene.innovation] then - disjointGenes = disjointGenes+1 - end - end - - for i = 1,#genes2 do - local gene = genes2[i] - if not i1[gene.innovation] then - disjointGenes = disjointGenes+1 - end - end - - local n = math.max(#genes1, #genes2) - - return disjointGenes / n -end - -function weights(genes1, genes2) - local i2 = {} - for i = 1,#genes2 do - local gene = genes2[i] - i2[gene.innovation] = gene - end - - local sum = 0 - local coincident = 0 - for i = 1,#genes1 do - local gene = genes1[i] - if i2[gene.innovation] ~= nil then - local gene2 = i2[gene.innovation] - sum = sum + math.abs(gene.weight - gene2.weight) - coincident = coincident + 1 - end - end - - return sum / coincident -end - -function sameSpecies(genome1, genome2) - local dd = config.NeatConfig.DeltaDisjoint*disjoint(genome1.genes, genome2.genes) - local dw = config.NeatConfig.DeltaWeights*weights(genome1.genes, genome2.genes) - return dd + dw < config.NeatConfig.DeltaThreshold -end - -function rankGlobally() - local global = {} - for s = 1,#pool.species do - local species = pool.species[s] - for g = 1,#species.genomes do - table.insert(global, species.genomes[g]) - end - end - table.sort(global, function (a,b) - return (a.fitness < b.fitness) - end) - - for g=1,#global do - global[g].globalRank = g - end -end - -function calculateAverageFitness(species) - local total = 0 - - for g=1,#species.genomes do - local genome = species.genomes[g] - total = total + genome.globalRank - end - - species.averageFitness = total / #species.genomes -end - -function totalAverageFitness() - local total = 0 - for s = 1,#pool.species do - local species = pool.species[s] - total = total + species.averageFitness - end - - return total -end - -function cullSpecies(cutToOne) - for s = 1,#pool.species do - local species = pool.species[s] - - table.sort(species.genomes, function (a,b) - return (a.fitness > b.fitness) - end) - - local remaining = math.ceil(#species.genomes/2) - if cutToOne then - remaining = 1 - end - while #species.genomes > remaining do - table.remove(species.genomes) - end - end -end - -function breedChild(species) - local child = {} - if math.random() < config.NeatConfig.CrossoverChance then - g1 = species.genomes[math.random(1, #species.genomes)] - g2 = species.genomes[math.random(1, #species.genomes)] - child = crossover(g1, g2) - else - g = species.genomes[math.random(1, #species.genomes)] - child = copyGenome(g) - end - - mutate(child) - - return child -end - -function removeStaleSpecies() - local survived = {} - - for s = 1,#pool.species do - local species = pool.species[s] - - table.sort(species.genomes, function (a,b) - return (a.fitness > b.fitness) - end) - - if species.genomes[1].fitness > species.topFitness then - species.topFitness = species.genomes[1].fitness - species.staleness = 0 - else - species.staleness = species.staleness + 1 - end - if species.staleness < config.NeatConfig.StaleSpecies or species.topFitness >= pool.maxFitness then - table.insert(survived, species) - end - end - - pool.species = survived -end - -function removeWeakSpecies() - local survived = {} - - local sum = totalAverageFitness() - for s = 1,#pool.species do - local species = pool.species[s] - breed = math.floor(species.averageFitness / sum * config.NeatConfig.Population) - if breed >= 1 then - table.insert(survived, species) - end - end - - pool.species = survived -end - - -function addToSpecies(child) - local foundSpecies = false - for s=1,#pool.species do - local species = pool.species[s] - if not foundSpecies and sameSpecies(child, species.genomes[1]) then - table.insert(species.genomes, child) - foundSpecies = true - end - end - - if not foundSpecies then - local childSpecies = newSpecies() - table.insert(childSpecies.genomes, child) - table.insert(pool.species, childSpecies) - end -end - -function newGeneration() - cullSpecies(false) -- Cull the bottom half of each species - rankGlobally() - removeStaleSpecies() - rankGlobally() - for s = 1,#pool.species do - local species = pool.species[s] - calculateAverageFitness(species) - end - removeWeakSpecies() - local sum = totalAverageFitness() - local children = {} - for s = 1,#pool.species do - local species = pool.species[s] - breed = math.floor(species.averageFitness / sum * config.NeatConfig.Population) - 1 - for i=1,breed do - table.insert(children, breedChild(species)) - end - end - cullSpecies(true) -- Cull all but the top member of each species - while #children + #pool.species < config.NeatConfig.Population do - local species = pool.species[math.random(1, #pool.species)] - table.insert(children, breedChild(species)) - end - for c=1,#children do - local child = children[c] - addToSpecies(child) - end - - pool.generation = pool.generation + 1 - - --writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile)) - writeFile(forms.gettext(saveLoadFile) .. ".gen" .. pool.generation .. ".pool") -end - -function initializePool() - pool = newPool() - - for i=1,config.NeatConfig.Population do - basic = basicGenome() - addToSpecies(basic) - end - - initializeRun() -end - -function initializeRun() - savestate.load(config.NeatConfig.Filename); - if config.StartPowerup ~= NIL then - game.writePowerup(config.StartPowerup) - end - rightmost = 0 - pool.currentFrame = 0 - timeout = config.NeatConfig.TimeoutConstant - game.clearJoypad() - startBananas = game.getBananas() - startCoins = game.getCoins() - startLives = game.getLives() - checkMarioCollision = true - marioHitCounter = 0 - powerUpCounter = 0 - powerUpBefore = game.getPowerup() - local species = pool.species[pool.currentSpecies] - local genome = species.genomes[pool.currentGenome] - generateNetwork(genome) - evaluateCurrent() -end - -function evaluateCurrent() - local species = pool.species[pool.currentSpecies] - local genome = species.genomes[pool.currentGenome] - - local inputDeltas = {} - inputs, inputDeltas = game.getInputs() - - controller = evaluateNetwork(genome.network, inputs, inputDeltas) - - if controller["P1 Left"] and controller["P1 Right"] then - controller["P1 Left"] = false - controller["P1 Right"] = false - end - if controller["P1 Up"] and controller["P1 Down"] then - controller["P1 Up"] = false - controller["P1 Down"] = false - end - - joypad.set(controller) -end - -if pool == nil then - initializePool() -end - - -function nextGenome() - pool.currentGenome = pool.currentGenome + 1 - if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then - pool.currentGenome = 1 - pool.currentSpecies = pool.currentSpecies+1 - if pool.currentSpecies > #pool.species then - newGeneration() - pool.currentSpecies = 1 - end - end -end - -function fitnessAlreadyMeasured() - local species = pool.species[pool.currentSpecies] - local genome = species.genomes[pool.currentGenome] - - return genome.fitness ~= 0 -end - -form = forms.newform(500, 500, "Mario-Neat") -netPicture = forms.pictureBox(form, 5, 250,470, 200) - - ---int forms.pictureBox(int formhandle, [int? x = null], [int? y = null], [int? width = null], [int? height = null]) - -function displayGenome(genome) - forms.clear(netPicture,0x80808080) - local network = genome.network - local cells = {} - local i = 1 - local cell = {} - for dy=-config.BoxRadius,config.BoxRadius do - for dx=-config.BoxRadius,config.BoxRadius do - cell = {} - cell.x = 50+5*dx - cell.y = 70+5*dy - cell.value = network.neurons[i].value - cells[i] = cell - i = i + 1 - end - end - local biasCell = {} - biasCell.x = 80 - biasCell.y = 110 - biasCell.value = network.neurons[Inputs].value - cells[Inputs] = biasCell - - for o = 1,Outputs do - cell = {} - cell.x = 220 - cell.y = 30 + 8 * o - cell.value = network.neurons[config.NeatConfig.MaxNodes + o].value - cells[config.NeatConfig.MaxNodes+o] = cell - local color - if cell.value > 0 then - color = 0xFF0000FF - else - color = 0xFF000000 - end - --gui.drawText(223, 24+8*o, config.ButtonNames[o], color, 9) - forms.drawText(netPicture,223, 24+8*o, config.ButtonNames[o], color, 9) - end - - for n,neuron in pairs(network.neurons) do - cell = {} - if n > Inputs and n <= config.NeatConfig.MaxNodes then - cell.x = 140 - cell.y = 40 - cell.value = neuron.value - cells[n] = cell - end - end - - for n=1,4 do - for _,gene in pairs(genome.genes) do - if gene.enabled then - local c1 = cells[gene.into] - local c2 = cells[gene.out] - if gene.into > Inputs and gene.into <= config.NeatConfig.MaxNodes then - c1.x = 0.75*c1.x + 0.25*c2.x - if c1.x >= c2.x then - c1.x = c1.x - 40 - end - if c1.x < 90 then - c1.x = 90 - end - - if c1.x > 220 then - c1.x = 220 - end - c1.y = 0.75*c1.y + 0.25*c2.y - - end - if gene.out > Inputs and gene.out <= config.NeatConfig.MaxNodes then - c2.x = 0.25*c1.x + 0.75*c2.x - if c1.x >= c2.x then - c2.x = c2.x + 40 - end - if c2.x < 90 then - c2.x = 90 - end - if c2.x > 220 then - c2.x = 220 - end - c2.y = 0.25*c1.y + 0.75*c2.y - end - end - end - end - - --gui.drawBox(50-config.BoxRadius*5-3,70-config.BoxRadius*5-3,50+config.BoxRadius*5+2,70+config.BoxRadius*5+2,0xFF000000, 0x80808080) - forms.drawBox(netPicture, 50-config.BoxRadius*5-3,70-config.BoxRadius*5-3,50+config.BoxRadius*5+2,70+config.BoxRadius*5+2,0xFF000000, 0x80808080) - --oid forms.drawBox(int componenthandle, int x, int y, int x2, int y2, [color? line = null], [color? background = null]) - for n,cell in pairs(cells) do - if n > Inputs or cell.value ~= 0 then - local color = math.floor((cell.value+1)/2*256) - if color > 255 then color = 255 end - if color < 0 then color = 0 end - local opacity = 0xFF000000 - if cell.value == 0 then - opacity = 0x50000000 - end - color = opacity + color*0x10000 + color*0x100 + color - forms.drawBox(netPicture,cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color) - --gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color) - end - end - for _,gene in pairs(genome.genes) do - if gene.enabled then - local c1 = cells[gene.into] - local c2 = cells[gene.out] - local opacity = 0xA0000000 - if c1.value == 0 then - opacity = 0x20000000 - end - - local color = 0x80-math.floor(math.abs(mathFunctions.sigmoid(gene.weight))*0x80) - if gene.weight > 0 then - color = opacity + 0x8000 + 0x10000*color - else - color = opacity + 0x800000 + 0x100*color - end - --gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color) - forms.drawLine(netPicture,c1.x+1, c1.y, c2.x-3, c2.y, color) - end - end - - --gui.drawBox(49,71,51,78,0x00000000,0x80FF0000) - forms.drawBox(netPicture, 49,71,51,78,0x00000000,0x80FF0000) - --if forms.ischecked(showMutationRates) then - local pos = 100 - for mutation,rate in pairs(genome.mutationRates) do - --gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10) - forms.drawText(netPicture,100, pos, mutation .. ": " .. rate, 0xFF000000, 10) - --forms.drawText(pictureBox,400,pos, mutation .. ": " .. rate) - - --void forms.drawText(int componenthandle, int x, int y, string message, [color? forecolor = null], [color? backcolor = null], [int? fontsize = null], [string fontfamily = null], [string fontstyle = null], [string horizalign = null], [string vertalign = null]) - - pos = pos + 8 - end - --end - forms.refresh(netPicture) -end - -function writeFile(filename) - local file = io.open(filename, "w") - file:write(pool.generation .. "\n") - file:write(pool.maxFitness .. "\n") - file:write(#pool.species .. "\n") - for n,species in pairs(pool.species) do - file:write(species.topFitness .. "\n") - file:write(species.staleness .. "\n") - file:write(#species.genomes .. "\n") - for m,genome in pairs(species.genomes) do - file:write(genome.fitness .. "\n") - file:write(genome.maxneuron .. "\n") - for mutation,rate in pairs(genome.mutationRates) do - file:write(mutation .. "\n") - file:write(rate .. "\n") - end - file:write("done\n") - - file:write(#genome.genes .. "\n") - for l,gene in pairs(genome.genes) do - file:write(gene.into .. " ") - file:write(gene.out .. " ") - file:write(gene.weight .. " ") - file:write(gene.innovation .. " ") - if(gene.enabled) then - file:write("1\n") - else - file:write("0\n") - end - end - end - end - file:close() -end - -function savePool() - local filename = forms.gettext(saveLoadFile) - print(filename) - writeFile(filename) -end - -function mysplit(inputstr, sep) - if sep == nil then - sep = "%s" - end - local t={} ; i=1 - for str in string.gmatch(inputstr, "([^"..sep.."]+)") do - t[i] = str - i = i + 1 - end - return t -end - -function loadFile(filename) - print("Loading pool from " .. filename) - local file = io.open(filename, "r") - pool = newPool() - pool.generation = file:read("*number") - pool.maxFitness = file:read("*number") - forms.settext(MaxLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) - local numSpecies = file:read("*number") - for s=1,numSpecies do - local species = newSpecies() - table.insert(pool.species, species) - species.topFitness = file:read("*number") - species.staleness = file:read("*number") - local numGenomes = file:read("*number") - for g=1,numGenomes do - local genome = newGenome() - table.insert(species.genomes, genome) - genome.fitness = file:read("*number") - genome.maxneuron = file:read("*number") - local line = file:read("*line") - while line ~= "done" do - - genome.mutationRates[line] = file:read("*number") - line = file:read("*line") - end - local numGenes = file:read("*number") - for n=1,numGenes do - - local gene = newGene() - local enabled - - local geneStr = file:read("*line") - local geneArr = mysplit(geneStr) - gene.into = tonumber(geneArr[1]) - gene.out = tonumber(geneArr[2]) - gene.weight = tonumber(geneArr[3]) - gene.innovation = tonumber(geneArr[4]) - enabled = tonumber(geneArr[5]) - - - if enabled == 0 then - gene.enabled = false - else - gene.enabled = true - end - - table.insert(genome.genes, gene) - end - end - end - file:close() - - while fitnessAlreadyMeasured() do - nextGenome() - end - initializeRun() - pool.currentFrame = pool.currentFrame + 1 - print("Pool loaded.") -end - -function flipState() - if config.Running == true then - config.Running = false - forms.settext(startButton, "Start") - else - config.Running = true - forms.settext(startButton, "Stop") - end -end - -function loadPool() - filename = forms.openfile("DP1.state.pool",config.PoolDir) - --local filename = forms.gettext(saveLoadFile) - forms.settext(saveLoadFile, filename) - loadFile(filename) -end - -function playTop() - local maxfitness = 0 - local maxs, maxg - for s,species in pairs(pool.species) do - for g,genome in pairs(species.genomes) do - if genome.fitness > maxfitness then - maxfitness = genome.fitness - maxs = s - maxg = g - end - end - end - - pool.currentSpecies = maxs - pool.currentGenome = maxg - pool.maxFitness = maxfitness - forms.settext(MaxLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) - initializeRun() - pool.currentFrame = pool.currentFrame + 1 - return -end - -function onExit() - forms.destroy(form) -end - -writeFile(config.PoolDir.."temp.pool") - -event.onexit(onExit) - -GenerationLabel = forms.label(form, "Generation: " .. pool.generation, 5, 5) -SpeciesLabel = forms.label(form, "Species: " .. pool.currentSpecies, 130, 5) -GenomeLabel = forms.label(form, "Genome: " .. pool.currentGenome, 230, 5) -MeasuredLabel = forms.label(form, "Measured: " .. "", 330, 5) - -FitnessLabel = forms.label(form, "Fitness: " .. "", 5, 30) -MaxLabel = forms.label(form, "Max: " .. "", 130, 30) - -BananasLabel = forms.label(form, "Bananas: " .. "", 5, 65) -CoinsLabel = forms.label(form, "Coins: " .. "", 130, 65, 90, 14) -LivesLabel = forms.label(form, "Lives: " .. "", 130, 80, 90, 14) -DmgLabel = forms.label(form, "Damage: " .. "", 230, 65, 110, 14) -PowerUpLabel = forms.label(form, "PowerUp: " .. "", 230, 80, 110, 14) - -startButton = forms.button(form, "Start", flipState, 155, 102) - -restartButton = forms.button(form, "Restart", initializePool, 155, 102) -saveButton = forms.button(form, "Save", savePool, 5, 102) -loadButton = forms.button(form, "Load", loadPool, 80, 102) -playTopButton = forms.button(form, "Play Top", playTop, 230, 102) - -saveLoadFile = forms.textbox(form, config.NeatConfig.Filename .. ".pool", 170, 25, nil, 5, 148) -saveLoadLabel = forms.label(form, "Save/Load:", 5, 129) -spritelist.InitSpriteList() -spritelist.InitExtSpriteList() -while true do - - if config.Running == true then - - local species = pool.species[pool.currentSpecies] - local genome = species.genomes[pool.currentGenome] - - displayGenome(genome) - - if pool.currentFrame%5 == 0 then - evaluateCurrent() - end - - joypad.set(controller) - - game.getPositions() - if partyX > rightmost then - rightmost = partyX - timeout = config.NeatConfig.TimeoutConstant - end - - local hitTimer = game.getHitTimer() - - if checkMarioCollision == true then - if hitTimer > 0 then - marioHitCounter = marioHitCounter + 1 - --console.writeline("Mario took damage, hit counter: " .. marioHitCounter) - checkMarioCollision = false - end - end - - if hitTimer == 0 then - checkMarioCollision = true - end - - powerUp = game.getPowerup() - if powerUp > 0 then - if powerUp ~= powerUpBefore then - powerUpCounter = powerUpCounter+1 - powerUpBefore = powerUp - end - end - - Lives = game.getLives() - - timeout = timeout - 1 - - local timeoutBonus = pool.currentFrame / 4 - if timeout + timeoutBonus <= 0 then - - local bananas = game.getBananas() - startBananas - local coins = game.getCoins() - startCoins - - --console.writeline("Bananas: " .. bananas .. " coins: " .. coins) - - local bananaCoinsFitness = (bananas * 50) + (coins * 0.2) - if (bananas + coins) > 0 then - console.writeline("Bananas and Coins added " .. bananaCoinsFitness .. " fitness") - end - - local hitPenalty = marioHitCounter * 100 - local powerUpBonus = powerUpCounter * 100 - - local fitness = bananaCoinsFitness - hitPenalty + powerUpBonus + rightmost - pool.currentFrame / 2 - - if startLives < Lives then - local ExtraLiveBonus = (Lives - startLives)*1000 - fitness = fitness + ExtraLiveBonus - console.writeline("ExtraLiveBonus added " .. ExtraLiveBonus) - end - - if rightmost > 4816 then - fitness = fitness + 1000 - console.writeline("!!!!!!Beat level!!!!!!!") - end - if fitness == 0 then - fitness = -1 - end - genome.fitness = fitness - - if fitness > pool.maxFitness then - pool.maxFitness = fitness - --writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile)) - writeFile(forms.gettext(saveLoadFile) .. ".gen" .. pool.generation .. ".pool") - end - - console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness) - pool.currentSpecies = 1 - pool.currentGenome = 1 - while fitnessAlreadyMeasured() do - nextGenome() - end - initializeRun() - end - - local measured = 0 - local total = 0 - for _,species in pairs(pool.species) do - for _,genome in pairs(species.genomes) do - total = total + 1 - if genome.fitness ~= 0 then - measured = measured + 1 - end - end - end - - gui.drawEllipse(game.screenX-84, game.screenY-84, 192, 192, 0x50000000) - forms.settext(FitnessLabel, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3)) - forms.settext(GenerationLabel, "Generation: " .. pool.generation) - forms.settext(SpeciesLabel, "Species: " .. pool.currentSpecies) - forms.settext(GenomeLabel, "Genome: " .. pool.currentGenome) - forms.settext(MaxLabel, "Max: " .. math.floor(pool.maxFitness)) - forms.settext(MeasuredLabel, "Measured: " .. math.floor(measured/total*100) .. "%") - forms.settext(BananasLabel, "Bananas: " .. (game.getBananas() - startBananas)) - forms.settext(CoinsLabel, "Coins: " .. (game.getCoins() - startCoins)) - forms.settext(LivesLabel, "Lives: " .. Lives) - forms.settext(DmgLabel, "Damage: " .. marioHitCounter) - forms.settext(PowerUpLabel, "PowerUp: " .. powerUpCounter) - - pool.currentFrame = pool.currentFrame + 1 - - end - emu.frameadvance(); - -end +--Update to Seth-Bling's MarI/O app + +config = require "config" +spritelist = require "spritelist" +game = require "game" +mathFunctions = require "mathFunctions" +util = require "util" + +Inputs = config.InputSize+1 +Outputs = #config.ButtonNames + +function newInnovation() + pool.innovation = pool.innovation + 1 + return pool.innovation +end + +function newPool() + local pool = {} + pool.species = {} + pool.generation = 0 + pool.innovation = Outputs + pool.currentSpecies = 1 + pool.currentGenome = 1 + pool.currentFrame = 0 + pool.maxFitness = 0 + + return pool +end + +function newSpecies() + local species = {} + species.topFitness = 0 + species.staleness = 0 + species.genomes = {} + species.averageFitness = 0 + + return species +end + +function newGenome() + local genome = {} + genome.genes = {} + genome.fitness = 0 + genome.adjustedFitness = 0 + genome.network = {} + genome.maxneuron = 0 + genome.globalRank = 0 + genome.mutationRates = {} + genome.mutationRates["connections"] = config.NeatConfig.MutateConnectionsChance + genome.mutationRates["link"] = config.NeatConfig.LinkMutationChance + genome.mutationRates["bias"] = config.NeatConfig.BiasMutationChance + genome.mutationRates["node"] = config.NeatConfig.NodeMutationChance + genome.mutationRates["enable"] = config.NeatConfig.EnableMutationChance + genome.mutationRates["disable"] = config.NeatConfig.DisableMutationChance + genome.mutationRates["step"] = config.NeatConfig.StepSize + + return genome +end + +function copyGenome(genome) + local genome2 = newGenome() + for g=1,#genome.genes do + table.insert(genome2.genes, copyGene(genome.genes[g])) + end + genome2.maxneuron = genome.maxneuron + genome2.mutationRates["connections"] = genome.mutationRates["connections"] + genome2.mutationRates["link"] = genome.mutationRates["link"] + genome2.mutationRates["bias"] = genome.mutationRates["bias"] + genome2.mutationRates["node"] = genome.mutationRates["node"] + genome2.mutationRates["enable"] = genome.mutationRates["enable"] + genome2.mutationRates["disable"] = genome.mutationRates["disable"] + + return genome2 +end + +function basicGenome() + local genome = newGenome() + local innovation = 1 + + genome.maxneuron = Inputs + mutate(genome) + + return genome +end + +function newGene() + local gene = {} + gene.into = 0 + gene.out = 0 + gene.weight = 0.0 + gene.enabled = true + gene.innovation = 0 + + return gene +end + +function copyGene(gene) + local gene2 = newGene() + gene2.into = gene.into + gene2.out = gene.out + gene2.weight = gene.weight + gene2.enabled = gene.enabled + gene2.innovation = gene.innovation + + return gene2 +end + +function newNeuron() + local neuron = {} + neuron.incoming = {} + neuron.value = 0.0 + --neuron.dw = 1 + return neuron +end + +function generateNetwork(genome) + local network = {} + network.neurons = {} + + for i=1,Inputs do + network.neurons[i] = newNeuron() + end + + for o=1,Outputs do + network.neurons[config.NeatConfig.MaxNodes+o] = newNeuron() + end + + table.sort(genome.genes, function (a,b) + return (a.out < b.out) + end) + for i=1,#genome.genes do + local gene = genome.genes[i] + if gene.enabled then + if network.neurons[gene.out] == nil then + network.neurons[gene.out] = newNeuron() + end + local neuron = network.neurons[gene.out] + table.insert(neuron.incoming, gene) + if network.neurons[gene.into] == nil then + network.neurons[gene.into] = newNeuron() + end + end + end + + genome.network = network +end + +function evaluateNetwork(network, inputs, inputDeltas) + table.insert(inputs, 1) + table.insert(inputDeltas,99) + if #inputs ~= Inputs then + console.writeline("Incorrect number of neural network inputs.") + return {} + end + + + for i=1,Inputs do + network.neurons[i].value = inputs[i] * inputDeltas[i] + --network.neurons[i].value = inputs[i] + end + + for _,neuron in pairs(network.neurons) do + local sum = 0 + for j = 1,#neuron.incoming do + local incoming = neuron.incoming[j] + local other = network.neurons[incoming.into] + sum = sum + incoming.weight * other.value + end + + if #neuron.incoming > 0 then + neuron.value = mathFunctions.sigmoid(sum) + end + end + + local outputs = {} + for o=1,Outputs do + local button = "P1 " .. config.ButtonNames[o] + if network.neurons[config.NeatConfig.MaxNodes+o].value > 0 then + outputs[button] = true + else + outputs[button] = false + end + end + + return outputs +end + +function crossover(g1, g2) + -- Make sure g1 is the higher fitness genome + if g2.fitness > g1.fitness then + tempg = g1 + g1 = g2 + g2 = tempg + end + + local child = newGenome() + + local innovations2 = {} + for i=1,#g2.genes do + local gene = g2.genes[i] + innovations2[gene.innovation] = gene + end + + for i=1,#g1.genes do + local gene1 = g1.genes[i] + local gene2 = innovations2[gene1.innovation] + if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then + table.insert(child.genes, copyGene(gene2)) + else + table.insert(child.genes, copyGene(gene1)) + end + end + + child.maxneuron = math.max(g1.maxneuron,g2.maxneuron) + + for mutation,rate in pairs(g1.mutationRates) do + child.mutationRates[mutation] = rate + end + + return child +end + +function randomNeuron(genes, nonInput) + local neurons = {} + if not nonInput then + for i=1,Inputs do + neurons[i] = true + end + end + for o=1,Outputs do + neurons[config.NeatConfig.MaxNodes+o] = true + end + for i=1,#genes do + if (not nonInput) or genes[i].into > Inputs then + neurons[genes[i].into] = true + end + if (not nonInput) or genes[i].out > Inputs then + neurons[genes[i].out] = true + end + end + + local count = 0 + for _,_ in pairs(neurons) do + count = count + 1 + end + local n = math.random(1, count) + + for k,v in pairs(neurons) do + n = n-1 + if n == 0 then + return k + end + end + + return 0 +end + +function containsLink(genes, link) + for i=1,#genes do + local gene = genes[i] + if gene.into == link.into and gene.out == link.out then + return true + end + end +end + +function pointMutate(genome) + local step = genome.mutationRates["step"] + + for i=1,#genome.genes do + local gene = genome.genes[i] + if math.random() < config.NeatConfig.PerturbChance then + gene.weight = gene.weight + math.random() * step*2 - step + else + gene.weight = math.random()*4-2 + end + end +end + +function linkMutate(genome, forceBias) + local neuron1 = randomNeuron(genome.genes, false) + local neuron2 = randomNeuron(genome.genes, true) + + local newLink = newGene() + if neuron1 <= Inputs and neuron2 <= Inputs then + --Both input nodes + return + end + if neuron2 <= Inputs then + -- Swap output and input + local temp = neuron1 + neuron1 = neuron2 + neuron2 = temp + end + + newLink.into = neuron1 + newLink.out = neuron2 + if forceBias then + newLink.into = Inputs + end + + if containsLink(genome.genes, newLink) then + return + end + newLink.innovation = newInnovation() + newLink.weight = math.random()*4-2 + + table.insert(genome.genes, newLink) +end + +function nodeMutate(genome) + if #genome.genes == 0 then + return + end + + genome.maxneuron = genome.maxneuron + 1 + + local gene = genome.genes[math.random(1,#genome.genes)] + if not gene.enabled then + return + end + gene.enabled = false + + local gene1 = copyGene(gene) + gene1.out = genome.maxneuron + gene1.weight = 1.0 + gene1.innovation = newInnovation() + gene1.enabled = true + table.insert(genome.genes, gene1) + + local gene2 = copyGene(gene) + gene2.into = genome.maxneuron + gene2.innovation = newInnovation() + gene2.enabled = true + table.insert(genome.genes, gene2) +end + +function enableDisableMutate(genome, enable) + local candidates = {} + for _,gene in pairs(genome.genes) do + if gene.enabled == not enable then + table.insert(candidates, gene) + end + end + + if #candidates == 0 then + return + end + + local gene = candidates[math.random(1,#candidates)] + gene.enabled = not gene.enabled +end + +function mutate(genome) + for mutation,rate in pairs(genome.mutationRates) do + if math.random(1,2) == 1 then + genome.mutationRates[mutation] = 0.95*rate + else + genome.mutationRates[mutation] = 1.05263*rate + end + end + + if math.random() < genome.mutationRates["connections"] then + pointMutate(genome) + end + + local p = genome.mutationRates["link"] + while p > 0 do + if math.random() < p then + linkMutate(genome, false) + end + p = p - 1 + end + + p = genome.mutationRates["bias"] + while p > 0 do + if math.random() < p then + linkMutate(genome, true) + end + p = p - 1 + end + + p = genome.mutationRates["node"] + while p > 0 do + if math.random() < p then + nodeMutate(genome) + end + p = p - 1 + end + + p = genome.mutationRates["enable"] + while p > 0 do + if math.random() < p then + enableDisableMutate(genome, true) + end + p = p - 1 + end + + p = genome.mutationRates["disable"] + while p > 0 do + if math.random() < p then + enableDisableMutate(genome, false) + end + p = p - 1 + end +end + +function disjoint(genes1, genes2) + local i1 = {} + for i = 1,#genes1 do + local gene = genes1[i] + i1[gene.innovation] = true + end + + local i2 = {} + for i = 1,#genes2 do + local gene = genes2[i] + i2[gene.innovation] = true + end + + local disjointGenes = 0 + for i = 1,#genes1 do + local gene = genes1[i] + if not i2[gene.innovation] then + disjointGenes = disjointGenes+1 + end + end + + for i = 1,#genes2 do + local gene = genes2[i] + if not i1[gene.innovation] then + disjointGenes = disjointGenes+1 + end + end + + local n = math.max(#genes1, #genes2) + + return disjointGenes / n +end + +function weights(genes1, genes2) + local i2 = {} + for i = 1,#genes2 do + local gene = genes2[i] + i2[gene.innovation] = gene + end + + local sum = 0 + local coincident = 0 + for i = 1,#genes1 do + local gene = genes1[i] + if i2[gene.innovation] ~= nil then + local gene2 = i2[gene.innovation] + sum = sum + math.abs(gene.weight - gene2.weight) + coincident = coincident + 1 + end + end + + return sum / coincident +end + +function sameSpecies(genome1, genome2) + local dd = config.NeatConfig.DeltaDisjoint*disjoint(genome1.genes, genome2.genes) + local dw = config.NeatConfig.DeltaWeights*weights(genome1.genes, genome2.genes) + return dd + dw < config.NeatConfig.DeltaThreshold +end + +function rankGlobally() + local global = {} + for s = 1,#pool.species do + local species = pool.species[s] + for g = 1,#species.genomes do + table.insert(global, species.genomes[g]) + end + end + table.sort(global, function (a,b) + return (a.fitness < b.fitness) + end) + + for g=1,#global do + global[g].globalRank = g + end +end + +function calculateAverageFitness(species) + local total = 0 + + for g=1,#species.genomes do + local genome = species.genomes[g] + total = total + genome.globalRank + end + + species.averageFitness = total / #species.genomes +end + +function totalAverageFitness() + local total = 0 + for s = 1,#pool.species do + local species = pool.species[s] + total = total + species.averageFitness + end + + return total +end + +function cullSpecies(cutToOne) + for s = 1,#pool.species do + local species = pool.species[s] + + table.sort(species.genomes, function (a,b) + return (a.fitness > b.fitness) + end) + + local remaining = math.ceil(#species.genomes/2) + if cutToOne then + remaining = 1 + end + while #species.genomes > remaining do + table.remove(species.genomes) + end + end +end + +function breedChild(species) + local child = {} + if math.random() < config.NeatConfig.CrossoverChance then + g1 = species.genomes[math.random(1, #species.genomes)] + g2 = species.genomes[math.random(1, #species.genomes)] + child = crossover(g1, g2) + else + g = species.genomes[math.random(1, #species.genomes)] + child = copyGenome(g) + end + + mutate(child) + + return child +end + +function removeStaleSpecies() + local survived = {} + + for s = 1,#pool.species do + local species = pool.species[s] + + table.sort(species.genomes, function (a,b) + return (a.fitness > b.fitness) + end) + + if species.genomes[1].fitness > species.topFitness then + species.topFitness = species.genomes[1].fitness + species.staleness = 0 + else + species.staleness = species.staleness + 1 + end + if species.staleness < config.NeatConfig.StaleSpecies or species.topFitness >= pool.maxFitness then + table.insert(survived, species) + end + end + + pool.species = survived +end + +function removeWeakSpecies() + local survived = {} + + local sum = totalAverageFitness() + for s = 1,#pool.species do + local species = pool.species[s] + breed = math.floor(species.averageFitness / sum * config.NeatConfig.Population) + if breed >= 1 then + table.insert(survived, species) + end + end + + pool.species = survived +end + + +function addToSpecies(child) + local foundSpecies = false + for s=1,#pool.species do + local species = pool.species[s] + if not foundSpecies and sameSpecies(child, species.genomes[1]) then + table.insert(species.genomes, child) + foundSpecies = true + end + end + + if not foundSpecies then + local childSpecies = newSpecies() + table.insert(childSpecies.genomes, child) + table.insert(pool.species, childSpecies) + end +end + +function newGeneration() + cullSpecies(false) -- Cull the bottom half of each species + rankGlobally() + removeStaleSpecies() + rankGlobally() + for s = 1,#pool.species do + local species = pool.species[s] + calculateAverageFitness(species) + end + removeWeakSpecies() + local sum = totalAverageFitness() + local children = {} + for s = 1,#pool.species do + local species = pool.species[s] + breed = math.floor(species.averageFitness / sum * config.NeatConfig.Population) - 1 + for i=1,breed do + table.insert(children, breedChild(species)) + end + end + cullSpecies(true) -- Cull all but the top member of each species + while #children + #pool.species < config.NeatConfig.Population do + local species = pool.species[math.random(1, #pool.species)] + table.insert(children, breedChild(species)) + end + for c=1,#children do + local child = children[c] + addToSpecies(child) + end + + pool.generation = pool.generation + 1 + + --writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile)) + writeFile(forms.gettext(saveLoadFile) .. ".gen" .. pool.generation .. ".pool") +end + +function initializePool() + pool = newPool() + + for i=1,config.NeatConfig.Population do + basic = basicGenome() + addToSpecies(basic) + end + + initializeRun() +end + +function initializeRun() + print("Hello") + print(config.NeatConfig.Filename) + local rew = movie.to_rewind(config.NeatConfig.Filename) + movie.unsafe_rewind(rew) + if config.StartPowerup ~= NIL then + game.writePowerup(config.StartPowerup) + end + rightmost = 0 + pool.currentFrame = 0 + timeout = config.NeatConfig.TimeoutConstant + game.clearJoypad() + startBananas = game.getBananas() + startCoins = game.getCoins() + startLives = game.getLives() + checkMarioCollision = true + marioHitCounter = 0 + powerUpCounter = 0 + powerUpBefore = game.getPowerup() + local species = pool.species[pool.currentSpecies] + local genome = species.genomes[pool.currentGenome] + generateNetwork(genome) + evaluateCurrent() +end + +function evaluateCurrent() + local species = pool.species[pool.currentSpecies] + local genome = species.genomes[pool.currentGenome] + + local inputDeltas = {} + inputs, inputDeltas = game.getInputs() + + controller = evaluateNetwork(genome.network, inputs, inputDeltas) + + if controller["P1 Left"] and controller["P1 Right"] then + controller["P1 Left"] = false + controller["P1 Right"] = false + end + if controller["P1 Up"] and controller["P1 Down"] then + controller["P1 Up"] = false + controller["P1 Down"] = false + end + + joypad.set(controller) +end + +if pool == nil then + initializePool() +end + +function nextGenome() + pool.currentGenome = pool.currentGenome + 1 + if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then + pool.currentGenome = 1 + pool.currentSpecies = pool.currentSpecies+1 + if pool.currentSpecies > #pool.species then + newGeneration() + pool.currentSpecies = 1 + end + end +end + +function fitnessAlreadyMeasured() + local species = pool.species[pool.currentSpecies] + local genome = species.genomes[pool.currentGenome] + + return genome.fitness ~= 0 +end + +form = forms.newform(500, 500, "Mario-Neat") +netPicture = forms.pictureBox(form, 5, 250,470, 200) + +--int forms.pictureBox(int formhandle, [int? x = null], [int? y = null], [int? width = null], [int? height = null]) + +function displayGenome(genome) + forms.clear(netPicture,0x80808080) + local network = genome.network + local cells = {} + local i = 1 + local cell = {} + for dy=-config.BoxRadius,config.BoxRadius do + for dx=-config.BoxRadius,config.BoxRadius do + cell = {} + cell.x = 50+5*dx + cell.y = 70+5*dy + cell.value = network.neurons[i].value + cells[i] = cell + i = i + 1 + end + end + local biasCell = {} + biasCell.x = 80 + biasCell.y = 110 + biasCell.value = network.neurons[Inputs].value + cells[Inputs] = biasCell + + for o = 1,Outputs do + cell = {} + cell.x = 220 + cell.y = 30 + 8 * o + cell.value = network.neurons[config.NeatConfig.MaxNodes + o].value + cells[config.NeatConfig.MaxNodes+o] = cell + local color + if cell.value > 0 then + color = 0xFF0000FF + else + color = 0xFF000000 + end + --gui.drawText(223, 24+8*o, config.ButtonNames[o], color, 9) + forms.drawText(netPicture,223, 24+8*o, config.ButtonNames[o], color, 9) + end + + for n,neuron in pairs(network.neurons) do + cell = {} + if n > Inputs and n <= config.NeatConfig.MaxNodes then + cell.x = 140 + cell.y = 40 + cell.value = neuron.value + cells[n] = cell + end + end + + for n=1,4 do + for _,gene in pairs(genome.genes) do + if gene.enabled then + local c1 = cells[gene.into] + local c2 = cells[gene.out] + if gene.into > Inputs and gene.into <= config.NeatConfig.MaxNodes then + c1.x = 0.75*c1.x + 0.25*c2.x + if c1.x >= c2.x then + c1.x = c1.x - 40 + end + if c1.x < 90 then + c1.x = 90 + end + + if c1.x > 220 then + c1.x = 220 + end + c1.y = 0.75*c1.y + 0.25*c2.y + + end + if gene.out > Inputs and gene.out <= config.NeatConfig.MaxNodes then + c2.x = 0.25*c1.x + 0.75*c2.x + if c1.x >= c2.x then + c2.x = c2.x + 40 + end + if c2.x < 90 then + c2.x = 90 + end + if c2.x > 220 then + c2.x = 220 + end + c2.y = 0.25*c1.y + 0.75*c2.y + end + end + end + end + + --gui.drawBox(50-config.BoxRadius*5-3,70-config.BoxRadius*5-3,50+config.BoxRadius*5+2,70+config.BoxRadius*5+2,0xFF000000, 0x80808080) + forms.drawBox(netPicture, 50-config.BoxRadius*5-3,70-config.BoxRadius*5-3,50+config.BoxRadius*5+2,70+config.BoxRadius*5+2,0xFF000000, 0x80808080) + --oid forms.drawBox(int componenthandle, int x, int y, int x2, int y2, [color? line = null], [color? background = null]) + for n,cell in pairs(cells) do + if n > Inputs or cell.value ~= 0 then + local color = math.floor((cell.value+1)/2*256) + if color > 255 then color = 255 end + if color < 0 then color = 0 end + local opacity = 0xFF000000 + if cell.value == 0 then + opacity = 0x50000000 + end + color = opacity + color*0x10000 + color*0x100 + color + forms.drawBox(netPicture,cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color) + --gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color) + end + end + for _,gene in pairs(genome.genes) do + if gene.enabled then + local c1 = cells[gene.into] + local c2 = cells[gene.out] + local opacity = 0xA0000000 + if c1.value == 0 then + opacity = 0x20000000 + end + + local color = 0x80-math.floor(math.abs(mathFunctions.sigmoid(gene.weight))*0x80) + if gene.weight > 0 then + color = opacity + 0x8000 + 0x10000*color + else + color = opacity + 0x800000 + 0x100*color + end + --gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color) + forms.drawLine(netPicture,c1.x+1, c1.y, c2.x-3, c2.y, color) + end + end + + --gui.drawBox(49,71,51,78,0x00000000,0x80FF0000) + forms.drawBox(netPicture, 49,71,51,78,0x00000000,0x80FF0000) + --if forms.ischecked(showMutationRates) then + local pos = 100 + for mutation,rate in pairs(genome.mutationRates) do + --gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10) + forms.drawText(netPicture,100, pos, mutation .. ": " .. rate, 0xFF000000, 10) + --forms.drawText(pictureBox,400,pos, mutation .. ": " .. rate) + + --void forms.drawText(int componenthandle, int x, int y, string message, [color? forecolor = null], [color? backcolor = null], [int? fontsize = null], [string fontfamily = null], [string fontstyle = null], [string horizalign = null], [string vertalign = null]) + + pos = pos + 8 + end + --end + forms.refresh(netPicture) +end + +function writeFile(filename) + local file = io.open(filename, "w") + file:write(pool.generation .. "\n") + file:write(pool.maxFitness .. "\n") + file:write(#pool.species .. "\n") + for n,species in pairs(pool.species) do + file:write(species.topFitness .. "\n") + file:write(species.staleness .. "\n") + file:write(#species.genomes .. "\n") + for m,genome in pairs(species.genomes) do + file:write(genome.fitness .. "\n") + file:write(genome.maxneuron .. "\n") + for mutation,rate in pairs(genome.mutationRates) do + file:write(mutation .. "\n") + file:write(rate .. "\n") + end + file:write("done\n") + + file:write(#genome.genes .. "\n") + for l,gene in pairs(genome.genes) do + file:write(gene.into .. " ") + file:write(gene.out .. " ") + file:write(gene.weight .. " ") + file:write(gene.innovation .. " ") + if(gene.enabled) then + file:write("1\n") + else + file:write("0\n") + end + end + end + end + file:close() +end + +function savePool() + local filename = forms.gettext(saveLoadFile) + print(filename) + writeFile(filename) +end + +function mysplit(inputstr, sep) + if sep == nil then + sep = "%s" + end + local t={} ; i=1 + for str in string.gmatch(inputstr, "([^"..sep.."]+)") do + t[i] = str + i = i + 1 + end + return t +end + +function loadFile(filename) + print("Loading pool from " .. filename) + local file = io.open(filename, "r") + pool = newPool() + pool.generation = file:read("*number") + pool.maxFitness = file:read("*number") + forms.settext(MaxLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) + local numSpecies = file:read("*number") + for s=1,numSpecies do + local species = newSpecies() + table.insert(pool.species, species) + species.topFitness = file:read("*number") + species.staleness = file:read("*number") + local numGenomes = file:read("*number") + for g=1,numGenomes do + local genome = newGenome() + table.insert(species.genomes, genome) + genome.fitness = file:read("*number") + genome.maxneuron = file:read("*number") + local line = file:read("*line") + while line ~= "done" do + + genome.mutationRates[line] = file:read("*number") + line = file:read("*line") + end + local numGenes = file:read("*number") + for n=1,numGenes do + + local gene = newGene() + local enabled + + local geneStr = file:read("*line") + local geneArr = mysplit(geneStr) + gene.into = tonumber(geneArr[1]) + gene.out = tonumber(geneArr[2]) + gene.weight = tonumber(geneArr[3]) + gene.innovation = tonumber(geneArr[4]) + enabled = tonumber(geneArr[5]) + + + if enabled == 0 then + gene.enabled = false + else + gene.enabled = true + end + + table.insert(genome.genes, gene) + end + end + end + file:close() + + while fitnessAlreadyMeasured() do + nextGenome() + end + initializeRun() + pool.currentFrame = pool.currentFrame + 1 + print("Pool loaded.") +end + +function flipState() + if config.Running == true then + config.Running = false + forms.settext(startButton, "Start") + else + config.Running = true + forms.settext(startButton, "Stop") + end +end + +function loadPool() + filename = forms.openfile("DP1.state.pool",config.PoolDir) + --local filename = forms.gettext(saveLoadFile) + forms.settext(saveLoadFile, filename) + loadFile(filename) +end + +function playTop() + local maxfitness = 0 + local maxs, maxg + for s,species in pairs(pool.species) do + for g,genome in pairs(species.genomes) do + if genome.fitness > maxfitness then + maxfitness = genome.fitness + maxs = s + maxg = g + end + end + end + + pool.currentSpecies = maxs + pool.currentGenome = maxg + pool.maxFitness = maxfitness + forms.settext(MaxLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) + initializeRun() + pool.currentFrame = pool.currentFrame + 1 + return +end + +function onExit() + forms.destroy(form) +end + +writeFile(config.PoolDir.."temp.pool") + +event.onexit(onExit) + +GenerationLabel = forms.label(form, "Generation: " .. pool.generation, 5, 5) +SpeciesLabel = forms.label(form, "Species: " .. pool.currentSpecies, 130, 5) +GenomeLabel = forms.label(form, "Genome: " .. pool.currentGenome, 230, 5) +MeasuredLabel = forms.label(form, "Measured: " .. "", 330, 5) + +FitnessLabel = forms.label(form, "Fitness: " .. "", 5, 30) +MaxLabel = forms.label(form, "Max: " .. "", 130, 30) + +BananasLabel = forms.label(form, "Bananas: " .. "", 5, 65) +CoinsLabel = forms.label(form, "Coins: " .. "", 130, 65, 90, 14) +LivesLabel = forms.label(form, "Lives: " .. "", 130, 80, 90, 14) +DmgLabel = forms.label(form, "Damage: " .. "", 230, 65, 110, 14) +PowerUpLabel = forms.label(form, "PowerUp: " .. "", 230, 80, 110, 14) + +startButton = forms.button(form, "Start", flipState, 155, 102) + +restartButton = forms.button(form, "Restart", initializePool, 155, 102) +saveButton = forms.button(form, "Save", savePool, 5, 102) +loadButton = forms.button(form, "Load", loadPool, 80, 102) +playTopButton = forms.button(form, "Play Top", playTop, 230, 102) + +saveLoadFile = forms.textbox(form, config.NeatConfig.Filename .. ".pool", 170, 25, nil, 5, 148) +saveLoadLabel = forms.label(form, "Save/Load:", 5, 129) +spritelist.InitSpriteList() +spritelist.InitExtSpriteList() +while true do + + if config.Running == true then + + local species = pool.species[pool.currentSpecies] + local genome = species.genomes[pool.currentGenome] + + displayGenome(genome) + + if pool.currentFrame%5 == 0 then + evaluateCurrent() + end + + joypad.set(controller) + + game.getPositions() + if partyX > rightmost then + rightmost = partyX + timeout = config.NeatConfig.TimeoutConstant + end + + local hitTimer = game.getHitTimer() + + if checkMarioCollision == true then + if hitTimer > 0 then + marioHitCounter = marioHitCounter + 1 + --console.writeline("Mario took damage, hit counter: " .. marioHitCounter) + checkMarioCollision = false + end + end + + if hitTimer == 0 then + checkMarioCollision = true + end + + powerUp = game.getPowerup() + if powerUp > 0 then + if powerUp ~= powerUpBefore then + powerUpCounter = powerUpCounter+1 + powerUpBefore = powerUp + end + end + + Lives = game.getLives() + + timeout = timeout - 1 + + local timeoutBonus = pool.currentFrame / 4 + if timeout + timeoutBonus <= 0 then + + local bananas = game.getBananas() - startBananas + local coins = game.getCoins() - startCoins + + --console.writeline("Bananas: " .. bananas .. " coins: " .. coins) + + local bananaCoinsFitness = (bananas * 50) + (coins * 0.2) + if (bananas + coins) > 0 then + console.writeline("Bananas and Coins added " .. bananaCoinsFitness .. " fitness") + end + + local hitPenalty = marioHitCounter * 100 + local powerUpBonus = powerUpCounter * 100 + + local fitness = bananaCoinsFitness - hitPenalty + powerUpBonus + rightmost - pool.currentFrame / 2 + + if startLives < Lives then + local ExtraLiveBonus = (Lives - startLives)*1000 + fitness = fitness + ExtraLiveBonus + console.writeline("ExtraLiveBonus added " .. ExtraLiveBonus) + end + + if rightmost > 4816 then + fitness = fitness + 1000 + console.writeline("!!!!!!Beat level!!!!!!!") + end + if fitness == 0 then + fitness = -1 + end + genome.fitness = fitness + + if fitness > pool.maxFitness then + pool.maxFitness = fitness + --writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile)) + writeFile(forms.gettext(saveLoadFile) .. ".gen" .. pool.generation .. ".pool") + end + + console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness) + pool.currentSpecies = 1 + pool.currentGenome = 1 + while fitnessAlreadyMeasured() do + nextGenome() + end + initializeRun() + end + + local measured = 0 + local total = 0 + for _,species in pairs(pool.species) do + for _,genome in pairs(species.genomes) do + total = total + 1 + if genome.fitness ~= 0 then + measured = measured + 1 + end + end + end + + gui.drawEllipse(game.screenX-84, game.screenY-84, 192, 192, 0x50000000) + forms.settext(FitnessLabel, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3)) + forms.settext(GenerationLabel, "Generation: " .. pool.generation) + forms.settext(SpeciesLabel, "Species: " .. pool.currentSpecies) + forms.settext(GenomeLabel, "Genome: " .. pool.currentGenome) + forms.settext(MaxLabel, "Max: " .. math.floor(pool.maxFitness)) + forms.settext(MeasuredLabel, "Measured: " .. math.floor(measured/total*100) .. "%") + forms.settext(BananasLabel, "Bananas: " .. (game.getBananas() - startBananas)) + forms.settext(CoinsLabel, "Coins: " .. (game.getCoins() - startCoins)) + forms.settext(LivesLabel, "Lives: " .. Lives) + forms.settext(DmgLabel, "Damage: " .. marioHitCounter) + forms.settext(PowerUpLabel, "PowerUp: " .. powerUpCounter) + + pool.currentFrame = pool.currentFrame + 1 + + end + emu.frameadvance(); + +end \ No newline at end of file diff --git a/pool/PiratePanic.lsmv b/pool/PiratePanic.lsmv new file mode 100644 index 0000000..816fa22 Binary files /dev/null and b/pool/PiratePanic.lsmv differ diff --git a/util.lua b/util.lua new file mode 100644 index 0000000..b3ad09c --- /dev/null +++ b/util.lua @@ -0,0 +1,44 @@ +local _M = {} + +function _M.table_to_string(tbl) + local result = "{" + local keys = {} + for k in pairs(tbl) do + table.insert(keys, k) + end + table.sort(keys) + for _, k in ipairs(keys) do + local v = tbl[k] + if type(v) == "number" and v == 0 then + goto continue + end + + -- Check the key type (ignore any numerical keys - assume its an array) + if type(k) == "string" then + result = result.."[\""..k.."\"]".."=" + end + + -- Check the value type + if type(v) == "table" then + result = result..table_to_string(v) + elseif type(v) == "boolean" then + result = result..tostring(v) + else + result = result.."\""..v.."\"" + end + result = result..",\n" + ::continue:: + end + -- Remove leading commas from the result + if result ~= "" then + result = result:sub(1, result:len()-1) + end + return result.."}" +end + +function _M.file_exists(name) + local f=io.open(name,"r") + if f~=nil then io.close(f) return true else return false end +end + +return _M \ No newline at end of file