local callback, set_timer_timeout, zip = callback, set_timer_timeout, zip local base = string.gsub(@@LUA_SCRIPT_FILENAME@@, "(.*[/\\])(.*)", "%1") local Promise = dofile(base.."/promise.lua") local config = dofile(base.."/config.lua") local util = dofile(base.."/util.lua")(Promise) local serpent = dofile(base.."/serpent.lua") local zzlib = dofile(base.."/zzlib.lua") local hasThreads = config.NeatConfig.Threads > 1 -- Only the parent should manage ticks! callback.register('timer', function() Promise.update() set_timer_timeout(1) end) set_timer_timeout(1) local warn = '========== The ROM file to use comes from config.lua.' io.stderr:write(warn) print(warn) local Runner = nil if hasThreads then Runner = dofile(base.."/runner-wrapper.lua") else Runner = dofile(base.."/runner.lua") end local Inputs = config.InputSize+1 local Outputs = #config.ButtonNames local _M = { saveLoadFile = config.NeatConfig.SaveFile, onMessageHandler = {}, onRenderFormHandler = {}, } local pool = nil local function message(msg, color) if color == nil then color = 0x00009900 end for i=#_M.onMessageHandler,1,-1 do _M.onMessageHandler[i](msg, color) end end local function newGenome() local genome = {} genome.genes = {} genome.fitness = 0 genome.adjustedFitness = 0 genome.network = {} genome.maxneuron = 0 genome.globalRank = 0 genome.mutationRates = {} genome.mutationRates["connections"] = config.NeatConfig.MutateConnectionsChance genome.mutationRates["link"] = config.NeatConfig.LinkMutationChance genome.mutationRates["bias"] = config.NeatConfig.BiasMutationChance genome.mutationRates["node"] = config.NeatConfig.NodeMutationChance genome.mutationRates["enable"] = config.NeatConfig.EnableMutationChance genome.mutationRates["disable"] = config.NeatConfig.DisableMutationChance genome.mutationRates["step"] = config.NeatConfig.StepSize return genome end local function randomNeuron(genes, nonInput) local neurons = {} if not nonInput then for i=1,Inputs do neurons[i] = true end end for o=1,Outputs do neurons[config.NeatConfig.MaxNodes+o] = true end for i=1,#genes do if (not nonInput) or genes[i].into > Inputs then neurons[genes[i].into] = true end if (not nonInput) or genes[i].out > Inputs then neurons[genes[i].out] = true end end local count = 0 for _,_ in pairs(neurons) do count = count + 1 end local n = math.random(1, count) for k,v in pairs(neurons) do n = n-1 if n == 0 then return k end end return 0 end local function newGene() local gene = {} gene.into = 0 gene.out = 0 gene.weight = 0.0 gene.enabled = true gene.innovation = 0 return gene end local function containsLink(genes, link) for i=1,#genes do local gene = genes[i] if gene.into == link.into and gene.out == link.out then return true end end end local function newInnovation() pool.innovation = pool.innovation + 1 return pool.innovation end local function linkMutate(genome, forceBias) local neuron1 = randomNeuron(genome.genes, false) local neuron2 = randomNeuron(genome.genes, true) local newLink = newGene() if neuron1 <= Inputs and neuron2 <= Inputs then --Both input nodes return end if neuron2 <= Inputs then -- Swap output and input local temp = neuron1 neuron1 = neuron2 neuron2 = temp end newLink.into = neuron1 newLink.out = neuron2 if forceBias then newLink.into = Inputs end if containsLink(genome.genes, newLink) then return end newLink.innovation = newInnovation() newLink.weight = math.random()*4-2 table.insert(genome.genes, newLink) end local function copyGene(gene) local gene2 = newGene() gene2.into = gene.into gene2.out = gene.out gene2.weight = gene.weight gene2.enabled = gene.enabled gene2.innovation = gene.innovation return gene2 end local function nodeMutate(genome) if #genome.genes == 0 then return end genome.maxneuron = genome.maxneuron + 1 local gene = genome.genes[math.random(1,#genome.genes)] if not gene.enabled then return end gene.enabled = false local gene1 = copyGene(gene) gene1.out = genome.maxneuron gene1.weight = 1.0 gene1.innovation = newInnovation() gene1.enabled = true table.insert(genome.genes, gene1) local gene2 = copyGene(gene) gene2.into = genome.maxneuron gene2.innovation = newInnovation() gene2.enabled = true table.insert(genome.genes, gene2) end local function pointMutate(genome) local step = genome.mutationRates["step"] for i=1,#genome.genes do local gene = genome.genes[i] if math.random() < config.NeatConfig.PerturbChance then gene.weight = gene.weight + math.random() * step*2 - step else gene.weight = math.random()*4-2 end end end local function enableDisableMutate(genome, enable) local candidates = {} for _,gene in pairs(genome.genes) do if gene.enabled == not enable then table.insert(candidates, gene) end end if #candidates == 0 then return end local gene = candidates[math.random(1,#candidates)] gene.enabled = not gene.enabled end local function mutate(genome) for mutation,rate in pairs(genome.mutationRates) do if math.random(1,2) == 1 then genome.mutationRates[mutation] = 0.95*rate else genome.mutationRates[mutation] = 1.05263*rate end end if math.random() < genome.mutationRates["connections"] then pointMutate(genome) end local p = genome.mutationRates["link"] while p > 0 do if math.random() < p then linkMutate(genome, false) end p = p - 1 end p = genome.mutationRates["bias"] while p > 0 do if math.random() < p then linkMutate(genome, true) end p = p - 1 end p = genome.mutationRates["node"] while p > 0 do if math.random() < p then nodeMutate(genome) end p = p - 1 end p = genome.mutationRates["enable"] while p > 0 do if math.random() < p then enableDisableMutate(genome, true) end p = p - 1 end p = genome.mutationRates["disable"] while p > 0 do if math.random() < p then enableDisableMutate(genome, false) end p = p - 1 end end local function basicGenome() local genome = newGenome() local innovation = 1 genome.maxneuron = Inputs mutate(genome) return genome end local function newPool() local pool = {} pool.speciesId = 1 pool.species = {} pool.generation = 0 pool.innovation = Outputs pool.maxFitness = 0 return pool end local function newSpecies() local species = {} species.id = pool.speciesId pool.speciesId = pool.speciesId + 1 species.topFitness = 0 species.staleness = 0 species.genomes = {} species.averageFitness = 0 return species end local function disjoint(genes1, genes2) local i1 = {} for i = 1,#genes1 do local gene = genes1[i] i1[gene.innovation] = true end local i2 = {} for i = 1,#genes2 do local gene = genes2[i] i2[gene.innovation] = true end local disjointGenes = 0 for i = 1,#genes1 do local gene = genes1[i] if not i2[gene.innovation] then disjointGenes = disjointGenes+1 end end for i = 1,#genes2 do local gene = genes2[i] if not i1[gene.innovation] then disjointGenes = disjointGenes+1 end end local n = math.max(#genes1, #genes2) return disjointGenes / n end local function weights(genes1, genes2) local i2 = {} for i = 1,#genes2 do local gene = genes2[i] i2[gene.innovation] = gene end local sum = 0 local coincident = 0 for i = 1,#genes1 do local gene = genes1[i] if i2[gene.innovation] ~= nil then local gene2 = i2[gene.innovation] sum = sum + math.abs(gene.weight - gene2.weight) coincident = coincident + 1 end end return sum / coincident end local function sameSpecies(genome1, genome2) local dd = config.NeatConfig.DeltaDisjoint*disjoint(genome1.genes, genome2.genes) local dw = config.NeatConfig.DeltaWeights*weights(genome1.genes, genome2.genes) return dd + dw < config.NeatConfig.DeltaThreshold end local function addToSpecies(child) local foundSpecies = false for s=1,#pool.species do local species = pool.species[s] if not foundSpecies and sameSpecies(child, species.genomes[1]) then table.insert(species.genomes, child) foundSpecies = true end end if not foundSpecies then local childSpecies = newSpecies() table.insert(childSpecies.genomes, child) table.insert(pool.species, childSpecies) end end local function initializePool() return util.promiseWrap(function() pool = newPool() for i=1,config.NeatConfig.Population do local basic = basicGenome() addToSpecies(basic) end end) end local function bytes(x) local b4=x%256 x=(x-x%256)/256 local b3=x%256 x=(x-x%256)/256 local b2=x%256 x=(x-x%256)/256 local b1=x%256 x=(x-x%256)/256 return string.char(b1,b2,b3,b4) end --- Saves the pool to a gzipped Serpent file ---@param filename string Filename to write ---@return Promise Promise A promise that resolves when the file is saved. local function writeFile(filename) return util.promiseWrap(function () local file = zip.writer.new(filename) file:create_file('data.serpent') local dump = serpent.dump(pool) file:write(dump) file:close_file() file:commit() end) end -- FIXME This isn't technically asynchronous. Probably can't be though. local function loadFile(filename) return util.promiseWrap(function() message("Loading pool from " .. filename, 0x00999900) local file = io.open(filename, "r") if file == nil then message("File could not be loaded", 0x00990000) return end local contents = file:read("*all") file:close() local decomp = zzlib.unzip(contents, 'data.serpent') local ok, obj = serpent.load(decomp) if not ok then message("Error parsing pool file", 0x00990000) return end pool = obj end) end local function savePool() local filename = _M.saveLoadFile return writeFile(filename):next(function() message(string.format("Saved \"%s\"!", filename:sub(#filename - 50)), 0x00009900) end) end local function loadPool() return loadFile(_M.saveLoadFile) end local function processRenderForm(form) for i=#_M.onRenderFormHandler,1,-1 do _M.onRenderFormHandler[i](form) end end local function copyGenome(genome) local genome2 = newGenome() for g=1,#genome.genes do table.insert(genome2.genes, copyGene(genome.genes[g])) end genome2.maxneuron = genome.maxneuron genome2.mutationRates["connections"] = genome.mutationRates["connections"] genome2.mutationRates["link"] = genome.mutationRates["link"] genome2.mutationRates["bias"] = genome.mutationRates["bias"] genome2.mutationRates["node"] = genome.mutationRates["node"] genome2.mutationRates["enable"] = genome.mutationRates["enable"] genome2.mutationRates["disable"] = genome.mutationRates["disable"] return genome2 end local function crossover(g1, g2) -- Make sure g1 is the higher fitness genome if g2.fitness > g1.fitness then local tempg = g1 g1 = g2 g2 = tempg end local child = newGenome() local innovations2 = {} for i=1,#g2.genes do local gene = g2.genes[i] innovations2[gene.innovation] = gene end for i=1,#g1.genes do local gene1 = g1.genes[i] local gene2 = innovations2[gene1.innovation] if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then table.insert(child.genes, copyGene(gene2)) else table.insert(child.genes, copyGene(gene1)) end end child.maxneuron = math.max(g1.maxneuron,g2.maxneuron) for mutation,rate in pairs(g1.mutationRates) do child.mutationRates[mutation] = rate end return child end local function rankGlobally() local global = {} for s = 1,#pool.species do local species = pool.species[s] for g = 1,#species.genomes do table.insert(global, species.genomes[g]) end end table.sort(global, function (a,b) return (a.fitness < b.fitness) end) for g=1,#global do global[g].globalRank = g end end local function calculateAverageFitness(species) local total = 0 for g=1,#species.genomes do local genome = species.genomes[g] total = total + genome.globalRank end species.averageFitness = total / #species.genomes end local function totalAverageFitness() local total = 0 for s = 1,#pool.species do local species = pool.species[s] total = total + species.averageFitness end return total end local function cullSpecies(cutToOne) for s = 1,#pool.species do local species = pool.species[s] table.sort(species.genomes, function (a,b) return (a.fitness > b.fitness) end) local remaining = math.ceil(#species.genomes/2) if cutToOne then remaining = 1 end while #species.genomes > remaining do table.remove(species.genomes) end end end local function breedChild(species) local child = {} if math.random() < config.NeatConfig.CrossoverChance then local g1 = species.genomes[math.random(1, #species.genomes)] local g2 = species.genomes[math.random(1, #species.genomes)] child = crossover(g1, g2) else local g = species.genomes[math.random(1, #species.genomes)] child = copyGenome(g) end mutate(child) return child end local function removeStaleSpecies() local survived = {} for s = 1,#pool.species do local species = pool.species[s] table.sort(species.genomes, function (a,b) return (a.fitness > b.fitness) end) if species.genomes[1].fitness > species.topFitness then species.topFitness = species.genomes[1].fitness species.staleness = 0 else species.staleness = species.staleness + 1 end if species.staleness < config.NeatConfig.StaleSpecies or species.topFitness >= pool.maxFitness then table.insert(survived, species) end end pool.species = survived end local function removeWeakSpecies() local survived = {} local sum = totalAverageFitness() for s = 1,#pool.species do local species = pool.species[s] local breed = math.floor(species.averageFitness / sum * config.NeatConfig.Population) if breed >= 1 then table.insert(survived, species) end end pool.species = survived end local function newGeneration() cullSpecies(false) -- Cull the bottom half of each species rankGlobally() removeStaleSpecies() rankGlobally() for s = 1,#pool.species do local species = pool.species[s] calculateAverageFitness(species) end removeWeakSpecies() local sum = totalAverageFitness() local children = {} for s = 1,#pool.species do local species = pool.species[s] local breed = math.floor(species.averageFitness / sum * config.NeatConfig.Population) - 1 for i=1,breed do table.insert(children, breedChild(species)) end end cullSpecies(true) -- Cull all but the top member of each species while #children + #pool.species < config.NeatConfig.Population do local species = pool.species[math.random(1, #pool.species)] table.insert(children, breedChild(species)) end for c=1,#children do local child = children[c] addToSpecies(child) end table.sort(pool.species, function(a,b) return (#a.genomes < #b.genomes) end) pool.generation = pool.generation + 1 return writeFile(_M.saveLoadFile .. ".gen" .. pool.generation .. ".pool"):next(function() if config.NeatConfig.AutoSave then return writeFile(_M.saveLoadFile) end end) end local function reset() return _M.run(true) end local runner = Runner(Promise) runner.onMessage(function(msg, color) message(msg, color) end) runner.onSave(function(filename) _M.requestSave(filename) end) runner.onLoad(function(filename) _M.requestLoad(filename) end) runner.onReset(function(filename) _M.requestReset() end) runner.onRenderForm(function(form) processRenderForm(form) end) local playTop = nil local topRequested = false local loadRequested = config.NeatConfig.AutoSave local saveRequested = false local resetRequested = false local function mainLoop(currentSpecies, topGenome) if currentSpecies == nil then currentSpecies = 1 end local slice = pool.species[currentSpecies] return util.promiseWrap(function() if loadRequested then loadRequested = false currentSpecies = nil -- FIXME return loadPool() end if saveRequested then saveRequested = false return savePool() end if resetRequested then resetRequested = false return reset() end if topRequested then topRequested = false return playTop() end if not config.Running then -- FIXME Tick? end if hasThreads then slice = pool.species end return runner.run( slice, pool.generation, function() -- Genome callback -- FIXME Should we do something here??? What was your plan, past me? end ):next(function(maxFitness) if maxFitness > pool.maxFitness then pool.maxFitness = maxFitness end if hasThreads then currentSpecies = currentSpecies + #slice else currentSpecies = currentSpecies + 1 end if currentSpecies > #pool.species then newGeneration() currentSpecies = 1 end end) end):next(function () if topGenome == nil then return mainLoop(currentSpecies) end end) end playTop = function() local maxfitness = 0 local maxs, maxg for s,species in pairs(pool.species) do for g,genome in pairs(species.genomes) do if genome.fitness > maxfitness then maxfitness = genome.fitness maxs = s maxg = g end end end -- FIXME genome return mainLoop(maxs, maxg) end function _M.requestLoad(filename) _M.saveLoadFile = filename loadRequested = true end function _M.requestSave(filename) _M.saveLoadFile = filename saveRequested = true end function _M.requestReset() resetRequested = true end function _M.onMessage(handler) table.insert(_M.onMessageHandler, handler) end function _M.onRenderForm(handler) table.insert(_M.onRenderFormHandler, handler) end function _M.requestTop() topRequested = true end function _M.run(reset) local promise = nil if pool == nil or reset == true then promise = initializePool() else promise = Promise.new() promise:resolve() end return promise:next(function() return writeFile(config.PoolDir.."temp.pool") end):next(function () if not hasThreads then return util.loadAndStart(config.ROM) end end):next(function() return mainLoop() end) end return _M