From ed5e64c6b74dd68712718713f02ed431a227b21f Mon Sep 17 00:00:00 2001 From: empathicqubit Date: Thu, 13 May 2021 05:55:09 -0400 Subject: [PATCH] Handle threading in the wrapper --- pool.lua | 9 +-------- runner-process.lua | 12 +++++++++++- runner-wrapper.lua | 45 +++++++++++++++++++++++++++++---------------- 3 files changed, 41 insertions(+), 25 deletions(-) diff --git a/pool.lua b/pool.lua index 80e08eb..ef40899 100644 --- a/pool.lua +++ b/pool.lua @@ -729,14 +729,7 @@ local function mainLoop(currentSpecies, topGenome) end if hasThreads then - slice = {} - for i=currentSpecies, currentSpecies + config.NeatConfig.Threads - 1, 1 do - if pool.species[i] == nil then - break - end - - table.insert(slice, pool.species[i]) - end + slice = pool.species end return runner.run( diff --git a/runner-process.lua b/runner-process.lua index 8daf9f8..5dbb81d 100644 --- a/runner-process.lua +++ b/runner-process.lua @@ -40,6 +40,14 @@ local function writeResponse(object) outputPipe:flush() end +local function unblockLoop() + return util.delay(1000000):next(function() + outputPipe:write(".\n") + outputPipe:flush() + return unblockLoop() + end) +end + local runner = Runner(Promise) runner.onMessage(function(msg, color) statusLine = msg @@ -147,7 +155,9 @@ writeResponse({ type = 'onInit', ts = ts }) print(string.format('Wrote init to output at %d', ts)) -waiter:next(waitLoop):catch(function(error) +waiter:next(function(inputLine) + return waitLoop(inputLine) +end):catch(function(error) if type(error) == "table" then error = "\n"..table.concat(error, "\n") end diff --git a/runner-wrapper.lua b/runner-wrapper.lua index 129fc0c..2e12154 100644 --- a/runner-wrapper.lua +++ b/runner-wrapper.lua @@ -158,22 +158,14 @@ return function(promise) onReset(_M, handler) end - _M.run = function(speciesSlice, generationIdx, genomeCallback) + _M.run = function(species, generationIdx, genomeCallback) local promise = Promise.new() promise:resolve() return promise:next(function() - return launchChildren(_M, #speciesSlice) + return launchChildren(_M, config.NeatConfig.Threads) end):next(function() message(_M, 'Setting up child processes') - for i=1,#speciesSlice,1 do - local inputPipe = _M.poppets[i].input - inputPipe:write(serpent.dump({speciesSlice[i], generationIdx}).."\n") - inputPipe:flush() - end - - message(_M, 'Waiting for child processes to finish') - local maxFitness = nil local function readLoop(outputPipe) return util.promiseWrap(function() @@ -201,8 +193,8 @@ return function(promise) elseif obj.type == 'onReset' then reset(_M) elseif obj.type == 'onGenome' then - for i=1,#speciesSlice,1 do - local s = speciesSlice[i] + for i=1,#species,1 do + local s = species[i] if s.id == obj.speciesId then message(_M, string.format('Write Species %d Genome %d', obj.speciesId, obj.genomeIndex)) s.genomes[obj.genomeIndex] = obj.genome @@ -227,12 +219,33 @@ return function(promise) end local waiters = {} - for i=1,#speciesSlice,1 do - local outputPipe = _M.poppets[i].output - local waiter = readLoop(outputPipe) - table.insert(waiters, waiter) + for t=1,config.NeatConfig.Threads,1 do + waiters[t] = Promise.new() + waiters[t]:resolve() end + local currentSpecies = 1 + while currentSpecies < #species do + for t=1,config.NeatConfig.Threads,1 do + local s = species[currentSpecies] + if s == nil then + break + end + + waiters[t] = waiters[t]:next(function() + local inputPipe = _M.poppets[t].input + inputPipe:write(serpent.dump({s, generationIdx}).."\n") + inputPipe:flush() + + local outputPipe = _M.poppets[t].output + return readLoop(outputPipe) + end) + currentSpecies = currentSpecies + 1 + end + end + + message(_M, 'Waiting for child processes to finish') + return Promise.all(table.unpack(waiters)) end):next(function(maxFitnesses) message(_M, 'Child processes finished')