Skip to content

Instantly share code, notes, and snippets.

@fishy15
Created April 20, 2024 23:06
Show Gist options
  • Save fishy15/34579e79b78e54e488057f0f3d213b85 to your computer and use it in GitHub Desktop.
Save fishy15/34579e79b78e54e488057f0f3d213b85 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <iomanip>
#include <fstream>
#include <vector>
#include <array>
#include <algorithm>
#include <utility>
#include <map>
#include <queue>
#include <set>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <functional>
#include <numeric>
#include <chrono>
#include <random>
#define ll long long
#define ld long double
#define eps 1e-8
#define MOD 1000000007
#define INF 0x3f3f3f3f
#define INFLL 0x3f3f3f3f3f3f3f3f
// change if necessary
#define MAXN 1000000
using namespace std;
pair<vector<int>, vector<int>> manachers(const string &s) {
int n = (int)(s.size());
vector<int> d1(n);
vector<int> d2(n);
int l = 0;
int r = -1;
for (int i = 0; i < n; i++) {
int k = 1;
if (i <= r) k = min(d1[l + r - i], r - i + 1);
while (0 <= i - k && i + k < n && s[i - k] == s[i + k]) k++;
d1[i] = k--;
if (i + k > r) {
l = i - k;
r = i + k;
}
}
l = 0;
r = -1;
for (int i = 0; i < n; i++) {
int k = 0;
if (i <= r) k = min(d2[l + r - i + 1], r - i + 1);
while (0 <= i - k - 1 && i + k < n && s[i - k - 1] == s[i + k]) k++;
d2[i] = k--;
if (i + k > r) {
l = i - k - 1;
r = i + k;
}
}
return {d1, d2};
}
ll modpow(ll n, ll e, ll mod) {
ll res = 1;
while (e) {
if (e & 1) res = res * n % mod;
n = n * n % mod;
e >>= 1;
}
return res;
}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
int randval() {
return uniform_int_distribution<int>(2, 1000000005)(rng);
}
template<int N>
struct string_hash {
struct hash {
int n;
array<int, 2> h;
static inline array<array<int, N>, 2> base, inv;
static constexpr int p[2] = {1000000007, 1000000009};
hash() {
n = 0;
if (!base[0][0]) calc_base();
h = {0, 0};
}
hash(char c) : hash() {
*this += c;
}
hash(const hash &h2) : hash() {
n = h2.n;
h = h2.h;
}
void calc_base() {
for (int i = 0; i < 2; i++) {
base[i][0] = 1;
base[i][1] = randval();
inv[i][0] = 1;
inv[i][1] = modpow(base[i][1], p[i] - 2, p[i]);
}
for (int j = 2; j < N; j++) {
for (int i = 0; i < 2; i++) {
base[i][j] = (ll) base[i][j - 1] * base[i][1] % p[i];
inv[i][j] = (ll) inv[i][j - 1] * inv[i][1] % p[i];
}
}
}
bool operator==(const hash &h2) const {
return n == h2.n && h == h2.h;
}
bool operator<(const hash &h2) const {
if (n == h2.n) return h < h2.h;
return n < h2.n;
}
hash &operator+=(const char c) {
for (int i = 0; i < 2; i++) {
h[i] += (ll) base[i][n] * c % p[i];
if (h[i] >= p[i]) h[i] -= p[i];
}
n++;
return *this;
}
hash &operator-=(const hash &h2) {
for (int i = 0; i < 2; i++) {
h[i] -= h2.h[i];
if (h[i] < 0) h[i] += p[i];
h[i] = (ll) h[i] * inv[i][h2.n] % p[i];
}
n -= h2.n;
return *this;
}
};
using hash_t = hash;
vector<hash> h;
string_hash() : string_hash("") {}
string_hash(const string &s) {
h.resize(s.size());
hash cur;
for (int i = 0; i < (int)(s.size()); i++) {
cur += s[i];
h[i] = cur;
}
}
hash substr(int n, int sz) const {
hash res = h[n + sz - 1];
if (n) res -= h[n - 1];
return res;
}
hash val() const {
return h.back();
}
static hash calc(const string &s) {
hash cur;
for (char c : s) {
cur += c;
}
return cur;
}
};
using shash = string_hash<MAXN>;
int main() {
cin.tie(0)->sync_with_stdio(0);
int n;
cin >> n;
string s;
cin >> s;
auto [d1, d2] = manachers(s);
vector<int> psum(n + 1);
for (int i = 0; i < n; i++) {
psum[i + 1] = psum[i] + (s[i] - 'a' + 1);
}
shash sh(s);
set<shash::hash_t> hashes;
ll sum = 0;
for (int i = 0; i < n; i++) {
int odd_sz = d1[i];
while (odd_sz > 0) {
int l = i - odd_sz + 1;
int sz = 2 * odd_sz - 1;
auto hash = sh.substr(l, sz);
if (!hashes.count(hash)) {
hashes.insert(hash);
sum += psum[l + sz] - psum[l];
odd_sz--;
} else {
break;
}
}
if (i >= 0) {
int even_sz = d2[i];
while (even_sz > 0) {
int l = i - even_sz;
int sz = 2 * even_sz;
auto hash = sh.substr(l, sz);
if (!hashes.count(hash)) {
hashes.insert(hash);
sum += psum[l + sz] - psum[l];
even_sz--;
} else {
break;
}
}
}
}
cout << sum << '\n';
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment