简化题意
给你一棵树,每个点有一个颜色,$Q$ 次询问,每次询问给出 $r_1$ 和 $r_2$,求祖孙节点 $(x,y)$ 满足 $x$ 的颜色为 $r_1$ 且 $y$ 的颜色为 $r_2$ 的祖孙节点的对数。
思路
运用根号分支,设树上某个颜色出现了大于 $\sqrt{n}$ 次,则称之为重颜色,否则称之为轻颜色。
对于每一轮询问 $(r_1,r_2)$:
-
若 $r_2$ 为重颜色,那么 $r_2$ 最多有 $\sqrt{n}$ 种可能,所以可能的询问 $(r_1,r_2)$ 的种数最多为 $n\sqrt{n}$ 个,对于整棵树进行预处理,将树转成 $dfn$ 序的形式,对于一个点,若其为重颜色,则在主席树中在对应颜色位置加一。那么对于每一个点,可以用它子树的 $dfn$ 序下标求出子树中每一个重颜色的个数,于是就可以预处理出所有 $(r_1,r_2)$ 的答案,时间复杂度 $O(n\sqrt{n}\log{\sqrt{n}})$(为了避免 MLE 所以使用了主席树)。
-
否则若 $r_1$ 为重颜色,同理也只有 $n\sqrt{n}$ 种可能,于是要求的是每一个节点的祖宗节点中,每个重颜色出现的次数,在遍历树的时候可以顺便处理,时间复杂度 $O(n)$。
-
否则 $r_1$ 和 $r_2$ 均为轻颜色,将所有颜色为 $r_2$ 的节点在 $dfn$ 序下加入树状数组,再对于每一个 $r_1$ 节点,在树状数组中求即可,时间复杂度 $O(q \sqrt{n}\log n)$。
综上所述,我们现在有了一个时间复杂度 $O(n\sqrt{n}\log n)$,空间复杂度 $O(R\sqrt{n})$ 的算法,由于空间卡的紧,所以将块长稍调大。
代码
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 200010, M = 300, K = 25010;
int n, r, q, L;
int h[N], e[N], ne[N], idx;
int c[N], cnt[K], g[K], le[N], ri[N], w[N], tmp, tot;
int s2[M], rt[N];
ll ans[K][M], ans2[M][K];
vector<int> gg[K];
struct Tree {
int tr[N];
void add(int x, int k) {
for (int i = x; i <= n; i += i & -i) tr[i] += k;
} int query(int x) {
int res = 0;
for (int i = x; i; i -= i & -i) res += tr[i];
return res;
} int qry(int lll, int rr) {return query(rr) - query(lll - 1);}
} T;
struct sgt {
int tt;
struct Node {
int l, r, s;
} tr[N * 10];
void modify(int &u, int l, int r, int p, int k) {
tr[++ tt] = tr[u], tr[tt].s += k, u = tt;
if (l == r) return;
int mid = l + r >> 1;
if (p <= mid) modify(tr[u].l, l, mid, p, k);
else modify(tr[u].r, mid + 1, r, p, k);
} int query(int u, int l, int r, int p) {
if (!u) return 0;
if (l == r) return tr[u].s;
int mid = l + r >> 1;
if (p <= mid) return query(tr[u].l, l, mid, p);
return query(tr[u].r, mid + 1, r, p);
}
} T2;
void add_edge(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
void dfs(int u, int fa) {
for (int i = 1; i <= tmp; i ++ ) ans2[i][c[u]] += s2[i];
le[u] = ++ tot, w[tot] = u;
if (cnt[c[u]] > L) s2[g[c[u]]] ++ ;
for (int i = h[u]; ~i; i = ne[i]) {
int ver = e[i];
if (ver == fa) continue;
dfs(ver, u);
}
if (cnt[c[u]] > L) s2[g[c[u]]] -- ;
ri[u] = tot;
}
int main() {
scanf("%d%d%d%d", &n, &r, &q, &c[1]);
T2.tt = n;
L = max((int)sqrt(n), 650);
memset(h, -1, sizeof h);
for (int i = 2; i <= n; i ++ ) {
int p;
scanf("%d%d", &p, &c[i]);
add_edge(p, i);
cnt[c[i]] ++ ;
}
for (int i = 1; i <= n; i ++ ) gg[c[i]].push_back(i);
for (int i = 1; i <= r; i ++ ) if (cnt[i] > L) g[i] = ++ tmp;
dfs(1, 0);
for (int i = 1; i <= n; i ++ ) {
rt[i] = i;
T2.tr[rt[i]] = T2.tr[rt[i - 1]];
if (cnt[c[w[i]]] > L) T2.modify(rt[i], 1, tmp, g[c[w[i]]], 1);
}
for (int i = 1; i <= n; i ++ )
for (int j = 1; j <= tmp; j ++ )
ans[c[i]][j] += T2.query(rt[ri[i]], 1, tmp, j) - T2.query(rt[le[i] - 1], 1, tmp, j);
while (q -- ) {
int r1, r2;
scanf("%d%d", &r1, &r2);
if (cnt[r2] > L) printf("%lld\n", ans[r1][g[r2]]);
else if (cnt[r1] > L) printf("%lld\n", ans2[g[r1]][r2]);
else {
ll res = 0;
for (int k : gg[r2]) T.add(le[k], 1);
for (int k : gg[r1]) res += T.qry(le[k] + 1, ri[k]);
for (int k : gg[r2]) T.add(le[k], -1);
printf("%lld\n", res);
}
fflush(stdout);
}
return 0;
}