Skip to content

Instantly share code, notes, and snippets.

@n-ari
Created March 13, 2020 05:53
Show Gist options
  • Save n-ari/7771d6450d1098d18e9e91ee77895767 to your computer and use it in GitHub Desktop.
Save n-ari/7771d6450d1098d18e9e91ee77895767 to your computer and use it in GitHub Desktop.
montecarlo.cpp for SECCON CTF 2019 International problem-five
// #define TL 1000
#define TL 1200
#define GROW_TH 10
#include <bits/stdc++.h>
using namespace std;
#define REP(i,n) for(int i=0; i<(int)(n); i++)
#define FOR(i,a,b) for(int i=(int)(a); i<(int)(b); i++)
#define FORR(i,a,b) for(int i=(int)(b)-1; i>=(int)(a); i--)
using ll = long long;
using vi = vector<int>;
using vl = vector<ll>;
using pii = pair<int,int>;
using pll = pair<ll,ll>;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
using treeset_int = tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update>;
using treeset_pii = tree<pii, null_type, less<pii>, rb_tree_tag, tree_order_statistics_node_update>;
#define CHMIN(a,b) (a)=min((a),(b))
#define CHMAX(a,b) (a)=max((a),(b))
#define ALL(v) (v).begin(),(v).end()
/*
H W A B T
s1
...
sH
x1 y1 ... xA yA
x1 y1 ... xB yB
output best move
*/
const int UP = 0;
const int DOWN = 1;
const int RIGHT = 2;
const int LEFT = 3;
// U D R L
int dx[4] = { 0, 0, +1, -1};
int dy[4] = {-1, +1, 0, 0};
const int HOLE = -1;
const int EMPTY = 0;
const int APPLE = 1;
const int BODY = 2;
const int HEAD = 3;
// inputs
int H,W,T;
int mp[252][252];
vector<pii> snake[2];
int num_apples;
mt19937 mt;
struct State {
int turn; // which player do move
int rest_turn; // rest turns
vector<bool> map_hole, map_apple; // map of holes (or obstacles), map of apples
deque<pii> snakes[2]; // snakes
treeset_pii empties;
set<pii> apples;
State(){}
bool canstep(int dir){
int y,x;
tie(y,x) = snakes[turn][0];
int ny = y + dy[dir];
int nx = x + dx[dir];
int ty,tx;
tie(ty,tx) = snakes[turn].back();
return rest_turn > 0 && (!map_hole[ny*W+nx] || (ny==ty && nx==tx));
}
void step(int dir){
assert(canstep(dir));
int y,x;
tie(y,x) = snakes[turn][0];
int ny = y + dy[dir];
int nx = x + dx[dir];
bool grow = map_apple[ny*W+nx];
if(grow){
// lengthen
map_hole[ny*W+nx] = true;
map_apple[ny*W+nx] = false;
apples.erase(pii(ny,nx));
snakes[turn].push_front(pii(ny, nx));
// randomly put apple
int cnt = empties.size();
if(cnt > 0){
int choice = mt() % cnt;
int cy, cx;
auto it = empties.find_by_order(choice);
tie(cy, cx) = *it;
empties.erase(it);
map_apple[cy*W+cx] = true;
apples.insert(pii(ny,nx));
}
}else{
int ty,tx;
tie(ty,tx) = snakes[turn].back();
map_hole[ty*W+tx] = false;
empties.insert(pii(ty,tx));
map_hole[ny*W+nx] = true;
empties.erase(pii(ny,nx));
snakes[turn].push_front(pii(ny, nx));
snakes[turn].pop_back();
}
turn ^= 1;
rest_turn -= 1;
}
};
struct MCTSNode {
int turn;
int rest_turn;
double cnt;
double win;
MCTSNode *parent;
MCTSNode *child[4];
};
void input(){
int lena, lenb;
scanf("%d%d%d%d%d", &H, &W, &lena, &lenb, &T);
num_apples = 0;
REP(i,H)REP(j,W){
char buf[10];
scanf("%s", buf);
if(buf[0]=='H' && buf[1]=='O'){
mp[i][j] = HOLE;
}else if(buf[0] == 'E'){
mp[i][j] = EMPTY;
}else if(buf[0] == 'A'){
mp[i][j] = APPLE;
num_apples++;
}else if(buf[0] == 'B'){
mp[i][j] = BODY;
}else{
mp[i][j] = HEAD;
}
}
REP(i,lena){
int x,y;
scanf("%d %d", &x, &y);
snake[0].push_back(pii(y,x));
}
REP(i,lenb){
int x,y;
scanf("%d %d", &x, &y);
snake[1].push_back(pii(y,x));
}
}
void output(int dir){
printf("%d\n", dir);
}
int montecarlo(){
auto start = chrono::system_clock::now();
MCTSNode *root = new MCTSNode();
root->turn = 0;
root->rest_turn = 2*(T+1);
root->cnt = root->win = 0;
root->parent = NULL;
fill(root->child, root->child+4, (MCTSNode*)NULL);
while(true){
auto end = chrono::system_clock::now();
double elapsed = chrono::duration_cast<chrono::milliseconds>(end-start).count();
if(elapsed > TL)break;
State state;
state.turn = 0;
state.rest_turn = 2*(T+1);
state.map_hole.assign(H*W, false);
state.map_apple.assign(H*W, false);
REP(i,H)REP(j,W){
if(mp[i][j] == APPLE){
state.map_apple[i*W+j] = true;
state.apples.insert(pii(i,j));
}else if(mp[i][j] != EMPTY){
state.map_hole[i*W+j] = true;
}
if(mp[i][j] == EMPTY){
state.empties.insert(pii(i,j));
}
}
copy(ALL(snake[0]), back_inserter(state.snakes[0]));
copy(ALL(snake[1]), back_inserter(state.snakes[1]));
MCTSNode *cur = root;
bool stopend = false;
int win = -1;
while(cur->child[0] != NULL && cur->rest_turn != 0){
// go to child
// random choice
// using UCT: win/cnt + c sqrt(ln root->cnt / cnt)
double best_score = -1e18;
vi choices;
REP(i,4){
if(!state.canstep(i))continue;
double score = 0.0;
if(cur->child[i]->cnt > 0)score += 1.0 - cur->child[i]->win / cur->child[i]->cnt;
if(cur->child[i]->cnt > 0){
score += 1.0 * sqrt(2.0 * log(cur->cnt) / cur->child[i]->cnt);
}else{
score += 1e18;
}
if(score > best_score){
choices.clear();
best_score = score;
}
if(score >= best_score){
choices.push_back(i);
}
}
if(choices.size() == 0){
win = state.turn ^ 1;
stopend = true;
break;
}
int choice = choices[mt() % choices.size()];
cur = cur->child[choice];
state.step(choice);
}
// leaf node
double score = 1.0;
if(stopend){
//
}else if(cur->rest_turn == 0){
// done
int lena = state.snakes[0].size();
int lenb = state.snakes[1].size();
if(lena>lenb){
win = 0;
}else if(lena<lenb){
win = 1;
}
}else{
// grow leaf
if(cur->cnt == GROW_TH){
REP(i,4){
MCTSNode *child = cur->child[i] = new MCTSNode();
child->turn = cur->turn^1;
child->rest_turn = cur->rest_turn-1;
child->cnt = child->win = 0;
child->parent = cur;
fill(child->child, child->child+4, (MCTSNode*)NULL);
}
}
// random choice
vi choices;
REP(i,4)if(state.canstep(i))choices.push_back(i);
if(choices.size() == 0){
win = state.turn ^ 1;
}else{
int choice = choices[mt() % choices.size()];
if(cur->child[choice] != NULL)cur = cur->child[choice];
state.step(choice);
// playout
while(state.rest_turn != 0){
int hy,hx;
tie(hy,hx) = state.snakes[state.turn][0];
vector<pair<double,int>> choices;
REP(i,4)if(state.canstep(i)){
double score = 0.0;
int sy = hy + dy[i];
int sx = hx + dx[i];
for(pii apple : state.apples){
int py, px;
tie(py, px) = apple;
int dist = abs(sy-py) + abs(sx-px);
score += pow((double)dist, 0.5);
// score = min(score, (double)dist);
}
choices.push_back(make_pair(score, i));
}
sort(ALL(choices));
if(choices.size() == 0){
win = state.turn ^ 1;
break;
}
int sel = mt()%choices.size();
sel = mt()%(sel+1);
sel = mt()%(sel+1);
int choice = choices[sel].second;
state.step(choice);
}
if(win == -1){
int lena = state.snakes[0].size();
int lenb = state.snakes[1].size();
if(lena>lenb){
win = 0;
score = lena;
}else if(lena<lenb){
win = 1;
score = lenb;
}
}
}
}
while(cur != NULL){
cur->cnt += 1.0;
if(cur->turn == win){
cur->win += 1.0;
}else if(win == -1){
cur->win += 0.5;
}
cur = cur->parent;
}
}
// REP(i,4)printf("%d: %f / %f = %f\n",i, root->child[i]->win, root->child[i]->cnt, root->child[i]->win / root->child[i]->cnt);
// int best = -1;
// REP(i,4){
// if(best == -1 && root->child[i]->cnt > 0){
// best = i;
// }else if(root->child[i]->cnt > 0 && root->child[i]->win / root->child[i]->cnt < root->child[best]->win / root->child[best]->cnt){
// best = i;
// }
// }
int best = 0;
REP(i,4)if(root->child[i]->cnt > root->child[best]->cnt)best=i;
return best;
}
int main(){
input();
mt = mt19937(random_device{}());
output(montecarlo());
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment