Add files via upload
This commit is contained in:
parent
08087c748b
commit
afbbbdd010
8 changed files with 1370 additions and 15 deletions
40
neat-mario/FitnessChangesOnly/config.lua
Normal file
40
neat-mario/FitnessChangesOnly/config.lua
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
local _M = {}
|
||||||
|
|
||||||
|
_M.NeatConfig = {
|
||||||
|
--Filename = "DP1.state",
|
||||||
|
Filename = "C:/Users/mmill/Downloads/BizHawk-2.2/Lua/SNES/neat-mario/pool/DP1.state",
|
||||||
|
Population = 300,
|
||||||
|
DeltaDisjoint = 2.0,
|
||||||
|
DeltaWeights = 0.4,
|
||||||
|
DeltaThreshold = 1.0,
|
||||||
|
StaleSpecies = 15,
|
||||||
|
MutateConnectionsChance = 0.25,
|
||||||
|
PerturbChance = 0.90,
|
||||||
|
CrossoverChance = 0.75,
|
||||||
|
LinkMutationChance = 2.0,
|
||||||
|
NodeMutationChance = 0.50,
|
||||||
|
BiasMutationChance = 0.40,
|
||||||
|
StepSize = 0.1,
|
||||||
|
DisableMutationChance = 0.4,
|
||||||
|
EnableMutationChance = 0.2,
|
||||||
|
TimeoutConstant = 20,
|
||||||
|
MaxNodes = 1000000,
|
||||||
|
}
|
||||||
|
|
||||||
|
_M.ButtonNames = {
|
||||||
|
"A",
|
||||||
|
"B",
|
||||||
|
"X",
|
||||||
|
"Y",
|
||||||
|
"Up",
|
||||||
|
"Down",
|
||||||
|
"Left",
|
||||||
|
"Right",
|
||||||
|
}
|
||||||
|
|
||||||
|
_M.BoxRadius = 6
|
||||||
|
_M.InputSize = (_M.BoxRadius*2+1)*(_M.BoxRadius*2+1)
|
||||||
|
|
||||||
|
_M.Running = false
|
||||||
|
|
||||||
|
return _M
|
126
neat-mario/FitnessChangesOnly/game.lua
Normal file
126
neat-mario/FitnessChangesOnly/game.lua
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
--Notes here
|
||||||
|
config = require "config"
|
||||||
|
local _M = {}
|
||||||
|
function _M.getPositions()
|
||||||
|
marioX = memory.read_s16_le(0x94)
|
||||||
|
marioY = memory.read_s16_le(0x96)
|
||||||
|
|
||||||
|
local layer1x = memory.read_s16_le(0x1A);
|
||||||
|
local layer1y = memory.read_s16_le(0x1C);
|
||||||
|
|
||||||
|
screenX = marioX-layer1x
|
||||||
|
screenY = marioY-layer1y
|
||||||
|
end
|
||||||
|
|
||||||
|
function _M.getCoins()
|
||||||
|
local coins = memory.readbyte(0x0DBF)
|
||||||
|
return coins
|
||||||
|
end
|
||||||
|
|
||||||
|
function _M.getScore()
|
||||||
|
local scoreLeft = memory.read_s16_le(0x0F34)
|
||||||
|
local scoreRight = memory.read_s16_le(0x0F36)
|
||||||
|
local score = ( scoreLeft * 10 ) + scoreRight
|
||||||
|
return score
|
||||||
|
end
|
||||||
|
|
||||||
|
function _M.getMarioHit(alreadyHit)
|
||||||
|
local timer = memory.read_s16_le(0x1497)
|
||||||
|
if timer > 0 then
|
||||||
|
if alreadyHit == false then
|
||||||
|
return true
|
||||||
|
else
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
else
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function _M.getMarioHitTimer()
|
||||||
|
local timer = memory.read_s16_le(0x1497)
|
||||||
|
return timer
|
||||||
|
end
|
||||||
|
|
||||||
|
function _M.getTile(dx, dy)
|
||||||
|
x = math.floor((marioX+dx+8)/16)
|
||||||
|
y = math.floor((marioY+dy)/16)
|
||||||
|
|
||||||
|
return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
|
||||||
|
end
|
||||||
|
|
||||||
|
function _M.getSprites()
|
||||||
|
local sprites = {}
|
||||||
|
for slot=0,11 do
|
||||||
|
local status = memory.readbyte(0x14C8+slot)
|
||||||
|
if status ~= 0 then
|
||||||
|
spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
|
||||||
|
spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
|
||||||
|
sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return sprites
|
||||||
|
end
|
||||||
|
|
||||||
|
function _M.getExtendedSprites()
|
||||||
|
local extended = {}
|
||||||
|
for slot=0,11 do
|
||||||
|
local number = memory.readbyte(0x170B+slot)
|
||||||
|
if number ~= 0 then
|
||||||
|
spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
|
||||||
|
spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
|
||||||
|
extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return extended
|
||||||
|
end
|
||||||
|
|
||||||
|
function _M.getInputs()
|
||||||
|
_M.getPositions()
|
||||||
|
|
||||||
|
sprites = _M.getSprites()
|
||||||
|
extended = _M.getExtendedSprites()
|
||||||
|
|
||||||
|
local inputs = {}
|
||||||
|
|
||||||
|
for dy=-config.BoxRadius*16,config.BoxRadius*16,16 do
|
||||||
|
for dx=-config.BoxRadius*16,config.BoxRadius*16,16 do
|
||||||
|
inputs[#inputs+1] = 0
|
||||||
|
|
||||||
|
tile = _M.getTile(dx, dy)
|
||||||
|
if tile == 1 and marioY+dy < 0x1B0 then
|
||||||
|
inputs[#inputs] = 1
|
||||||
|
end
|
||||||
|
|
||||||
|
for i = 1,#sprites do
|
||||||
|
distx = math.abs(sprites[i]["x"] - (marioX+dx))
|
||||||
|
disty = math.abs(sprites[i]["y"] - (marioY+dy))
|
||||||
|
if distx <= 8 and disty <= 8 then
|
||||||
|
inputs[#inputs] = -1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
for i = 1,#extended do
|
||||||
|
distx = math.abs(extended[i]["x"] - (marioX+dx))
|
||||||
|
disty = math.abs(extended[i]["y"] - (marioY+dy))
|
||||||
|
if distx < 8 and disty < 8 then
|
||||||
|
inputs[#inputs] = -1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
end
|
||||||
|
|
||||||
|
function _M.clearJoypad()
|
||||||
|
controller = {}
|
||||||
|
for b = 1,#config.ButtonNames do
|
||||||
|
controller["P1 " .. config.ButtonNames[b]] = false
|
||||||
|
end
|
||||||
|
joypad.set(controller)
|
||||||
|
end
|
||||||
|
|
||||||
|
return _M
|
1128
neat-mario/FitnessChangesOnly/mario-neat.lua
Normal file
1128
neat-mario/FitnessChangesOnly/mario-neat.lua
Normal file
File diff suppressed because it is too large
Load diff
9
neat-mario/FitnessChangesOnly/mathFunctions.lua
Normal file
9
neat-mario/FitnessChangesOnly/mathFunctions.lua
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
--Notes here
|
||||||
|
|
||||||
|
local _M = {}
|
||||||
|
|
||||||
|
function _M.sigmoid(x)
|
||||||
|
return 2/(1+math.exp(-4.9*x))-1
|
||||||
|
end
|
||||||
|
|
||||||
|
return _M
|
BIN
neat-mario/FitnessChangesOnly/pool/DP1.State
Normal file
BIN
neat-mario/FitnessChangesOnly/pool/DP1.State
Normal file
Binary file not shown.
|
@ -1,6 +1,7 @@
|
||||||
--Notes here
|
--Notes here
|
||||||
config = require "config"
|
config = require "config"
|
||||||
local _M = {}
|
local _M = {}
|
||||||
|
|
||||||
function _M.getPositions()
|
function _M.getPositions()
|
||||||
marioX = memory.read_s16_le(0x94)
|
marioX = memory.read_s16_le(0x94)
|
||||||
marioY = memory.read_s16_le(0x96)
|
marioY = memory.read_s16_le(0x96)
|
||||||
|
@ -8,8 +9,8 @@ function _M.getPositions()
|
||||||
local layer1x = memory.read_s16_le(0x1A);
|
local layer1x = memory.read_s16_le(0x1A);
|
||||||
local layer1y = memory.read_s16_le(0x1C);
|
local layer1y = memory.read_s16_le(0x1C);
|
||||||
|
|
||||||
screenX = marioX-layer1x
|
_M.screenX = marioX-layer1x
|
||||||
screenY = marioY-layer1y
|
_M.screenY = marioY-layer1y
|
||||||
end
|
end
|
||||||
|
|
||||||
function _M.getCoins()
|
function _M.getCoins()
|
||||||
|
@ -84,10 +85,16 @@ function _M.getInputs()
|
||||||
extended = _M.getExtendedSprites()
|
extended = _M.getExtendedSprites()
|
||||||
|
|
||||||
local inputs = {}
|
local inputs = {}
|
||||||
|
local inputDeltaDistance = {}
|
||||||
|
|
||||||
|
local layer1x = memory.read_s16_le(0x1A);
|
||||||
|
local layer1y = memory.read_s16_le(0x1C);
|
||||||
|
|
||||||
|
|
||||||
for dy=-config.BoxRadius*16,config.BoxRadius*16,16 do
|
for dy=-config.BoxRadius*16,config.BoxRadius*16,16 do
|
||||||
for dx=-config.BoxRadius*16,config.BoxRadius*16,16 do
|
for dx=-config.BoxRadius*16,config.BoxRadius*16,16 do
|
||||||
inputs[#inputs+1] = 0
|
inputs[#inputs+1] = 0
|
||||||
|
inputDeltaDistance[#inputDeltaDistance+1] = 1
|
||||||
|
|
||||||
tile = _M.getTile(dx, dy)
|
tile = _M.getTile(dx, dy)
|
||||||
if tile == 1 and marioY+dy < 0x1B0 then
|
if tile == 1 and marioY+dy < 0x1B0 then
|
||||||
|
@ -99,6 +106,12 @@ function _M.getInputs()
|
||||||
disty = math.abs(sprites[i]["y"] - (marioY+dy))
|
disty = math.abs(sprites[i]["y"] - (marioY+dy))
|
||||||
if distx <= 8 and disty <= 8 then
|
if distx <= 8 and disty <= 8 then
|
||||||
inputs[#inputs] = -1
|
inputs[#inputs] = -1
|
||||||
|
|
||||||
|
local dist = math.sqrt((distx * distx) + (disty * disty))
|
||||||
|
if dist > 8 then
|
||||||
|
inputDeltaDistance[#inputDeltaDistance] = mathFunctions.squashDistance(dist)
|
||||||
|
--gui.drawLine(screenX, screenY, sprites[i]["x"] - layer1x, sprites[i]["y"] - layer1y, 0x50000000)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -106,13 +119,26 @@ function _M.getInputs()
|
||||||
distx = math.abs(extended[i]["x"] - (marioX+dx))
|
distx = math.abs(extended[i]["x"] - (marioX+dx))
|
||||||
disty = math.abs(extended[i]["y"] - (marioY+dy))
|
disty = math.abs(extended[i]["y"] - (marioY+dy))
|
||||||
if distx < 8 and disty < 8 then
|
if distx < 8 and disty < 8 then
|
||||||
|
|
||||||
|
--console.writeline(screenX .. "," .. screenY .. " to " .. extended[i]["x"]-layer1x .. "," .. extended[i]["y"]-layer1y)
|
||||||
inputs[#inputs] = -1
|
inputs[#inputs] = -1
|
||||||
|
local dist = math.sqrt((distx * distx) + (disty * disty))
|
||||||
|
if dist > 8 then
|
||||||
|
inputDeltaDistance[#inputDeltaDistance] = mathFunctions.squashDistance(dist)
|
||||||
|
--gui.drawLine(screenX, screenY, extended[i]["x"] - layer1x, extended[i]["y"] - layer1y, 0x50000000)
|
||||||
|
end
|
||||||
|
--if dist > 100 then
|
||||||
|
--dw = mathFunctions.squashDistance(dist)
|
||||||
|
--console.writeline(dist .. " to " .. dw)
|
||||||
|
--gui.drawLine(screenX, screenY, extended[i]["x"] - layer1x, extended[i]["y"] - layer1y, 0x50000000)
|
||||||
|
--end
|
||||||
|
--inputs[#inputs] = {["value"]=-1, ["dw"]=dw}
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
return inputs
|
return inputs, inputDeltaDistance
|
||||||
end
|
end
|
||||||
|
|
||||||
function _M.clearJoypad()
|
function _M.clearJoypad()
|
||||||
|
|
|
@ -107,7 +107,7 @@ function newNeuron()
|
||||||
local neuron = {}
|
local neuron = {}
|
||||||
neuron.incoming = {}
|
neuron.incoming = {}
|
||||||
neuron.value = 0.0
|
neuron.value = 0.0
|
||||||
|
--neuron.dw = 1
|
||||||
return neuron
|
return neuron
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -143,15 +143,18 @@ function generateNetwork(genome)
|
||||||
genome.network = network
|
genome.network = network
|
||||||
end
|
end
|
||||||
|
|
||||||
function evaluateNetwork(network, inputs)
|
function evaluateNetwork(network, inputs, inputDeltas)
|
||||||
table.insert(inputs, 1)
|
table.insert(inputs, 1)
|
||||||
|
table.insert(inputDeltas,99)
|
||||||
if #inputs ~= Inputs then
|
if #inputs ~= Inputs then
|
||||||
console.writeline("Incorrect number of neural network inputs.")
|
console.writeline("Incorrect number of neural network inputs.")
|
||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
for i=1,Inputs do
|
for i=1,Inputs do
|
||||||
network.neurons[i].value = inputs[i]
|
network.neurons[i].value = inputs[i] * inputDeltas[i]
|
||||||
|
--network.neurons[i].value = inputs[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
for _,neuron in pairs(network.neurons) do
|
for _,neuron in pairs(network.neurons) do
|
||||||
|
@ -655,9 +658,11 @@ end
|
||||||
function evaluateCurrent()
|
function evaluateCurrent()
|
||||||
local species = pool.species[pool.currentSpecies]
|
local species = pool.species[pool.currentSpecies]
|
||||||
local genome = species.genomes[pool.currentGenome]
|
local genome = species.genomes[pool.currentGenome]
|
||||||
|
|
||||||
inputs = game.getInputs()
|
local inputDeltas = {}
|
||||||
controller = evaluateNetwork(genome.network, inputs)
|
inputs, inputDeltas = game.getInputs()
|
||||||
|
|
||||||
|
controller = evaluateNetwork(genome.network, inputs, inputDeltas)
|
||||||
|
|
||||||
if controller["P1 Left"] and controller["P1 Right"] then
|
if controller["P1 Left"] and controller["P1 Right"] then
|
||||||
controller["P1 Left"] = false
|
controller["P1 Left"] = false
|
||||||
|
@ -964,7 +969,7 @@ function flipState()
|
||||||
end
|
end
|
||||||
|
|
||||||
function loadPool()
|
function loadPool()
|
||||||
filename = forms.openfile("DP1.state.pool","C:\Users\mmill\Downloads\BizHawk-2.2\Lua\SNES\neat-mario\pool")
|
filename = forms.openfile("DP1.state.pool","C:/Users/mmill/Downloads/BizHawk-2.2/Lua/SNES/neat-mario/pool/")
|
||||||
--local filename = forms.gettext(saveLoadFile)
|
--local filename = forms.gettext(saveLoadFile)
|
||||||
forms.settext(saveLoadFile, filename)
|
forms.settext(saveLoadFile, filename)
|
||||||
loadFile(filename)
|
loadFile(filename)
|
||||||
|
@ -1006,7 +1011,7 @@ GenomeLabel = forms.label(form, "Genome: " .. pool.currentGenome, 230, 5)
|
||||||
MeasuredLabel = forms.label(form, "Measured: " .. "", 330, 5)
|
MeasuredLabel = forms.label(form, "Measured: " .. "", 330, 5)
|
||||||
|
|
||||||
FitnessLabel = forms.label(form, "Fitness: " .. "", 5, 30)
|
FitnessLabel = forms.label(form, "Fitness: " .. "", 5, 30)
|
||||||
MaxLabel = forms.label(form, "Maximum: " .. "", 130, 30)
|
MaxLabel = forms.label(form, "Max: " .. "", 130, 30)
|
||||||
|
|
||||||
CoinsLabel = forms.label(form, "Coins: " .. "", 5, 65)
|
CoinsLabel = forms.label(form, "Coins: " .. "", 5, 65)
|
||||||
ScoreLabel = forms.label(form, "Score: " .. "", 130, 65)
|
ScoreLabel = forms.label(form, "Score: " .. "", 130, 65)
|
||||||
|
@ -1048,7 +1053,7 @@ while true do
|
||||||
if checkMarioCollision == true then
|
if checkMarioCollision == true then
|
||||||
if hitTimer > 0 then
|
if hitTimer > 0 then
|
||||||
marioHitCounter = marioHitCounter + 1
|
marioHitCounter = marioHitCounter + 1
|
||||||
console.writeline("Mario took damage, hit counter: " .. marioHitCounter)
|
--console.writeline("Mario took damage, hit counter: " .. marioHitCounter)
|
||||||
checkMarioCollision = false
|
checkMarioCollision = false
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1065,14 +1070,14 @@ while true do
|
||||||
local coins = game.getCoins() - startCoins
|
local coins = game.getCoins() - startCoins
|
||||||
local score = game.getScore() - startScore
|
local score = game.getScore() - startScore
|
||||||
|
|
||||||
console.writeline("Coins: " .. coins .. " score: " .. score)
|
--console.writeline("Coins: " .. coins .. " score: " .. score)
|
||||||
|
|
||||||
local coinScoreFitness = (coins * 50) + (score * 0.2)
|
local coinScoreFitness = (coins * 50) + (score * 0.2)
|
||||||
if (coins + score) > 0 then
|
if (coins + score) > 0 then
|
||||||
console.writeline("Coins and Score added " .. coinScoreFitness .. " fitness")
|
console.writeline("Coins and Score added " .. coinScoreFitness .. " fitness")
|
||||||
end
|
end
|
||||||
|
|
||||||
local hitPenalty = marioHitCounter * 200
|
local hitPenalty = marioHitCounter * 100
|
||||||
|
|
||||||
local fitness = coinScoreFitness - hitPenalty + rightmost - pool.currentFrame / 2
|
local fitness = coinScoreFitness - hitPenalty + rightmost - pool.currentFrame / 2
|
||||||
if rightmost > 4816 then
|
if rightmost > 4816 then
|
||||||
|
@ -1110,11 +1115,12 @@ while true do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
gui.drawEllipse(game.screenX-84, game.screenY-84, 192, 192, 0x50000000)
|
||||||
forms.settext(FitnessLabel, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3))
|
forms.settext(FitnessLabel, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3))
|
||||||
forms.settext(GenerationLabel, "Generation: " .. pool.generation)
|
forms.settext(GenerationLabel, "Generation: " .. pool.generation)
|
||||||
forms.settext(SpeciesLabel, "Species: " .. pool.currentSpecies)
|
forms.settext(SpeciesLabel, "Species: " .. pool.currentSpecies)
|
||||||
forms.settext(GenomeLabel, "Genome: " .. pool.currentGenome)
|
forms.settext(GenomeLabel, "Genome: " .. pool.currentGenome)
|
||||||
forms.settext(MaxLabel, "Maximum: " .. math.floor(pool.maxFitness))
|
forms.settext(MaxLabel, "Max: " .. math.floor(pool.maxFitness))
|
||||||
forms.settext(MeasuredLabel, "Measured: " .. math.floor(measured/total*100) .. "%")
|
forms.settext(MeasuredLabel, "Measured: " .. math.floor(measured/total*100) .. "%")
|
||||||
forms.settext(CoinsLabel, "Coins: " .. (game.getCoins() - startCoins))
|
forms.settext(CoinsLabel, "Coins: " .. (game.getCoins() - startCoins))
|
||||||
forms.settext(ScoreLabel, "Score: " .. (game.getScore() - startScore))
|
forms.settext(ScoreLabel, "Score: " .. (game.getScore() - startScore))
|
||||||
|
|
|
@ -6,4 +6,24 @@ function _M.sigmoid(x)
|
||||||
return 2/(1+math.exp(-4.9*x))-1
|
return 2/(1+math.exp(-4.9*x))-1
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function _M.squashDistance(x)
|
||||||
|
local window = 0.20
|
||||||
|
local delta = 0.25
|
||||||
|
|
||||||
|
local dist = (x-8)
|
||||||
|
local newDist = 1
|
||||||
|
|
||||||
|
while dist > 0 do
|
||||||
|
newDist = newDist - (window*delta)
|
||||||
|
dist = dist - 1
|
||||||
|
end
|
||||||
|
|
||||||
|
if newDist < 0.80 then
|
||||||
|
newDist = 0.80
|
||||||
|
end
|
||||||
|
|
||||||
|
return newDist
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
return _M
|
return _M
|
Loading…
Add table
Reference in a new issue