Skip to content

Instantly share code, notes, and snippets.

@justiceHui
Created December 28, 2024 11:47
Show Gist options
  • Select an option

  • Save justiceHui/812a34497dafe5d659e1090789803e89 to your computer and use it in GitHub Desktop.

Select an option

Save justiceHui/812a34497dafe5d659e1090789803e89 to your computer and use it in GitHub Desktop.
#include "testlib.h"
#include <bits/stdc++.h>
using namespace std;
// http://boj.kr/cf68deedf9b44ba1b7317778f47aece2 603-690 변형
// http://boj.kr/42e2f7ab27004489bbf98a8866855a4c 603-665로 검증함
namespace parser{
enum{ ADD, MUL, NUM, LEP, RIP, END } token;
string str; int idx, num; int P;
pair<bool, int> expr();
pair<bool, int> term();
pair<bool, int> factor();
[[nodiscard]] bool get_token(){
if(idx == str.size()){ token = END; return true; }
char c = str[idx++];
if(c == '+') token = ADD;
else if(c == '*') token = MUL;
else if(c == '(') token = LEP;
else if(c == ')') token = RIP;
else if(isdigit(c)){
token = NUM; num = (c - '0') % P;
while(idx < str.size() && isdigit(str[idx])) num = (num * 10 + str[idx] - '0') % P, idx++;
num %= P;
}
else return false;
return true;
}
[[nodiscard]] pair<bool, int> expr(){
auto [flag, res] = term();
if(!flag) return {false, 0};
while(token == ADD){
if(!get_token()) return {false, 0};
auto [f, nxt] = term();
if(!f) return {false, 0};
else res = (res + nxt) % P;
}
return {true, res};
}
[[nodiscard]] pair<bool, int> term(){
auto [flag, res] = factor();
if(!flag) return {false, 0};
while(token == MUL){
if(!get_token()) return {false, 0};
auto [f, nxt] = factor();
if(!f) return {false, 0};
else res = res * nxt % P;
}
return {true, res};
}
[[nodiscard]] pair<bool, int> factor(){
if(token == NUM){
if(!get_token()) return {false, 0};
return {true, num};
}
else if(token == LEP){
if(!get_token()) return {false, 0};
auto [flag, res] = expr();
if(!flag) return {false, 0};
if(token != RIP) return {false, 0};
if(!get_token()) return {false, 0};
return {true, res};
}
else return {false, 0};
}
pair<bool, int> eval(string s, int p){
str = s; idx = 0; P = p;
if(!get_token()) return {false, 0};
auto [flag, val] = expr();
if(!flag) return {false, 0};
if(token != END) return {false, 0};
return {true, val % P};
}
}
void output_assert(InStream &in, bool flag, string result){
if(!flag) in.quit(_wa, result.c_str());
}
// s에 괄호만 추가해서 t를 만든 것이 맞는지 확인
bool check_modify(string s, string t){
int pos = 0;
for(int i=0; i<t.size(); i++){
if(pos < s.size() && t[i] == s[pos]) pos++;
else if(t[i] != '(' && t[i] != ')') return false;
}
return pos == s.size();
}
namespace dp{
struct node{ int bit; char ch; };
vector<node> tokenize(string s, int p){
vector<node> res;
for(int i=0; i<s.size(); i++){
if(!isdigit(s[i])) res.push_back({0, s[i]});
else{
int num = (s[i] - '0') % p;
while(i+1 < s.size() && isdigit(s[i+1])) num = (num * 10 + s[++i] - '0') % p;
res.push_back({1 << num, 0});
}
}
return res;
}
// 괄호 없는 수식의 계산
// dp[i][j] = [i, j]번째 수로 만들 수 있는 수의 집합(bitmask)
int calc(vector<node> v, int p){
for(auto [bit,ch] : v) assert(ch != '(' && ch != ')');
int sz = v.size() / 2 + 1;
vector<vector<int>> dp(sz, vector<int>(sz));
for(int i=0; i<sz; i++) dp[i][i] = v[i*2].bit;
for(int d=1; d<sz; d++){
for(int i=0, j=i+d; j<sz; i++, j++){
for(int k=i; k<j; k++){
char op = v[2*k+1].ch;
for(int x=0; x<p; x++){
if(~dp[i][k] >> x & 1) continue;
for(int y=0; y<p; y++){
if(~dp[k+1][j] >> y & 1) continue;
int val = op == '+' ? (x + y) % p : (x * y) % p;
dp[i][j] |= 1 << val;
}
}
}
}
}
return dp[0][sz-1];
}
int solve(vector<node> v, int p){
int st = -1, ed = -1;
for(int i=0; i<v.size(); i++) if(v[i].ch == '(') { st = i; break; }
// 괄호 먼저 계산
if(st != -1){
int cnt = 1; // st와 매칭되는 괄호 ed 찾기
for(int i=st+1; i<v.size(); i++){
if(v[i].ch == '(') cnt++;
if(v[i].ch == ')') cnt--;
if(cnt == 0){ ed = i; break; }
}
assert(ed != -1);
vector<node> in(v.begin()+st+1, v.begin()+ed);
vector<node> nxt;
nxt.insert(nxt.end(), v.begin(), v.begin()+st);
nxt.push_back({ solve(in, p), 0 });
nxt.insert(nxt.end(), v.begin()+ed+1, v.end());
return solve(nxt, p);
}
return calc(v, p);
}
int run(string t, int p){
auto v = tokenize(t, p);
return solve(v, p);
}
}
void check(int n, int p, string s, InStream &in){
int m = in.readInt(1, 100, "M");
string t = in.readToken("[0-9+*()]{1,100}", "T");
output_assert(in, 1 <= m && m <= 100, "1 <= m <= 100");
output_assert(in, t.size() == m, "|t| != m");
output_assert(in, check_modify(s, t), "invalid modify");
output_assert(in, parser::eval(t, p).first, "invalid expr");
int ans_val = parser::eval(s, p).second;
int bit = dp::run(t, p);
output_assert(in, __builtin_popcount(bit) == 1, "not unique, bitmask: " + to_string(bit));
output_assert(in, bit == (1 << ans_val), "wrong answer, out: " + to_string(__lg(bit)) + " answer: " + to_string(ans_val));
}
int main(int argc, char* argv[]){
registerTestlibCmd(argc, argv);
int n = inf.readInt();
int p = inf.readInt();
string s = inf.readToken();
check(n, p, s, ans);
check(n, p, s, ouf);
quit(_ok, "ok");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment