题意
有一颗 $n$ 个点的树,点有点权,只会是 $0$ 或 $1$,要支持以下操作。
1 u
,表示把 $u$ 的权值异或 $1$。2
,表示查询树上最远的两个权值为 $0$ 的点对的距离。
分析
我们知道,树上两个点集的并集中最远的两个点,一定属于原来两个集合中分别最远的两个点的并集中四个点。所以我们线段树维护区间最远两个点,然后直接合并,在树上同样暴力合并重链,时间复杂度就是 $O(n\log^2n)$。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define il inline
#define N 100005
#define get(x, y) (dep[x] < dep[y] ? (x) : (y))
il ll rd(){
ll s = 0, w = 1;
char ch = getchar();
for (;ch < '0' || ch > '9'; ch = getchar()) if (ch == '-') w = -1;
for (;ch >= '0' && ch <= '9'; ch = getchar()) s = ((s << 1) + (s << 3) + ch - '0');
return s * w;
}
ll n, u, v, lg[N << 1];
char op;
vector <ll> e[N];
struct Graph{
ll euler[N << 1][35], pos[N], et, dep[N];
void dfs(ll u, ll fa){
dep[u] = dep[fa] + 1, euler[++et][0] = u, pos[u] = et;
for (int v : e[u]) if (v != fa) dfs(v, u), euler[++et][0] = u;
}il void init(){
dfs(1, 0);
for (int i = 2; i <= et; i++) lg[i] = lg[i >> 1] + 1;
for (int j = 1; j <= 30; j++) for (int i = 1; i + (1 << j) - 1 <= et; i++)
euler[i][j] = get(euler[i][j - 1], euler[i + (1 << (j - 1))][j - 1]);
}il ll lca(ll u, ll v){
ll l = pos[u], r = pos[v];
if (l > r) swap(l, r);
ll p = lg[r - l + 1];
return get(euler[l][p], euler[r - (1 << p) + 1][p]);
}
}G;
il ll dis(ll u, ll v){return G.dep[u] + G.dep[v] - 2 * G.dep[G.lca(u, v)];}
struct ST{
ll u, v, dist = -1;
}tr[N << 2];
ST operator + (const ST &x, const ST &y){
ST ans, a = x, b = y;
if (a.dist == -1) return b;
if (b.dist == -1) return a;
ll d1 = dis(a.u, b.u), d2 = dis(a.u, b.v), d3 = dis(a.v, b.u), d4 = dis(a.v, b.v);
if (a.dist > b.dist) ans = a;
else ans = b;
if (d1 > ans.dist) ans.u = a.u, ans.v = b.u, ans.dist = d1;
if (d2 > ans.dist) ans.u = a.u, ans.v = b.v, ans.dist = d2;
if (d3 > ans.dist) ans.u = a.v, ans.v = b.u, ans.dist = d3;
if (d4 > ans.dist) ans.u = a.v, ans.v = b.v, ans.dist = d4;
return ans;
}void add(ll p, ll l, ll r, ll x){
if (l == r){
if (!tr[p].dist) tr[p].dist = -1;
else tr[p].dist = 0, tr[p].u = tr[p].v = l;
return ;
}ll mid = (l + r) >> 1;
if (x <= mid) add(p << 1, l, mid, x);
else add(p << 1 | 1, mid + 1, r, x);
tr[p] = tr[p << 1] + tr[p << 1 | 1];
}
int main(){
n = rd();
for (int i = 1; i < n; i++) u = rd(), v = rd(), e[u].push_back(v), e[v].push_back(u);
G.init();
for (int i = 1; i <= n; i++) add(1, 1, n, i);
for (int T = rd(); T--;){
scanf ("%s", &op);
if (op == 'C') add(1, 1, n, rd());
else printf ("%lld\n", tr[1].dist);
}
return 0;
}