Created
March 18, 2022 04:39
-
-
Save math314/56dd73a27c3b0f23935cc592b3f527da to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define _CRT_SECURE_NO_WARNINGS | |
#include <cstdio> | |
#include <cstdlib> | |
#include <cstring> | |
#include <cmath> | |
#include <limits> | |
#include <ctime> | |
#include <cassert> | |
#include <map> | |
#include <set> | |
#include <iostream> | |
#include <memory> | |
#include <string> | |
#include <vector> | |
#include <algorithm> | |
#include <functional> | |
#include <sstream> | |
#include <stack> | |
#include <queue> | |
#include <numeric> | |
#include <iterator> | |
#include <bitset> | |
#include <unordered_map> | |
#include <unordered_set> | |
using namespace std; | |
#define FOR(i,n) for(int i = 0; i < (n); i++) | |
#define sz(c) ((int)c.size()) | |
#define ten(n) ((int)1e##n) | |
using ll = long long; | |
using ull = unsigned long long; | |
using Pii = pair<int, int>; | |
using Pll = pair<ll, ll>; | |
template<typename ...> static inline int getchar_unlocked(void) { return getchar(); } | |
template<typename ...> static inline void putchar_unlocked(int c) { putchar(c); } | |
#define mygc(c) (c)=getchar_unlocked() | |
#define mypc(c) putchar_unlocked(c) | |
void reader(int& x) { int k, m = 0; x = 0; for (;;) { mygc(k); if (k == '-') { m = 1; break; }if ('0' <= k && k <= '9') { x = k - '0'; break; } }for (;;) { mygc(k); if (k < '0' || k>'9')break; x = x * 10 + k - '0'; }if (m) x = -x; } | |
void reader(ll& x) { int k, m = 0; x = 0; for (;;) { mygc(k); if (k == '-') { m = 1; break; }if ('0' <= k && k <= '9') { x = k - '0'; break; } }for (;;) { mygc(k); if (k < '0' || k>'9')break; x = x * 10 + k - '0'; }if (m) x = -x; } | |
int reader(char c[]) { int i, s = 0; for (;;) { mygc(i); if (i != ' ' && i != '\n' && i != '\r' && i != '\t' && i != EOF) break; }c[s++] = i; for (;;) { mygc(i); if (i == ' ' || i == '\n' || i == '\r' || i == '\t' || i == EOF) break; c[s++] = i; }c[s] = '\0'; return s; } | |
int reader(string& c) { c.clear(); int i; for (;;) { mygc(i); if (i != ' ' && i != '\n' && i != '\r' && i != '\t' && i != EOF) break; }c.push_back(i); for (;;) { mygc(i); if (i == ' ' || i == '\n' || i == '\r' || i == '\t' || i == EOF) break; c.push_back(i); }; return sz(c); } | |
template <class T, class S> void reader(T& x, S& y) { reader(x); reader(y); } | |
template <class T, class S, class U> void reader(T& x, S& y, U& z) { reader(x); reader(y); reader(z); } | |
template <class T, class S, class U, class V> void reader(T& x, S& y, U& z, V& w) { reader(x); reader(y); reader(z); reader(w); } | |
void writer(int x, char c) { int s = 0, m = 0; char f[10]; if (x < 0)m = 1, x = -x; while (x)f[s++] = x % 10, x /= 10; if (!s)f[s++] = 0; if (m)mypc('-'); while (s--)mypc(f[s] + '0'); mypc(c); } | |
void writer(ll x, char c) { int s = 0, m = 0; char f[20]; if (x < 0)m = 1, x = -x; while (x)f[s++] = x % 10, x /= 10; if (!s)f[s++] = 0; if (m)mypc('-'); while (s--)mypc(f[s] + '0'); mypc(c); } | |
void writer(const char c[]) { int i; for (i = 0; c[i] != '\0'; i++)mypc(c[i]); } | |
void writer(const string& x, char c) { int i; for (i = 0; x[i] != '\0'; i++)mypc(x[i]); mypc(c); } | |
void writer(const char x[], char c) { int i; for (i = 0; x[i] != '\0'; i++)mypc(x[i]); mypc(c); } | |
template<class T> void writerLn(T x) { writer(x, '\n'); } | |
template<class T, class S> void writerLn(T x, S y) { writer(x, ' '); writer(y, '\n'); } | |
template<class T, class S, class U> void writerLn(T x, S y, U z) { writer(x, ' '); writer(y, ' '); writer(z, '\n'); } | |
template<class T, class S, class U, class V> void writerLn(T x, S y, U z, V v) { writer(x, ' '); writer(y, ' '); writer(z, ' '); writer(v, '\n'); } | |
template<class T> void writerArr(T x[], int n) { if (!n) { mypc('\n'); return; }FOR(i, n - 1)writer(x[i], ' '); writer(x[n - 1], '\n'); } | |
template<class T> void writerArr(vector<T>& x) { writerArr(x.data(), (int)x.size()); } | |
void reader(ull& x) { int k, m = 0; x = 0; for (;;) { mygc(k); if (k == '-') { m = 1; break; }if ('0' <= k && k <= '9') { x = ull(k - '0'); break; } }for (;;) { mygc(k); if (k < '0' || k>'9')break; x = x * 10 + k - '0'; } } | |
void writer(ull x, char c) { int s = 0, m = 0; char f[20]; while (x)f[s++] = x % 10, x /= 10; if (!s)f[s++] = 0; if (m)mypc('-'); while (s--)mypc(f[s] + '0'); mypc(c); } | |
template<class T> vector<T> readerArray(int n) { vector<T> ret(n); for (int i = 0; i < n; i++) reader(ret[i]); return ret; } | |
template<class T> void chmin(T& a, const T& b) { if (a > b) a = b; } | |
template<class T> void chmax(T& a, const T& b) { if (a < b) a = b; } | |
template<class T> T gcd(T a, T b) { return (b != 0) ? gcd(b, a % b) : a; } | |
template<class T> T lcm(T a, T b) { return a / gcd(a, b) * b; } | |
ll mod_pow(ll a, ll n, ll mod) { | |
ll ret = 1; | |
ll p = a % mod; | |
while (n) { | |
if (n & 1) ret = ret * p % mod; | |
p = p * p % mod; | |
n >>= 1; | |
} | |
return ret; | |
} | |
template<class T> T extgcd(T a, T b, T& x, T& y) { for (T u = y = 1, v = x = 0; a;) { T q = b / a; swap(x -= q * u, u); swap(y -= q * v, v); swap(b -= q * a, a); } return b; } | |
template<class T> T mod_inv(T a, T m) { T x, y; extgcd<T>(a, m, x, y); return (m + x % m) % m; } | |
#ifdef _MSC_VER | |
// #ifdef _DEBUG | |
template <typename ... Args> | |
void debugPrintf(const char* format, Args const& ... args) { | |
fprintf(stderr, format, args ...); | |
fflush(stderr); | |
} | |
#else | |
#define debugPrintf(...) | |
#endif | |
ll nSqrt(ll x) { | |
ll p = (ll)sqrt(x); | |
while ((p + 1) * (p + 1) <= x) p++; | |
while (p * p > x) p--; | |
return p; | |
} | |
Pll sqrt2DiophantineApproximation(int cnt) { | |
ll q = 0, p = 1; | |
FOR(i, cnt) { | |
ll np = 2 * p + q; | |
ll nq = p; | |
q = nq; | |
p = np; | |
} | |
return Pll(q + p, p); | |
} | |
ll naive(ll n) { | |
ll ans = 0; | |
for (ll i = 1; i <= n; i++) { | |
ll val = nSqrt(ll(i) * i * 2); | |
ans += val; | |
} | |
return ans; | |
} | |
bool isLargerThanSqrt2(Pll qp) { | |
return 2 * qp.second * qp.second < qp.first * qp.first; | |
} | |
ll sumDiophantineApproximatedLatticePoints(Pll qp) { | |
ll s2 = qp.first * qp.second; | |
ll b = qp.first + qp.second + 1; | |
ll i = (s2 - b) / 2 + 1; | |
bool larger = isLargerThanSqrt2(qp); | |
if (larger) { | |
debugPrintf("[*] %lld/%lld is slightly larger than sqrt(2)\n", qp.first, qp.second); | |
} | |
return i + qp.first - larger; | |
} | |
ll solve(ll n) { | |
vector<Pll> points; | |
for (int i = 2;; i++) { | |
Pll cur = sqrt2DiophantineApproximation(i); | |
points.push_back(cur); | |
if (cur.second > n) break; | |
} | |
ll ans = 0; | |
ll cur = 0; | |
bool isSq2SlightlySmallerThanInteger = false; | |
while (n - cur > 20) { | |
int id = 0; | |
while (points[id + 1].second < n - cur) id++; | |
ll baseLine = nSqrt(2 * cur * cur); | |
if (isSq2SlightlySmallerThanInteger) { | |
baseLine++; | |
} | |
const ll dalAns = sumDiophantineApproximatedLatticePoints(points[id]); | |
const ll add = dalAns + baseLine * points[id].second; | |
printf(" cumulated range (%lld, %lld]. baseLine = %lld, dalAns = %lld, add = %lld, isSq2SlightlySmallerThanInteger = %d\n", | |
cur, cur + points[id].second, baseLine, dalAns, add, isSq2SlightlySmallerThanInteger); | |
if (isSq2SlightlySmallerThanInteger) { | |
printf(" baseline was incremented since the baseLine is slightly smaller than integer.\n"); | |
} | |
isSq2SlightlySmallerThanInteger = isLargerThanSqrt2(points[id]); | |
ans += add; | |
cur += points[id].second; | |
} | |
if (cur != 0) { | |
cur++; // x = cur was calculated above | |
} | |
while (cur <= n) { | |
ll val = nSqrt(2 * cur * cur); | |
ans += val; | |
cur++; | |
} | |
return ans; | |
} | |
void compare(int i) { | |
puts("--------------------"); | |
printf("answer for %d\n", i); | |
ll ans = naive(i); | |
ll a2 = solve(i); | |
printf("naive = %lld, solve = %lld\n", ans, a2); | |
if (ans != a2) { | |
cout << "?"; | |
} | |
} | |
int main() { | |
//compare(77); | |
//for(int i = 10; i < 500; i++) { | |
// compare(i); | |
//} | |
for (int i = 1000; i < ten(7); i *= 2) { | |
compare(i); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment