Create a gist now

Instantly share code, notes, and snippets.

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int MOD = int(1e9) + 7;
//素数列挙
bool prime[1000001]; //10^6
vector<int> prs;
void init_prime() {
memset(prime, 1, sizeof(prime));
prime[0] = prime[1] = false;
for (int i = 2; i < sizeof(prime); i++) if (prime[i])
for (int j = i * 2; j < sizeof(prime); j += i) prime[j] = false;
for (int i = 2; i < sizeof(prime); i++) if (prime[i]) prs.push_back(i);
}
ll prime_sum_small_part_memo[sizeof(prime)];
void init_prime_sum_small_part() {
for (int i = 0; i < sizeof(prime) - 1; i++) {
prime_sum_small_part_memo[i + 1] = prime_sum_small_part_memo[i];
if (prime[i + 1]) prime_sum_small_part_memo[i + 1] += i + 1;
}
}
ll prime_sum(ll n) {
int n2 = (int)sqrt(n);
while (ll(n2 + 1) * (n2 + 1) <= n) n2++;
const int k = lower_bound(prs.begin(), prs.end(), n2 + 1) - prs.begin() - 1;
vector<ll> n_div_p;
for (int i = 1; i <= n2; i++) {
n_div_p.push_back(i);
n_div_p.push_back(n / i);
}
sort(n_div_p.begin(), n_div_p.end());
n_div_p.erase(unique(n_div_p.begin(), n_div_p.end()), n_div_p.end());
vector<int> ndp_table(n2 + 1);
for (auto x : n_div_p) if (x <= n2) ndp_table[(int)x]++;
for(int i = 0; i < n2; i++) ndp_table[i + 1] += ndp_table[i];
vector<ll> dp;
for (auto x : n_div_p) {
if (x % 2 == 0) dp.push_back((x / 2 % MOD)*((x + 1) % MOD) % MOD - 1);
else dp.push_back((x % MOD)*((x + 1) / 2 % MOD) % MOD - 1);
}
for(int i = 0; i <= k; i++) {
const ll p = prs[i];
for (int j = int(n_div_p.size()) - 1; j >= 0; j--) {
const ll x = n_div_p[j];
if (x < p * p) break;
const ll np = x / p;
int n_div_p_idx;
if (np <= n2) n_div_p_idx = ndp_table[(int)np] - 1;
else n_div_p_idx = int(n_div_p.size()) - ndp_table[(int)(n / np)];
assert(n_div_p[n_div_p_idx] == np);
(dp[j] -= p * (dp[n_div_p_idx] - prime_sum_small_part_memo[p - 1])) %= MOD;
}
}
return (dp.back() % MOD + MOD) % MOD;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment