Skip to content

Instantly share code, notes, and snippets.

@taotao54321
Created September 1, 2019 23:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save taotao54321/b8bc72c645b1846e20d078ba406ecdef to your computer and use it in GitHub Desktop.
Save taotao54321/b8bc72c645b1846e20d078ba406ecdef to your computer and use it in GitHub Desktop.
Vice: Project Doom (NES) 区間最適化テスト
--[[
Vice: Project Doom (NES) 区間最適化テスト
TAS Editor上で選択した区間を焼きなまし法で最適化する。
--]]
-- 焼きなまし法関連定数
-- スコアはx座標
local SA_TEMP_START = 200.0; -- 初期温度(だいたいこれくらいのスコア悪化は許す)
local SA_ITER = 100; -- ループ回数
local BUTTON_A = 0x01;
local BUTTON_B = 0x02;
local BUTTON_S = 0x04;
local BUTTON_T = 0x08;
local BUTTON_U = 0x10;
local BUTTON_D = 0x20;
local BUTTON_L = 0x40;
local BUTTON_R = 0x80;
local function dbg_array(ary)
local res = "";
res = res .. "[";
for i,e in ipairs(ary) do
if i ~= 1 then res = res .. ","; end
res = res .. tostring(e);
end
res = res .. "]";
return res;
end
local function UNREACHABLE()
assert(false, "UNREACHABLE");
end
local function ERROR(msg)
error(msg);
end
local function INFO(...)
print(...);
end
local function apply_state(state)
taseditor.setplayback(state.frame_start);
for i,input in ipairs(state.inputs) do
taseditor.submitinputchange(state.frame_start+i-1, 1, input);
end
taseditor.applyinputchanges();
for _ = 1, #state.inputs do
emu.frameadvance();
end
end
local function gen_initial_state(frame_start, frame_n)
local inputs = {};
for frame = frame_start, frame_start+frame_n-1 do
local input = taseditor.getinput(frame, 1);
table.insert(inputs, input);
end
return {
frame_start = frame_start,
inputs = inputs,
};
end
local function rand_input()
local input = 0;
-- 横移動
do
local r = math.random();
if r < 0.02 then
input = OR(input, BUTTON_L);
elseif 0.02 <= r and r < 0.1 then
-- do nothing
else
input = OR(input, BUTTON_R);
end
end
-- しゃがみ
do
local r = math.random();
if r < 0.4 then
input = OR(input, BUTTON_D);
end
end
-- ジャンプ
do
local r = math.random();
if r < 0.2 then -- 通常ジャンプ
input = OR(input, BUTTON_A);
elseif 0.2 <= r and r < 0.4 then -- ハイジャンプ
input = OR(input, BUTTON_A);
input = OR(input, BUTTON_S);
else
-- do nothing
end
end
return input;
end
local function neighbor_mutate(state)
local inputs = copytable(state.inputs);
local idx = math.random(1, #inputs);
local input = rand_input();
INFO(string.format(" mutate: frame %d: %02X -> %02X",
state.frame_start+idx-1, inputs[idx], input));
inputs[idx] = input;
return {
frame_start = state.frame_start,
inputs = inputs,
};
end
local function neighbor_swap(state)
local inputs = copytable(state.inputs);
local idx1 = math.random(1, #inputs);
local idx2 = math.random(1, #inputs);
INFO(string.format(" swap: frame %d <-> %d",
state.frame_start+idx1-1, state.frame_start+idx2-1));
inputs[idx1], inputs[idx2] = inputs[idx2], inputs[idx1];
return {
frame_start = state.frame_start,
inputs = inputs,
};
end
local function neighbor_insert(state)
local inputs = copytable(state.inputs);
local idx = math.random(1, #inputs);
local input = rand_input();
INFO(string.format(" insert: frame %d: %d", state.frame_start+idx-1, input));
for i = #inputs, idx+1, -1 do
inputs[i] = inputs[i-1];
end
inputs[idx] = input;
return {
frame_start = state.frame_start,
inputs = inputs,
};
end
local function neighbor_delete(state)
local inputs = copytable(state.inputs);
local idx = math.random(1, #inputs);
local input = rand_input();
INFO(string.format(" delete: frame %d: %d", state.frame_start+idx-1, input));
for i = idx, #inputs-1 do
inputs[i] = inputs[i+1];
end
inputs[#inputs] = input;
return {
frame_start = state.frame_start,
inputs = inputs,
};
end
local function gen_neighbor_state(state)
local r = math.random();
if r < 0.25 then
return neighbor_mutate(state);
elseif 0.25 <= r and r < 0.5 then
return neighbor_swap(state);
elseif 0.5 <= r and r < 0.75 then
return neighbor_insert(state);
else
return neighbor_delete(state);
end
UNREACHABLE();
end
local function f_energy(state)
apply_state(state);
local x = 0x100*memory.readbyte(0x0200) + memory.readbyte(0x01F0);
return -1.0 * x;
end
-- 温度スケジュール
-- とりあえず線形
local function f_temperature(progress)
return SA_TEMP_START * (1.0-progress);
end
local function f_probability(energy, energy_nex, temp)
if energy_nex <= energy then return 1.0; end
return math.exp((energy-energy_nex) / (temp+1e-9));
end
local function anneal(state)
local energy = f_energy(state);
local state_best = copytable(state);
local energy_best = energy;
for i = 1, SA_ITER do
local temp = f_temperature(i/SA_ITER);
INFO(string.format("[iter=%d, temp=%.2f]", i, temp));
local state_nex = gen_neighbor_state(state);
local energy_nex = f_energy(state_nex);
if math.random() < f_probability(energy,energy_nex,temp) then
if energy_nex < energy_best then
INFO(string.format(" energy_best updated: %.2f -> %.2f", energy_best, energy_nex));
state_best = copytable(state_nex);
energy_best = energy_nex;
end
INFO(string.format(" -- accept: %.2f -> %.2f, energy_best=%.2f",
energy, energy_nex, energy_best));
state = state_nex;
energy = energy_nex;
else
INFO(string.format(" --- decline: %.2f -> %.2f, energy_best=%.2f",
energy, energy_nex, energy_best));
end
end
return state_best;
end
local function get_interval()
local sel = taseditor.getselection();
if not sel then ERROR("no selection"); end
-- [l,r]
local frame_l = sel[1];
local frame_r = sel[#sel];
local frame_n = frame_r - frame_l + 1;
if frame_n ~= #sel then ERROR("noncontiguous selection"); end
return frame_l, frame_n;
end
local function main()
if not taseditor.engaged() then ERROR("TAS Editor not engaged"); end
emu.speedmode("maximum");
local frame_start, frame_n = get_interval();
INFO(string.format("interval: [%d,%d]", frame_start, frame_start+frame_n-1));
INFO();
local state_ini = gen_initial_state(frame_start, frame_n);
local state_best = anneal(state_ini);
apply_state(state_best);
taseditor.stopseeking();
end
main();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment