首先很容易意识到这题是LCA的模板题,对于一条路径,我们先将没有减去景点的完整路径走一遍看它的路径和S是多少,先不考虑首尾节点,当我们要删去一个节点时,我们只需要将S减去这个节点对应前一个节点和后一节点的距离的和,再加上前一个节点到后一节点的距离即可(这样就相当于我们在这条路径中删去了该节点),想清楚了这一点首尾节点就更简单了,只要删去对应的一条边长度就行了。
一开始的思路是将所有查询用二维数组来表示(毕竟这样可以减少思维量),但tarjan作为一种离线算法,并不适合存储大量数据,所以不能用二维数组存储两个节点的结果,否则只能过三个点,因此我们重新考虑使用vector的方式来存储查询。
我们注意到对于某个点并不需要查询它对其余所有节点的距离,而是只需要查询在路线中它和它的下一个点或者下下个点的距离。
例如对于样例路径,对于第一个景点2来说只需要记录2-6的距离(正常路线)和2-5(跳过景点6)的距离即可,因此我们可以用两个查询id来表示一个节点对应的走法。用res存储查询id且下标从1开始。这样对于路径的第i个节点ri,它的正常走法距离是res[2 * i - 1],跳过它的下一个节点的走法距离是res[2 * i]。
(实际代码里road从0开始,发现没啥必要但不想改了呜呜呜,主要是res从1开始)
接下来的就都很简单了,直接使用tarjan算法即可,tarjan算法模板在蓝桥杯辅导课的10-4。
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
typedef pair<int, int> pii;
const int N = 1e5 + 10, M = 2 * N;
int e[M], ne[M], h[N], w[M], idx;
int p[N], st[N];
long long dist[N];
long long res[M];
int road[N];
int n, k;
vector<pii> query[N];
void add(int a, int b,int c) {
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}
int find(int x) {
if (p[x] != x)
p[x] = find(p[x]);
return p[x];
}
void dfs(int u, int fat) {
for (int i = h[u]; i != -1; i=ne[i]) {
int j = e[i];
if (j != fat) {
dist[j] = dist[u] + w[i];
dfs(j, u);
}
}
}
void tarjan(int u) {
st[u] = 1;
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (!st[j]) {
tarjan(j);
p[j] = u;
}
}
for (auto item : query[u]) {
int y = item.first, id = item.second;
if (st[y] == 2) {
int anc = find(y);
res[id] = dist[u] + dist[y] - dist[anc] * 2;
}
}
st[u] = 2;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> k;
int t = n - 1;
memset(h, -1, sizeof h);
while (t--) {
int a, b, c;
cin >> a >> b >> c;
add(a, b, c);
add(b, a, c);
}
for (int i = 0; i < k; i++)
cin >> road[i];
int ct = 0;
// 初始化距离
dfs(1, -1);
// 初始化并查集
for (int i = 1; i <= n; i++)
p[i] = i;
// 初始化查询
for (int i = 0, ct = 1; i < k; i++) {
// 下一个点
query[road[i]].push_back({ road[i + 1],ct });
query[road[i + 1]].push_back({ road[i],ct });
ct++;
// 下下个点
query[road[i]].push_back({ road[i + 2],ct });
query[road[i + 2]].push_back({ road[i],ct });
ct++;
}
tarjan(1);
long long ans = 0;
// 总距离
for (int i = 1; i <= k; i++) {
ans += res[2 * i - 1];
}
// 删去第一个点
cout << ans - res[1] << ' ';
// 删去中间点
for (int i = 2; i < k; i++) {
cout << ans - res[2 * (i - 1) - 1] - res[2 * i - 1] + res[2 * (i - 1)] << ' ';
}
// 删去最后一个点
cout << ans - res[2 * k - 3];
return 0;
}