Skip to content

Instantly share code, notes, and snippets.

/gist:5943298

Created Jul 7, 2013
Embed
What would you like to do?
#include <algorithm>
#include <bitset>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <deque>
#include <functional>
#include <iomanip>
#include <iostream>
#include <list>
#include <map>
#include <numeric>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <utility>
#include <vector>
#include <cstring>
#define mod 1000000007
#define maxn 30
#define maxf 1000000
#define maxk 5
using namespace std;
int n,K;
int a[maxn];
int hole[maxk];
map< pair<int,int>,int > ID;
int IDcnt = 0;
map< pair<int,int>,int > cnt[maxf];
long long perm[maxn],ans = 0;
int PX,PY;
bool mark[maxn];
int getID(int alpha,int beta) {
pair<int,int> D = make_pair(alpha,beta);
if (!ID.count(D)) ID[D] = IDcnt++;
return ID[D];
}
inline void generate(int low,int high,int pos,int n1,int s1,int n2,int s2) {
if (pos > high) {
int idx = getID(s1,s2);
cnt[idx][make_pair(n1,n2)]++;
return;
}
generate(low,high,pos + 1,n1 + 1,s1 + a[pos],n2,s2);
generate(low,high,pos + 1,n1,s1,n2 + 1,s2 + a[pos]);
generate(low,high,pos + 1,n1,s1,n2,s2);
}
inline void recurse(int low,int high,int pos,int n1,int s1,int n2,int s2) {
if (pos > high) {
if (!ID.count(make_pair(PX - s1,PY - s2))) return;
int idx = ID[make_pair(PX - s1,PY - s2)];
for (map< pair<int,int>,int >::iterator it = cnt[idx].begin(); it != cnt[idx].end(); it++) {
int l1 = it->first.first,l2 = it->first.second,val = it->second;
long long prod = (perm[l1 + n1] * perm[l2 + n2]) % mod;
prod = (prod * perm[n - (l1 + n1) - (l2 + n2)]) % mod;
ans = (ans + prod * val) % mod;
}
return;
}
if (s1 + a[pos] <= PX) recurse(low,high,pos + 1,n1 + 1,s1 + a[pos],n2,s2);
if (s2 + a[pos] <= PY) recurse(low,high,pos + 1,n1,s1,n2 + 1,s2 + a[pos]);
recurse(low,high,pos + 1,n1,s1,n2,s2);
}
long long compute(int X,int Y) {
PX = X; PY = Y;
ans = 0;
recurse(n/2,n - 1,n/2,0,0,0,0);
return ans;
}
void backtrack(int pos,long long sum) {
if (pos >= n) {
ans++;
return;
}
for (int i = 0; i < n; i++) if (!mark[i]) {
bool flag = true;
for (int j = 0; j < K; j++) if (sum + a[i] == hole[j]) flag = false;
if (!flag) continue;
mark[i] = true;
backtrack(pos + 1,sum + a[i]);
mark[i] = false;
}
}
int main() {
cin >> n;
perm[0] = 1;
for (int i = 1; i <= n; i++) perm[i] = (perm[i - 1] * i) % mod;
for (int i = 0; i < n; i++) cin >> a[i];
cin >> K;
for (int i = 0; i < K; i++) cin >> hole[i];
sort(hole,hole + K);
if (n <= 10) {
ans = 0;
backtrack(0,0);
cout << ans << endl;
return 0;
}
generate(0,n/2 - 1,0,0,0,0,0);
long long ret = perm[n];
if (K > 0) {
ret -= compute(hole[0],0);
if (K > 1) ret -= compute(hole[1],0);
}
if (K > 1) ret += compute(hole[0],hole[1] - hole[0]);
ret %= mod;
if (ret < 0) ret += mod;
cout << ret << endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.