算法分析
先考虑静态查询树上两点的距离,可以通过 $\operatorname{dist}(a, \operatorname{lca}(a, b)) + \operatorname{dist}(b, \operatorname{lca}(a, b))$ 计算得到,其中 $\operatorname{dist}(u, v)$ 表示点 $u$ 和点 $v$ 之间的距离
对于 $\operatorname{lca}(a, b)$,可以通过以下方法求得:
- 倍增
- 树剖
- tarjan
- 欧拉序+线段树
下面主要介绍第四种方法
对树跑 $\operatorname{dfs}$ 可得到序列 $(1, ~2, ~3, ~5, ~3, ~6, ~3, ~2, ~4, ~2, ~1, ~7, ~1)$
对应的欧拉序序列就是 $(1, ~2, ~3, ~4, ~3, ~5, ~3, ~2, ~6, ~2, ~1, ~7, ~1)$
此时,点 $4$ 和点 $5$ 的 $\operatorname{lca}$ 就是位于两点之间欧拉序最小的那个点
然后用线段树对两点间的欧拉序做一遍 $\operatorname{RMQ}$
再考虑加上边权,对于相邻两个点的距离,定义 $\operatorname{dfs}$ 时方向向下则符号为正,反之符号为负
对于 $\operatorname{dist}(a, \operatorname{lca}(a, b))$ 和 $\operatorname{dist}(b, \operatorname{lca}(a, b))$,可以用树状数组来求
再考虑操作 $1$,其实就是树上差分, 可以用树状数组维护树边的 $\operatorname{dfs}$ 序
C++ 代码
#include <bits/stdc++.h>
#if __has_include(<atcoder/all>)
#include <atcoder/all>
using namespace atcoder;
#endif
#define rep(i, n) for (int i = 0; i < (n); ++i)
using namespace std;
using ll = long long;
using P = pair<int, int>;
struct Edge {
int to, cost, id;
Edge() {}
Edge(int to, int cost, int id): to(to), cost(cost), id(id) {}
};
int op(int a, int b) { return min(a, b); }
int rmq_e() { return 1e9; }
int main() {
cin.tie(nullptr) -> sync_with_stdio(false);
int n;
cin >> n;
vector<vector<Edge>> g(n);
rep(i, n-1) {
int a, b, c;
cin >> a >> b >> c;
--a; --b;
g[a].emplace_back(b, c, i);
g[b].emplace_back(a, c, i);
}
vector<int> et;
vector<int> in(n);
vector<int> ein(n-1), eout(n-1), ew(n-1);
auto dfs = [&](auto f, int v, int p=-1) -> void {
in[v] = et.size();
et.push_back(v);
for (auto e : g[v]) {
if (e.to == p) continue;
ein[e.id] = et.size()-1;
ew[e.id] = e.cost;
f(f, e.to, v);
eout[e.id] = et.size()-1;
et.push_back(v);
}
};
dfs(dfs, 0);
segtree<int, op, rmq_e> rmq(et.size());
rep(i, et.size()) rmq.set(i, in[et[i]]);
fenwick_tree<ll> d(et.size());
rep(i, n-1) d.add(ein[i], ew[i]);
rep(i, n-1) d.add(eout[i], -ew[i]);
int q;
cin >> q;
rep(qi, q) {
int type, a, b;
cin >> type >> a >> b;
if (type == 1) {
--a;
int dif = b-ew[a];
ew[a] = b;
d.add(ein[a], dif);
d.add(eout[a], -dif);
}
else {
--a; --b;
int l = in[a], r = in[b];
if (l > r) swap(l, r);
int c = et[rmq.prod(l, r+1)];
ll ans = 0;
if (a != c) ans += d.sum(in[c], in[a]);
if (b != c) ans += d.sum(in[c], in[b]);
cout << ans << '\n';
}
}
return 0;
}