在树上求A到B的路径只需要写上个LCA就可以解决,剩下就是从A到B路径上的深度 k
次方和该如何解决。
设点 LCA(u,v)
为u
,v
的最近公共祖先, dep[u]
为 u
点的深度, 那我们可以发现从 A
点到根节点的深度 k
次方和为 $0 + 1^k + 2^k + 3^k .... dep[A]^k$,每次查询的时候再去计算时间复杂度会爆炸,所以我们可以直接预处理出来在深度为 i
的情况下,次方为 k
的值为多少,设数组为 mul[i][j]
预处理出来即可
那这部分代码可以写成
for (int i = 1; i <= t; i ++){
mul[i][0] = 1;
for (int j = 1; j <= 50; j ++){
mul[i][j] = mul[i][j-1] % mod * 1ll*i % mod;
}
}
然后我们看到从根节点到 A
节点是一段有序的和,假设我要求点 C
(假设点 C
是根节点到 A
节点上的一个点 )到点 A
之间的距离,那就是要求 $dep[C]^k + (dep[C]+1)^k + ....+ dep[A]^k$的值,那我们可以使用前缀和来处理,dist[i][k],表示在深度为 i
的情况下,k次方的前缀和
for (int i = 1; i <= 50; i ++){
for (int j = 1; j <= t; j ++){
dist[j][i] = (dist[j-1][i] % mod + 1ll*mul[j][i] % mod) % mod;
}
}
那最后我们从 A
点到 B
点的值就等于
dist[dep[A]][k] + dist[dep[B]][k] - dist[dep[LCA(A,B)]][k] - dist[dep[LCA(A,B)]-1][k]
;
#include<bits/stdc++.h>
using namespace std;
const int N = 3e5+10, mod = 998244353;
int dep[N], fa[N][21];
int mul[N][51], dist[N][51];
vector<int> g[N];
struct T{
int i, j, k;
}ask[N];
int maxdep;
void dfs(int x, int f){
dep[x] = dep[f] + 1;
maxdep = max(maxdep, dep[x]);
fa[x][0] = f;
for (int i = 1; i <= 20; i ++){
fa[x][i] = fa[fa[x][i-1]][i-1];
}
for (auto u : g[x]){
if(u == f) continue;
dfs(u, x);
}
}
int lca(int x, int y){
if(dep[x] < dep[y]) swap(x, y);
for (int i = 20; i>=0 ; i--){
if(dep[fa[x][i]] >= dep[y]){
x = fa[x][i];
}
}
if(x == y) return x;
for (int i = 20; i >= 0; i --){
if(fa[x][i] != fa[y][i]){
x = fa[x][i], y = fa[y][i];
}
}
return fa[x][0];
}
void init(int t,int k){
for (int i = 1; i <= t; i ++){
mul[i][0] = 1;
for (int j = 1; j <= 50; j ++){
mul[i][j] = mul[i][j-1] % mod * 1ll*i % mod;
}
}
for (int i = 1; i <= 50; i ++){
for (int j = 1; j <= t; j ++){
dist[j][i] = (dist[j-1][i] % mod + 1ll*mul[j][i] % mod) % mod;
}
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int n;
cin >> n;
for (int i = 1; i < n; i ++) {
int a, b;
cin >> a >> b;
g[a].push_back(b);
g[b].push_back(a);
}
dep[0] = -1;
dfs(1,0);
int maxn = -1;
int m;
cin >> m;
for (int i = 1; i <= m; i ++){
cin >> ask[i].i >> ask[i].j >> ask[i].k;
maxn = max(maxn, ask[i].k);
}
init(maxdep, maxn);//预处理
//cout << dist[2][2] << '\n';
for (int i = 1; i <= m; i ++){
int a = ask[i].i, b = ask[i].j, k = ask[i].k;
int t = lca(a, b);
int t1 = dist[dep[a]][k], t2 = dist[dep[b]][k], t3, t4 = dist[dep[t]][k];
//cout << t1 << " " << t2 << '\n';
if(t == 1){
t3 = 0;
}else{
//cout << dep[t]-1 << " " << k << '\n';
t3 = dist[dep[t]-1][k] % mod;
}
//cout << t3 << "\n";
cout << (t1 + t2 + 2ll*mod - t3 - t4) % mod << '\n';
}
return 0;
}
/*
预处理出来
dep[i][j]
i的j次方,存储起来
统计最大有多少层;
离线统计一下即可
*/