树剖之后操作转化为将一条链修改为同一种新的颜色
查询即查询链上颜色相同的相邻点对数
线段树维护区间左右端点颜色即可
#include <iostream>
#include <cstring>
#pragma GCC optimize(1)
using namespace std;
const int N = 100010, M = 200010;
int c;
int e[M], ne[M], h[N], idx;
int son[N], fa[N], sz[N];
int dep[N], top[N], ts;
int id[N];
int n, m;
struct Node {
int l, r;
int lc, rc;
int s;
int cov;
}tr[N << 2];
void insert(int u, int v) {
e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}
void dfs1(int u, int father) {
dep[u] = dep[father] + 1, fa[u] = father, sz[u] = 1;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == father) continue;
dfs1(j, u);
sz[u] += sz[j];
if (sz[j] > sz[son[u]]) son[u] = j;
}
}
void dfs2(int u, int t) {
id[u] = ++ts, top[u] = t;
if (!son[u]) return;
dfs2(son[u], t);
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa[u] || j == son[u]) continue;
dfs2(j, j);
}
}
void pushup(Node& u, Node& l, Node& r) {
u.s = l.s + r.s;
if (l.rc == r.lc) u.s++;
u.lc = l.lc, u.rc = r.rc;
}
void pushup(int u) {
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r) {
tr[u] = {l, r};
if (l == r) {
tr[u].lc = tr[u].rc = ++c;
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void pushcov(Node& u, int v) {
u.s = u.r - u.l;
u.lc = u.rc = u.cov = v;
}
void pushdown(int u) {
auto &rt = tr[u], &l = tr[u << 1], &r = tr[u << 1 | 1];
if (rt.cov) {
pushcov(l, rt.cov);
pushcov(r, rt.cov);
rt.cov = 0;
}
}
void modify(int u, int l, int r, int v) {
if (tr[u].l >= l && tr[u].r <= r) {
pushcov(tr[u], v);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, v);
if (r > mid) modify(u << 1 | 1, l, r, v);
pushup(u);
}
Node query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
pushdown(u);
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
Node res;
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
pushup(res, left, right);
return res;
}
void modify_path(int u, int v, int c) {
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
modify(1, id[top[u]], id[u], c);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
modify(1, id[v], id[u], c);
}
Node query_path(int u, int v) {
Node left = {0, 0, 0, 0, 0, 0}, right = {0, 0, 0, 0, 0, 0};
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(left, right), swap(u, v);
Node tmp = query(1, id[top[u]], id[u]);
pushup(left, tmp, left);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(left, right), swap(u, v);
Node tmp = query(1, id[v], id[u]);
pushup(left, tmp, left);
Node res = {0, 0, 0, 0, 0, 0};
swap(left.lc, left.rc);
pushup(res, left, right);
return res;
}
void solve() {
scanf("%d%d", &n, &m);
c = idx = ts = 0;
for (int i = 1; i <= n; i++) h[i] = -1, son[i] = 0;
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
insert(u, v);
insert(v, u);
}
dfs1(1, 0);
dfs2(1, 1);
build(1, 1, n);
while (m--) {
int op, u, v;
scanf("%d%d%d", &op, &u, &v);
if (op == 1) modify_path(u, v, ++c);
else printf("%d\n", query_path(u, v).s);
}
}
int main() {
int T;
scanf("%d", &T);
while (T--) solve();
return 0;
}
这是哪里的题单呀 历年NOI?