全是套路,全都不会,但关键点这一转化确实比较难想。
下文记 $pre_i$ 表示 $[1,i]$ 这一段字符串前缀,$suf_i$ 表示 $[i,n]$ 这一段字符串后缀,$S[i,j]$ 表示 $[i,j]$ 这段子串。
最开始想过设什么 $f_{i,j}$ 表示 $[i,j]$ 是否是 $\texttt{AA}$ 型字符串,$g_{i,j}$ 表示 $[i,j]$ 的优秀拆分方案数,但这样 $O(n^2)$ 的做法很难优化,状态就已经二维了。
所以设 $l_i$ 表示以 $i$ 为左端点的 $\texttt{AA}$ 型字符串数量,$r_i$ 表示以 $i$ 为右端点的 $\texttt{AA}$ 型字符串数量。
答案即为 $\sum\limits_{i=1}^{n-1} r_i \times l_{i+1}$。
有很多方法预处理 $l,r$,但是有些并不好优化,因此考虑下面这种预处理方法:
既然正向的都不好做,那就反向枚举 $\texttt{AA}$ 中 $\texttt{A}$ 的长度 $len$。
枚举点对 $(i,j)$,其中 $i+len=j$。
若满足 $\text{LCP}(suf_i,suf_j) \geq len$,那么显然有 $S[i,j-1]=S[j,j+len-1]$,因此 $r_{j+len-1} \leftarrow r_{j+len-1} + 1$,$l$ 数组的计算同理。
其中 LCP 可以用后缀数组 $O(1)$ 求解,枚举 $len,i$ 复杂度为 $O(n^2)$。
如何优化?
考虑设“关键点”。
还是枚举 $len$,每间隔 $len-1$ 个标记一个“关键点”,那么 $\texttt{A}$ 子串一定至少经过一个关键点。
考虑求出 $\text{LCS}(pre_i,pre_j)$ 和 $\text{LCP}(suf_i,suf_j)$,如果两者的和 $\geq len$ 那么显然可以组成 $\texttt{AA}$ 型字符串。
然后发现 $l,r$ 每次都是一段区间加上 $1$,可以用差分维护。
至于前后缀的 $\text{LCS}$ 和 $\text{LCP}$,可以对原串和反串跑 SA。
这样枚举 $len$ 是调和级数,预处理 SA 是 $O(n \log n)$,因此总复杂度为 $O(n \log n)$,数据范围还是开太小了。
由于 LCP 和 LCS 可能很长,会导致重复计算,因此要和 $len$ 取 $\min$。
#include <bits/stdc++.h>
using namespace std;
const int N = 3e4 + 5;
int T, n;
char s[N];
struct Str {
int n;
char s[N];
int sa[N], rk[N], x[N << 1], y[N << 1], cnt[N];
void init_SA() {
int v = 128;
for (int i = 0; i <= v; i++) cnt[i] = 0;
for (int i = 1; i <= n; i++) ++cnt[x[i] = (int)s[i]];
for (int i = 1; i <= v; i++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i--) sa[ cnt[x[i]]-- ] = i;
for (int len = 1; ; len <<= 1) {
int tot = 0;
for (int i = n - len + 1; i <= n; i++) y[++tot] = i;
for (int i = 1; i <= n; i++)
if (sa[i] > len) y[++tot] = sa[i] - len;
for (int i = 0; i <= v; i++) cnt[i] = 0;
for (int i = 1; i <= n; i++) ++cnt[x[i]];
for (int i = 1; i <= v; i++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i--) sa[ cnt[x[y[i]]]-- ] = y[i], y[i] = 0;
swap(x, y), tot = 0;
for (int i = 1; i <= n; i++)
if (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + len] == y[sa[i - 1] + len]) x[sa[i]] = tot;
else x[sa[i]] = ++tot;
v = tot;
if (v == n) break;
}
}
int ht[N];
void init_height() {
for (int i = 1; i <= n; i++) rk[sa[i]] = i;
for (int i = 1, j = 0, now = 0; i <= n; i++) {
if (rk[i] == 1) { ht[rk[i]] = 0; continue; }
j = sa[rk[i] - 1], now = max(0, now - 1);
while (i + now <= n && j + now <= n && s[i + now] == s[j + now]) now++;
ht[rk[i]] = now;
}
}
int mn[N][17], lg[N];
void init_ST() {
lg[1] = 0; for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
for (int i = 1; i <= n; i++) mn[i][0] = ht[i];
for (int i = 1; i <= 15; i++)
for (int j = 1; j + (1 << i) - 1 <= n; j++)
mn[j][i] = min(mn[j][i - 1], mn[j + (1 << i - 1)][i - 1]);
}
inline int query(int l, int r) {
if (l > r) swap(l, r);
int len = lg[r - l + 1];
return min(mn[l][len], mn[r - (1 << len) + 1][len]);
}
inline int LCP(int a, int b) { // a, b 两后缀的 LCP
a = rk[a], b = rk[b];
if (a == b) return n - sa[a] + 1;
if (a > b) swap(a, b);
return query(a + 1, b);
}
void init() { init_SA(), init_height(), init_ST(); }
void clr() {
for (int i = 0; i <= n + 3; i++) sa[i] = rk[i] = ht[i] = 0;
for (int i = 0; i <= n * 2 + 3; i++) x[i] = y[i] = 0;
n = 0;
}
} s1, s2; //原串、反串
inline int LCP(int a, int b) { return s1.LCP(a, b); } // a, b 两后缀的 LCP
inline int LCS(int a, int b) { return s2.LCP(n - a + 1, n - b + 1); } // a, b 两前缀的 LCS
long long ans = 0;
int l[N], r[N];
void solve() {
scanf("\n %s", s + 1), n = strlen(s + 1);
s1.n = s2.n = n;
for (int i = 1; i <= n; i++) s1.s[i] = s[i], s2.s[n - i + 1] = s[i];
s1.init(), s2.init();
for (int len = 1; len <= n / 2; len++) {
for (int i = 1, j; i + len <= n; i += len) {
j = i + len;
int lcp = min(len, LCP(i, j) ), lcs = min(len, LCS(i, j) );
if (lcp + lcs < len) continue;
int ll = i - lcs + 1, lr = i + lcp - len; // l 数组
int rl = j - lcs + len, rr = j + lcp - 1; // r 数组
if (ll <= lr) l[ll]++, l[lr + 1]--;
if (rl <= rr) r[rl]++, r[rr + 1]--;
}
}
for (int i = 1; i <= n; i++) l[i] += l[i - 1], r[i] += r[i - 1];
for (int i = 1; i < n; i++) ans += r[i] * 1ll * l[i + 1];
printf("%lld\n", ans);
s1.clr(), s2.clr();
for (int i = 0; i <= n + 3; i++) l[i] = r[i] = 0; ans = 0ll;
}
int main() {
scanf("%d", &T);
while (T--) solve();
return 0;
}
数据太小指的是我的 hash + 二分 $O(n\log^2n)$ 过了吗
是指单 $\log$ 居然只开 $3 \times 10^4$ 而不是 $5 \times 10^5$。
就算对于双 $\log$ 标程没开 $10^5$ 也说不过去啊。