Skip to content

Instantly share code, notes, and snippets.

@meooow25
Created April 15, 2018 01:17
Show Gist options
  • Save meooow25/ba3316e7455cf65544792bc1e2f06778 to your computer and use it in GitHub Desktop.
Save meooow25/ba3316e7455cf65544792bc1e2f06778 to your computer and use it in GitHub Desktop.
#include <bits/stdc++.h>
using namespace std;
const int MAX_VAL = 1e6 + 5;
const int MOD = 1e9 + 7;
long long n, m;
bool comp[MAX_VAL];
long long phi[MAX_VAL], phi2[MAX_VAL];
vector<int> pr;
void sieve() {
phi[1] = phi2[1] = 1;
for (int i = 2; i < MAX_VAL; i++) {
if (!comp[i]) {
pr.push_back(i);
phi[i] = i - 1;
phi2[i] = (1LL * i * i - 1) % MOD;
}
for (int j = 0; j < pr.size() && i * pr[j] < MAX_VAL; j++) {
comp[i * pr[j]] = true;
if (i % pr[j] == 0) {
phi[i * pr[j]] = phi[i] * pr[j];
phi2[i * pr[j]] = phi2[i] * pr[j] % MOD * pr[j] % MOD;
break;
} else {
phi[i * pr[j]] = phi[i] * phi[pr[j]];
phi2[i * pr[j]] = phi2[i] * phi2[pr[j]] % MOD;
}
}
}
}
long long inv(long long a) {
int p = MOD - 2; long long r = 1;
while (p) {
if (p & 1) r = r * a % MOD;
a = a * a % MOD; p >>= 1;
}
return r;
}
long long calcf(long long *g) {
if (n > m) swap(n, m);
long long s0, s1, s2, s3;
s0 = s1 = s2 = s3 = 0;
n--; m--;
for (int k = 1; k <= n; k++) {
long long n1 = n / k, m1 = m / k;
long long n2 = n1 * (n1 + 1) / 2 % MOD, m2 = m1 * (m1 + 1) / 2 % MOD;
s0 += n1 * m1 % MOD * g[k] % MOD;
s1 += n1 * m2 % MOD * k % MOD * g[k] % MOD;
s2 += n2 * m1 % MOD * k % MOD * g[k] % MOD;
s3 += n2 * m2 % MOD * k % MOD * k % MOD * g[k] % MOD;
}
n++; m++;
s0 = s0 % MOD * n % MOD * m % MOD;
s1 = s1 % MOD * n % MOD;
s2 = s2 % MOD * m % MOD;
return (s0 - s1 - s2 + s3) % MOD;
}
long long calcsump1p2c(int mode) {
long long s0, s1, sum; s0 = s1 = 0;
if (mode == 1) {
for (int x = 1; x < n; x++)
s0 += (n - x) * x % MOD;
for (int y = 1; y < m; y++)
s1 += (m - y) * y % MOD;
} else {
for (int x = 1; x < n; x++)
s0 += (n - x) * x % MOD * x % MOD;
for (int y = 1; y < m; y++)
s1 += (m - y) * y % MOD * y % MOD;
}
sum = s0 % MOD * 2 * m + s1 % MOD * 2 * n + 4 * calcf(mode == 1 ? phi : phi2);
return sum % MOD;
}
long long solve() {
sieve();
long long allsum, allcnt, badsum, badcnt, ans;
long long sump1p2c = calcsump1p2c(1), sump1p2c2 = calcsump1p2c(2);
allsum = 3 * n * m % MOD * sump1p2c % MOD;
allcnt = n % MOD * n % MOD * n % MOD * m % MOD * m % MOD * m % MOD;
badsum = 6 * sump1p2c2 % MOD;
badcnt = 3 * sump1p2c + n * m % MOD;
ans = (allsum - badsum) * inv(allcnt - badcnt) % MOD;
return (ans + MOD) % MOD;
}
int main() {
ios_base::sync_with_stdio(false); cin.tie(NULL);
cin >> n >> m;
cout << solve() << "\n";
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment