听课笔记:
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <vector>
#include <cmath>
//如果顺序有关系,涉及修改,需要用树链剖分
using namespace std;
//不涉及修改,没有顺序关系的可以用欧拉序列
const int N = 500007, M = 500007, INF = 0x3f3f3f3f;
int n, m;
int cnt[N], vis[N];
vector<int>v;
int f[N][20], dep[N];
int head[N], ver[M], nex[M], tot;
int w[M];
int seq[N], top, first[N], last[N];
int block;
int ans[N];
void add_edge(int x, int y)
{
ver[tot] = y;
nex[tot] = head[x];
head[x] = tot ++ ;
}
struct Query{
int id, l, r, p;
}q[N];
int get_block(int x)
{
return x / block;
}
bool cmp(Query &a, Query &b)
{
int x = get_block(a.l);
int y = get_block(b.l);
if(x != y)return x < y;
return a.r < b.r;
}
void dfs(int x, int fa){
seq[ ++ top] = x;
first[x] = top;
for(int i = head[x]; ~i; i = nex[i]){
int y = ver[i];
if(y == fa) continue;
dfs(y, x);
}
seq[ ++ top] = x;
last[x] = top;
}
int que[N];
void bfs()
{
memset(dep, 0x3f, sizeof dep);
int hh = 0, tt = 0;
que[0] = 1;
dep[0] = 0, dep[1] = 1;
while(hh <= tt){
int x = que[hh ++ ];
if(hh == N) hh = 0;
for(int i = head[x]; ~i; i = nex[i]){
int y = ver[i];
if(dep[y] > dep[x] + 1){
dep[y] = dep[x] + 1;
f[y][0] = x;
for(int k = 1; k <= 15; ++ k){
f[y][k] = f[f[y][k - 1]][k - 1];
}
que[ ++ tt] = y;
if(tt == N) tt = 0;
}
}
}
}
int lca(int x, int y)
{
if(dep[x] < dep[y]) swap(x, y);
for(int k = 15; k >= 0; -- k){
if(dep[f[x][k]] >= dep[y]){
x = f[x][k];
}
}
if(x == y) return x;
for(int k = 15; k >= 0; -- k){
if(f[x][k] != f[y][k]){
x = f[x][k];
y = f[y][k];
}
}
return f[x][0];
}
void add(int x, int &res)
{
//!欧拉序列中出现两次就不是路径上的点了!要删掉
//要删掉的点一定是只出现一次的,添加的时候add一次,删除的时候add一次,两次即为删除
vis[x] ^= 1;//需要的是点的编号
if(vis[x] == 0){
cnt[w[x]] -- ;//需要的是点的权值(离散化过了)
if(cnt[w[x]] == 0) res -- ;
}
else {
cnt[w[x]] ++ ;
if(cnt[w[x]] == 1) res ++ ;
}
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; ++ i){
scanf("%d", &w[i]);
v.push_back(w[i]);
}
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
for(int i = 1; i <= n; ++ i){
w[i] = lower_bound(v.begin(), v.end(), w[i]) - v.begin();
}
memset(head, -1, sizeof head);
for(int i = 1; i <= n - 1; ++ i){
int x, y;
scanf("%d%d", &x, &y);
add_edge(x, y);
add_edge(y, x);
}
dfs(1, -1);//得到欧拉序列
bfs();//lca预处理
for(int i = 0; i < m; ++ i){
int a, b;
scanf("%d%d", &a, &b);
//a,b是树上的点
//first[a], first[b], last[a], last[b]才是数列上的点,也是我们莫队要处理的点
if(first[a] > first[b]) swap(a, b);
int p = lca(a, b);
if(a == p)
q[i] = {i, first[a], first[b]};
else q[i] = {i, last[a], first[b], p};
}
block = sqrt(top);//这里应该是欧拉序列里的点的个数
sort(q, q + m, cmp);
//右指针i左指针j, 右指针先冲左指针跟上
//左指针在1,右指针在0,初始状态形成一个空集
for(int k = 0, i = 0, j = 1, res = 0; k < m; ++ k){
int l = q[k].l, r = q[k].r, id = q[k].id, p = q[k].p;
//这里走的应该是欧拉序列里的点了
while(i < r) add(seq[ ++ i], res);//add
while(i > r) add(seq[i -- ], res);//del
while(j < l) add(seq[j ++ ], res);//del
while(j > l) add(seq[ -- j], res);//add
if(p) add(p, res);
ans[id] = res;
if(p) add(p, res);//最后一定要删除p,因为它不属于 i 到 i 这一连续序列中
}
for(int i = 0; i < m; ++ i)
printf("%d\n", ans[i]);
return 0;
}
可以树上带修莫队吧