Description
here
lca+桶+差分我不会写
我好菜啊/kk
Solution
可以转化为线段树合并的模板,每个点放一个类型为 $w$ 的物品,统计物品数,可以使用树上差分。
用树上差分把处理转化为,对于 $s->lca$ ,在 $s$ 的线段树上让 $dep[s]+1$ ,然后统计 $dep[x]+w[x]$ 出现了几次即可。
对于 $lca->t$ ,统计 $dep[x]-w[x]$ 出现的次数。
由于下标可能出现负数,把下标都平移 $n$ 即可。
在acwing要开吸氧,不然最后一个点会 T
Code
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=300010,M=600010;
int n,m;
int h[N],to[M],ne[M],idx;
int dep[N],fa[N][19];
int root[N],cnt;
int q[N],w[N];
int ans[N];
struct SegTree{
int l,r,v;
}tr[N*4*19];
inline void read(int &x){
x=0;int f=1;char c=getchar();
while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
x*=f;
}
inline void add(int u,int v){
to[idx]=v,ne[idx]=h[u],h[u]=idx++;
}
void bfs(){
dep[1]=1;
int hh=0,tt=0;
q[0]=1;
while(hh <= tt){
int x=q[hh++];
for(int i=h[x];~i;i=ne[i]){
int y=to[i];
if(!dep[y]){
dep[y]=dep[x]+1;
q[++tt]=y;
fa[y][0]=x;
for(int k=1;k<=18;++k)
fa[y][k]=fa[fa[y][k-1]][k-1];
}
}
}
}
inline int lca(int x,int y){
if(dep[x] < dep[y]) swap(x,y);
for(int k=18;k>=0;--k)
if(dep[fa[x][k]] >= dep[y])
x=fa[x][k];
if(x == y) return x;
for(int k=18;k>=0;--k)
if(fa[x][k] != fa[y][k])
x=fa[x][k],y=fa[y][k];
return fa[x][0];
}
void insert(int &u,int l,int r,int p,int d){
if(l == r){
tr[u].v+=d;
return;
}
int mid=(l+r)>>1;
if(p <= mid){
if(!tr[u].l) tr[u].l=++cnt;
insert(tr[u].l,l,mid,p,d);
}
else{
if(!tr[u].r) tr[u].r=++cnt;
insert(tr[u].r,mid+1,r,p,d);
}
}
int query(int u,int l,int r,int p){
if(!u) return 0;
if(l == r) return tr[u].v;
int mid=(l+r)>>1;
if(p <= mid) return query(tr[u].l,l,mid,p);
else return query(tr[u].r,mid+1,r,p);
}
int merge(int p,int q,int l,int r){
if(!p) return q;
if(!q) return p;
if(l == r){
tr[p].v+=tr[q].v;
return p;
}
int mid=(l+r)>>1;
tr[p].l=merge(tr[p].l,tr[q].l,l,mid);
tr[p].r=merge(tr[p].r,tr[q].r,mid+1,r);
return p;
}
void dfs(int x){
for(int i=h[x];~i;i=ne[i]){
int y=to[i];
if(y == fa[x][0]) continue;
dfs(y);
root[x]=merge(root[x],root[y],1,n<<1);
}
if(w[x] && dep[x]+w[x]+n<=n*2)
ans[x]+=query(root[x],1,n<<1,dep[x]+w[x]+n);
ans[x]+=query(root[x],1,n<<1,dep[x]-w[x]+n);
}
int main(){
memset(h,-1,sizeof h);
read(n),read(m);
for(int i=1;i<n;++i){
int a,b;
read(a),read(b);
add(a,b),add(b,a);
}
for(int i=1;i<=n;++i) read(w[i]);
bfs();
for(int i=1;i<=n*2;++i) root[i]=++cnt;
for(int i=1;i<=m;++i){
int a,b;
read(a),read(b);
int anc=lca(a,b);
insert(root[a],1,n<<1,dep[a]+n,1);
insert(root[b],1,n<<1,dep[anc]*2-dep[a]+n,1);
insert(root[anc],1,n<<1,dep[a]+n,-1);
if(anc != 1) insert(root[fa[anc][0]],1,n<<1,dep[anc]*2-dep[a]+n,-1);
}
dfs(1);
for(int i=1;i<=n;++i)
printf("%d ",ans[i]);
return 0;
}