参考视频
关于此题主要考树的性质,树上两点有唯一的最短路径
将所有情况分为两种
1.去掉首尾其中一点:相当于总路径中删除1->2 或者 (n-1) ->n
2.去掉中间的点例如3号点,原路径例如1->2->3->4
新路径1->2->4
我们发现减少了2->3 , 3->4的路径
多了2->4的路径
我们需要快速求树上两点的距离,这里直接背LCA模板,作用是快速返回树上两点距离
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
typedef pair<int, int> PII;
typedef long long LL;
const int N = 100010 , M = 2 * N, K = 20;
int h[N], e[M], ne[M], idx , w[M];
int n,k;
int path[N];
LL sum,init[N];
void add(int a, int b, int c) // 添加一条边a->b,边权为c
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}
int depth[N] , q[N] ,fa[N][K] ;
LL dist[N];
void bfs(int root) // 预处理倍增数组
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[root] = 1; // depth存储节点所在层数
int hh = 0, tt = 0;
q[0] = root;
while (hh <= tt)
{
int t = q[hh ++ ];
for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if (depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
dist[j] = dist[t] + w[i];
q[ ++ tt] = j;
fa[j][0] = t; // j的第二次幂个父节点
for (int k = 1; k < K; k ++ )
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
int lca(int a, int b) // 返回a和b的最近公共祖先
{
if (depth[a] < depth[b]) swap(a, b);
for (int k = K - 1; k >= 0; k -- )
if (depth[fa[a][k]] >= depth[b])
a = fa[a][k];
if (a == b) return a;
for (int k = K - 1; k >= 0; k -- )
if (fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
LL get_dist(int a,int b)
{
int p = lca(a,b);
return dist[a] + dist[b] - 2 * dist[p];
}
int main()
{
memset(h, -1, sizeof h);
scanf("%d%d", &n, &k);
for(int i = 1;i <= n - 1;i++)
{
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add(a,b,c) , add(b,a,c);
}
for(int i = 1;i <= k;i++)
scanf("%d",&path[i]);
bfs(1);
for(int i = 1;i <= k - 1;i++)
{
int st = path[i] ,ed = path[i + 1];
//init[i] = dijkstra(st,ed);
init[i] = get_dist(st,ed);
sum += init[i];
}
//cout << sum << endl;
LL backup = sum;
for(int i = 1;i <= k;i++)
{
sum = backup;
if(i == 1)
{
sum -= init[1];
printf("%lld ",sum);
}
else if(i == k)
{
sum -= init[k-1];
printf("%lld ",sum);
}
else
{
int a = i - 1,b = i + 1;
sum -= init[a] + init[i];
//sum += dijkstra(path[a],path[b]);
sum += get_dist(path[a],path[b]);
printf("%lld ",sum);
}
}
return 0;
}