-
-
Save n-ari/7771d6450d1098d18e9e91ee77895767 to your computer and use it in GitHub Desktop.
montecarlo.cpp for SECCON CTF 2019 International problem-five
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// #define TL 1000 | |
#define TL 1200 | |
#define GROW_TH 10 | |
#include <bits/stdc++.h> | |
using namespace std; | |
#define REP(i,n) for(int i=0; i<(int)(n); i++) | |
#define FOR(i,a,b) for(int i=(int)(a); i<(int)(b); i++) | |
#define FORR(i,a,b) for(int i=(int)(b)-1; i>=(int)(a); i--) | |
using ll = long long; | |
using vi = vector<int>; | |
using vl = vector<ll>; | |
using pii = pair<int,int>; | |
using pll = pair<ll,ll>; | |
#include <ext/pb_ds/assoc_container.hpp> | |
#include <ext/pb_ds/tree_policy.hpp> | |
using namespace __gnu_pbds; | |
using treeset_int = tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update>; | |
using treeset_pii = tree<pii, null_type, less<pii>, rb_tree_tag, tree_order_statistics_node_update>; | |
#define CHMIN(a,b) (a)=min((a),(b)) | |
#define CHMAX(a,b) (a)=max((a),(b)) | |
#define ALL(v) (v).begin(),(v).end() | |
/* | |
H W A B T | |
s1 | |
... | |
sH | |
x1 y1 ... xA yA | |
x1 y1 ... xB yB | |
output best move | |
*/ | |
const int UP = 0; | |
const int DOWN = 1; | |
const int RIGHT = 2; | |
const int LEFT = 3; | |
// U D R L | |
int dx[4] = { 0, 0, +1, -1}; | |
int dy[4] = {-1, +1, 0, 0}; | |
const int HOLE = -1; | |
const int EMPTY = 0; | |
const int APPLE = 1; | |
const int BODY = 2; | |
const int HEAD = 3; | |
// inputs | |
int H,W,T; | |
int mp[252][252]; | |
vector<pii> snake[2]; | |
int num_apples; | |
mt19937 mt; | |
struct State { | |
int turn; // which player do move | |
int rest_turn; // rest turns | |
vector<bool> map_hole, map_apple; // map of holes (or obstacles), map of apples | |
deque<pii> snakes[2]; // snakes | |
treeset_pii empties; | |
set<pii> apples; | |
State(){} | |
bool canstep(int dir){ | |
int y,x; | |
tie(y,x) = snakes[turn][0]; | |
int ny = y + dy[dir]; | |
int nx = x + dx[dir]; | |
int ty,tx; | |
tie(ty,tx) = snakes[turn].back(); | |
return rest_turn > 0 && (!map_hole[ny*W+nx] || (ny==ty && nx==tx)); | |
} | |
void step(int dir){ | |
assert(canstep(dir)); | |
int y,x; | |
tie(y,x) = snakes[turn][0]; | |
int ny = y + dy[dir]; | |
int nx = x + dx[dir]; | |
bool grow = map_apple[ny*W+nx]; | |
if(grow){ | |
// lengthen | |
map_hole[ny*W+nx] = true; | |
map_apple[ny*W+nx] = false; | |
apples.erase(pii(ny,nx)); | |
snakes[turn].push_front(pii(ny, nx)); | |
// randomly put apple | |
int cnt = empties.size(); | |
if(cnt > 0){ | |
int choice = mt() % cnt; | |
int cy, cx; | |
auto it = empties.find_by_order(choice); | |
tie(cy, cx) = *it; | |
empties.erase(it); | |
map_apple[cy*W+cx] = true; | |
apples.insert(pii(ny,nx)); | |
} | |
}else{ | |
int ty,tx; | |
tie(ty,tx) = snakes[turn].back(); | |
map_hole[ty*W+tx] = false; | |
empties.insert(pii(ty,tx)); | |
map_hole[ny*W+nx] = true; | |
empties.erase(pii(ny,nx)); | |
snakes[turn].push_front(pii(ny, nx)); | |
snakes[turn].pop_back(); | |
} | |
turn ^= 1; | |
rest_turn -= 1; | |
} | |
}; | |
struct MCTSNode { | |
int turn; | |
int rest_turn; | |
double cnt; | |
double win; | |
MCTSNode *parent; | |
MCTSNode *child[4]; | |
}; | |
void input(){ | |
int lena, lenb; | |
scanf("%d%d%d%d%d", &H, &W, &lena, &lenb, &T); | |
num_apples = 0; | |
REP(i,H)REP(j,W){ | |
char buf[10]; | |
scanf("%s", buf); | |
if(buf[0]=='H' && buf[1]=='O'){ | |
mp[i][j] = HOLE; | |
}else if(buf[0] == 'E'){ | |
mp[i][j] = EMPTY; | |
}else if(buf[0] == 'A'){ | |
mp[i][j] = APPLE; | |
num_apples++; | |
}else if(buf[0] == 'B'){ | |
mp[i][j] = BODY; | |
}else{ | |
mp[i][j] = HEAD; | |
} | |
} | |
REP(i,lena){ | |
int x,y; | |
scanf("%d %d", &x, &y); | |
snake[0].push_back(pii(y,x)); | |
} | |
REP(i,lenb){ | |
int x,y; | |
scanf("%d %d", &x, &y); | |
snake[1].push_back(pii(y,x)); | |
} | |
} | |
void output(int dir){ | |
printf("%d\n", dir); | |
} | |
int montecarlo(){ | |
auto start = chrono::system_clock::now(); | |
MCTSNode *root = new MCTSNode(); | |
root->turn = 0; | |
root->rest_turn = 2*(T+1); | |
root->cnt = root->win = 0; | |
root->parent = NULL; | |
fill(root->child, root->child+4, (MCTSNode*)NULL); | |
while(true){ | |
auto end = chrono::system_clock::now(); | |
double elapsed = chrono::duration_cast<chrono::milliseconds>(end-start).count(); | |
if(elapsed > TL)break; | |
State state; | |
state.turn = 0; | |
state.rest_turn = 2*(T+1); | |
state.map_hole.assign(H*W, false); | |
state.map_apple.assign(H*W, false); | |
REP(i,H)REP(j,W){ | |
if(mp[i][j] == APPLE){ | |
state.map_apple[i*W+j] = true; | |
state.apples.insert(pii(i,j)); | |
}else if(mp[i][j] != EMPTY){ | |
state.map_hole[i*W+j] = true; | |
} | |
if(mp[i][j] == EMPTY){ | |
state.empties.insert(pii(i,j)); | |
} | |
} | |
copy(ALL(snake[0]), back_inserter(state.snakes[0])); | |
copy(ALL(snake[1]), back_inserter(state.snakes[1])); | |
MCTSNode *cur = root; | |
bool stopend = false; | |
int win = -1; | |
while(cur->child[0] != NULL && cur->rest_turn != 0){ | |
// go to child | |
// random choice | |
// using UCT: win/cnt + c sqrt(ln root->cnt / cnt) | |
double best_score = -1e18; | |
vi choices; | |
REP(i,4){ | |
if(!state.canstep(i))continue; | |
double score = 0.0; | |
if(cur->child[i]->cnt > 0)score += 1.0 - cur->child[i]->win / cur->child[i]->cnt; | |
if(cur->child[i]->cnt > 0){ | |
score += 1.0 * sqrt(2.0 * log(cur->cnt) / cur->child[i]->cnt); | |
}else{ | |
score += 1e18; | |
} | |
if(score > best_score){ | |
choices.clear(); | |
best_score = score; | |
} | |
if(score >= best_score){ | |
choices.push_back(i); | |
} | |
} | |
if(choices.size() == 0){ | |
win = state.turn ^ 1; | |
stopend = true; | |
break; | |
} | |
int choice = choices[mt() % choices.size()]; | |
cur = cur->child[choice]; | |
state.step(choice); | |
} | |
// leaf node | |
double score = 1.0; | |
if(stopend){ | |
// | |
}else if(cur->rest_turn == 0){ | |
// done | |
int lena = state.snakes[0].size(); | |
int lenb = state.snakes[1].size(); | |
if(lena>lenb){ | |
win = 0; | |
}else if(lena<lenb){ | |
win = 1; | |
} | |
}else{ | |
// grow leaf | |
if(cur->cnt == GROW_TH){ | |
REP(i,4){ | |
MCTSNode *child = cur->child[i] = new MCTSNode(); | |
child->turn = cur->turn^1; | |
child->rest_turn = cur->rest_turn-1; | |
child->cnt = child->win = 0; | |
child->parent = cur; | |
fill(child->child, child->child+4, (MCTSNode*)NULL); | |
} | |
} | |
// random choice | |
vi choices; | |
REP(i,4)if(state.canstep(i))choices.push_back(i); | |
if(choices.size() == 0){ | |
win = state.turn ^ 1; | |
}else{ | |
int choice = choices[mt() % choices.size()]; | |
if(cur->child[choice] != NULL)cur = cur->child[choice]; | |
state.step(choice); | |
// playout | |
while(state.rest_turn != 0){ | |
int hy,hx; | |
tie(hy,hx) = state.snakes[state.turn][0]; | |
vector<pair<double,int>> choices; | |
REP(i,4)if(state.canstep(i)){ | |
double score = 0.0; | |
int sy = hy + dy[i]; | |
int sx = hx + dx[i]; | |
for(pii apple : state.apples){ | |
int py, px; | |
tie(py, px) = apple; | |
int dist = abs(sy-py) + abs(sx-px); | |
score += pow((double)dist, 0.5); | |
// score = min(score, (double)dist); | |
} | |
choices.push_back(make_pair(score, i)); | |
} | |
sort(ALL(choices)); | |
if(choices.size() == 0){ | |
win = state.turn ^ 1; | |
break; | |
} | |
int sel = mt()%choices.size(); | |
sel = mt()%(sel+1); | |
sel = mt()%(sel+1); | |
int choice = choices[sel].second; | |
state.step(choice); | |
} | |
if(win == -1){ | |
int lena = state.snakes[0].size(); | |
int lenb = state.snakes[1].size(); | |
if(lena>lenb){ | |
win = 0; | |
score = lena; | |
}else if(lena<lenb){ | |
win = 1; | |
score = lenb; | |
} | |
} | |
} | |
} | |
while(cur != NULL){ | |
cur->cnt += 1.0; | |
if(cur->turn == win){ | |
cur->win += 1.0; | |
}else if(win == -1){ | |
cur->win += 0.5; | |
} | |
cur = cur->parent; | |
} | |
} | |
// REP(i,4)printf("%d: %f / %f = %f\n",i, root->child[i]->win, root->child[i]->cnt, root->child[i]->win / root->child[i]->cnt); | |
// int best = -1; | |
// REP(i,4){ | |
// if(best == -1 && root->child[i]->cnt > 0){ | |
// best = i; | |
// }else if(root->child[i]->cnt > 0 && root->child[i]->win / root->child[i]->cnt < root->child[best]->win / root->child[best]->cnt){ | |
// best = i; | |
// } | |
// } | |
int best = 0; | |
REP(i,4)if(root->child[i]->cnt > root->child[best]->cnt)best=i; | |
return best; | |
} | |
int main(){ | |
input(); | |
mt = mt19937(random_device{}()); | |
output(montecarlo()); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment