题目描述
给你两个长度都为 n
的字符串 s
和 t
。你可以对字符串 s
执行以下操作:
- 将
s
长度为l
(0 < l < n
)的 后缀字符串 删除,并将它添加在s
的开头。- 比方说,
s = 'abcd'
,那么一次操作中,你可以删除后缀'cd'
,并将它添加到s
的开头,得到s = 'cdab'
。
- 比方说,
给你一个整数 k
,请你返回 恰好 k
次操作将 s
变为 t
的方案数。
由于答案可能很大,返回答案对 10^9 + 7
取余 后的结果。
样例
输入:s = "abcd", t = "cdab", k = 2
输出:2
解释:
第一种方案:
第一次操作,选择 index = 3 开始的后缀,得到 s = "dabc"。
第二次操作,选择 index = 3 开始的后缀,得到 s = "cdab"。
第二种方案:
第一次操作,选择 index = 1 开始的后缀,得到 s = "bcda"。
第二次操作,选择 index = 1 开始的后缀,得到 s = "cdab"。
输入:s = "ababab", t = "ababab", k = 1
输出:2
解释:
第一种方案:
选择 index = 2 开始的后缀,得到 s = "ababab"。
第二种方案:
选择 index = 4 开始的后缀,得到 s = "ababab"。
限制
2 <= s.length <= 5 * 10^5
1 <= k <= 10^15
s.length == t.length
s
和t
都只包含小写英文字母。
算法
(KMP,动态规划,矩阵快速幂) $O(n + \log k)$
- 使用 KMP 算法求出模式字符串 $s$ 的 $next$ 数组。
- 使用 $next$ 数组,让模式字符串 $s$ 匹配字符串 $t + t$。如果找不到匹配的位置,则说明 $s$ 无论如何都变不成 $t$,直接返回 $0$。否则,拿到第一次匹配的位置 $fst$。
- 通过 $next$ 数组,可以判断出字符串 $s$ 最小的循环节。如果 $n \% (n - next(n - 1) - 1) == 0$,则说明存在小于 $n$ 的循环节为 $r = n - next(n - 1) - 1$。
- 现在可以考虑如何求出操作 $k$ 次后,以下标 $i$ 开头的方案数。初始(操作 $0$ 次)时,以下标 $0$ 的开头的方案数为 $1$,其余为 $0$。
- 注意到,除了下标 $0$ 之外,以其他下标开头的方案数都是一样的,所以可以看成一种情况。
- 设 $f(t, 0)$ 表示操作了 $t$ 次后,以下标 $0$ 开头的方案数;$f(t, 1)$ 表示非下标 $0$ 开头的方案数。
- 存在如下转移:$f(t, 0) = (n - 1)f(t - 1, 1)$,$f(t, 1) = f(t - 1, 0) + (n - 2)f(t - 1, 1)$。
- 以上递推过程可以写成一个
2 x 2
的矩阵连乘。矩阵为:
mat = [[0, n - 1],
[1, n - 2]]
- 通过快速幂求出矩阵的 $k$ 次幂,便可以得到以下标 $0$ 开头的方案数 $x = mat(0, 0)$,和非下标 $0$ 开头的方案数 $y = mat(1, 0)$。
- 如果 $fst = 0$,则答案为 $x + (n - fst - 1) / r * y$;否则答案为 $y + (n - fst - 1) / r * y$。
时间复杂度
- 预处理 $next$ 和匹配的时间复杂度为 $O(n)$。
- 矩阵快速幂的时间复杂度为 $O(\log k)$。
- 故总时间复杂度为 $O(n + \log k)$。
空间复杂度
- 仅需要 $O(n)$ 的额外空间存储 $next$ 数组。
C++ 代码
#define LL long long
const int mod = 1000000007;
struct Mat {
int a[2][2];
Mat() {
memset(a, 0, sizeof(a));
}
Mat operator * (const Mat &y) {
Mat res;
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
for (int k = 0; k < 2; k++)
res.a[i][j] = (res.a[i][j] + (LL)(a[i][k]) * y.a[k][j]) % mod;
return res;
}
};
class Solution {
private:
int n;
vector<int> get_nxt(const string &s) {
vector<int> nxt(n);
nxt[0] = -1;
for (int i = 1, j = -1; i < n; i++) {
while (j >= 0 && s[j + 1] != s[i])
j = nxt[j];
if (s[j + 1] == s[i])
++j;
nxt[i] = j;
}
return nxt;
}
int find(const string &s, const string &t, const vector<int> &nxt) {
string w = t + t;
for (int i = 0, j = -1; i < w.size(); i++) {
while (j >= 0 && s[j + 1] != w[i])
j = nxt[j];
if (s[j + 1] == w[i])
++j;
if (j == n - 1)
return i - n + 1;
}
return -1;
}
Mat power(Mat x, LL y) {
Mat tot, p = x;
tot.a[0][0] = tot.a[1][1] = 1;
for(; y; y >>= 1) {
if (y & 1) tot = tot * p;
p = p * p;
}
return tot;
}
public:
int numberOfWays(string s, string t, LL k) {
n = s.size();
vector<int> nxt = get_nxt(s);
int fst = find(s, t, nxt);
if (fst == -1)
return 0;
Mat p;
p.a[0][0] = 0; p.a[0][1] = n - 1;
p.a[1][0] = 1; p.a[1][1] = n - 2;
Mat tot = power(p, k);
int ans = 0;
if (fst == 0) ans += tot.a[0][0];
else ans += tot.a[1][0];
int r = n;
if (n % (n - nxt[n - 1] - 1) == 0)
r = n - nxt[n - 1] - 1;
ans = (ans + (LL)((n - fst - 1) / r) * tot.a[1][0]) % mod;
return ans;
}
};