题意
如果一个字符串可以被拆分为 $\text{AABB}$ 的形式,其中 $\text{A}$ 和 $\text{B}$ 是任意非空字符串,则我们称该字符串的这种拆分是优秀的。
一个字符串的权值定义为这个字符串优秀的拆分的方案数。求一个长度为 $n$ 的字符串所有子串的权值和。
分析
首先容易发现我们可以枚举分界点。定义 $L_i,R_i$ 表示以 $i$ 为左右端点的 $\text{AA}$ 串的数量。那么我们就是求 $\sum_{i=1}^{n-1} R_i \times L_{i+1}$。现在我们只要快速求 $L_i,R_i$。
如果你直接 $O(n^2)$ 枚举,$\text{hash}$ 判断,你就可以有 $95 \text{pts}$ 的好成绩。
剩下 $5$ 分呢?下载数据打表。
这个点我思考一夜无果,看完题解感觉十分巧妙。
考虑枚举一个 $len$,表示 $\text{AA}$ 串一个 $\text{A}$ 的长度。那么这个 $\text{AA}$ 串一定包含两个 $i \times len$ 的点,且这两个点距离为 $len$。设这两个点是 $x,y$。然后考虑计算这个 $\text{AA}$ 的位置。因为这是 $\text{SA}$ 的标签,所以一定和 $\text{lcp}$ 有关。
考虑算他的 $lcs(s[1,x],s[1,y]),lcp(s[x,n],[y,n])$,记为 $Lcs,Lcp$。显然的,如果 $Lcs + Lcp \le len$,那么一定没有这样的子串。而如果有,因为要一样,所以左端点 $l \in [x-Lcs+1,x+Lcp-len+1]$,右端点同样可以算。直接差分即可。这一部分枚举是调和级数,$O(n\log n)$。
至于求 $Lcs,Lcp$ 的话,用 $\text{SA+ST 表}$ 可以做到 $O(1)$ 的查询,总时间复杂度是 $O(n\log n)$ 的。但是我们并不想写 $\text{SA}$,所以用了无耻的 $\text{hash}$,时间复杂度就是 $O(n\log^2n)$,比较劣,但是可以过。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define il inline
#define N 100005
il int rd(){
int s = 0, w = 1;
char ch = getchar();
for (;ch < '0' || ch > '9'; ch = getchar()) if (ch == '-') w = -1;
for (;ch >= '0' && ch <= '9'; ch = getchar()) s = ((s << 1) + (s << 3) + ch - '0');
return s * w;
}
const int P = 1e9 + 9, base = 237;
int n, hsh[N], pbase[N], L[N], R[N];
char s[N];
il int hashh(int l, int r){return (hsh[r] - 1ll * hsh[l - 1] * pbase[r - l + 1] % P + P) % P;}
il int lcs(int i, int j){
int l = 1, r = min(i, j), ans = 0;
while (l <= r){
int mid = (l + r) >> 1;
if (hashh(i - mid + 1, i) == hashh(j - mid + 1, j)) l = mid + 1, ans = mid;
else r = mid - 1;
}
return ans;
}
il int lcp(int i, int j){
int l = 1, r = min(n - i + 1, n - j + 1), ans = 0;
while (l <= r){
int mid = (l + r) >> 1;
if (hashh(i, i + mid - 1) == hashh(j, j + mid - 1)) l = mid + 1, ans = mid;
else r = mid - 1;
}
return ans;
}
int Main(){
scanf ("%s", s + 1), n = strlen(s + 1);
for (int i = 1; i <= n; i++) hsh[i] = (1ll * hsh[i - 1] * base % P + s[i]) % P;
for (int i = 0; i <= n + 1; i++) L[i] = R[i] = 0;
for (int len = 1; len <= n; len++){
for (int x = len, y = len + len; y <= n; x += len, y += len){
int Lcs = min(lcs(x, y), len), Lcp = min(lcp(x, y), len);
int xll = x - Lcs + 1, xlr = x + Lcp - len;
int xrl = y - Lcs + len, xrr = y + Lcp - 1;
if (Lcs + Lcp <= len) continue;
L[xll]++, L[xlr + 1]--, R[xrl]++, R[xrr + 1]--;
}
}
for (int i = 1; i <= n + 1; i++) L[i] += L[i - 1], R[i] += R[i - 1];
ll ans = 0;
for (int i = 1; i <= n; i++) ans += 1ll * R[i] * L[i + 1];
printf ("%lld\n", ans);
return 0;
}
int main(){
pbase[0] = 1;
for (int i = 1; i <= N - 5; i++) pbase[i] = 1ll * pbase[i - 1] * base % P;
for (int T = rd(); T--;) Main();
return 0;
}