树上差分有多种写法.
写法一:直接对于所有深度的链单独进行处理:
Code:
#include<bits/stdc++.h>
#pragma optimize(2)
#define endl '\n'
#define ll() to_ullong()
#define string() to_string()
#define Endl endl
using namespace std;
typedef long long ll;
typedef pair<int,int>PII;
typedef unsigned long long ull;
const int M=2010;
const int P=13331;
const ll llinf=0x3f3f3f3f3f3f3f3f;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
const int N=2e6+10;
int dx[4]={0,1,0,-1};
int dy[4]={-1,0,1,0};
int id[N],a[N],sum[N];
int e[2*N],ne[2*N],idx,h[2*N];
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs1(int u,int fa,int depth)
{
id[depth]=u;
sum[id[max(0,depth-a[u]-1)]]--;
sum[u]++;
for(int i=h[u];~i;i=ne[i])
{
int j=e[i];
if(j==fa)continue;
dfs1(j,u,depth+1);
}
}
int ans[N];
int dfs2(int u,int fa)
{
for(int i=h[u];~i;i=ne[i])
{
int j=e[i];
if(j==fa)continue;
sum[u]+=dfs2(j,u);
}
return sum[u];
}
void solve()
{
memset(h,-1,sizeof h);
int n;cin>>n;
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
add(x,y),add(y,x);
}
for(int i=1;i<=n;i++)cin>>a[i];
dfs1(1,-1,1);
dfs2(1,-1);
for(int i=1;i<=n;i++)cout<<sum[i]<<' ';
cout<<endl;
return ;
}
int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
// freopen("test.in","r",stdin);
solve();
return 0;
}
写法二:倍增法,求出每个点对应的链上祖先.
Code:
#include<bits/stdc++.h>
#pragma optimize(2)
#define endl '\n'
#define ll() to_ullong()
#define string() to_string()
#define Endl endl
using namespace std;
typedef long long ll;
typedef pair<int,int>PII;
typedef unsigned long long ull;
const int M=2010;
const int P=13331;
const ll llinf=0x3f3f3f3f3f3f3f3f;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
const int N=2e6+10;
int dx[4]={0,1,0,-1};
int dy[4]={-1,0,1,0};
int a[N],sum[N];
int e[2*N],ne[2*N],idx,h[2*N];
int fa[N][21];
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs1(int u,int father,int depth)
{
fa[u][0]=father;
for(int k=1;k<=20;k++)fa[u][k]=fa[fa[u][k-1]][k-1];
for(int i=h[u];~i;i=ne[i])
{
int j=e[i];
if(j==father)continue;
dfs1(j,u,depth+1);
}
}//倍增
int dfs2(int u,int father)
{
for(int i=h[u];~i;i=ne[i])
{
int j=e[i];
if(j==father)continue;
sum[u]+=dfs2(j,u);
}
return sum[u];
}
void solve()
{
memset(h,-1,sizeof h);
int n;cin>>n;
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
add(x,y),add(y,x);
}
for(int i=1;i<=n;i++)cin>>a[i];
//注意dfs和倍增的顺序
dfs1(1,0,1);
for(int i=1;i<=n;i++)
{
int x=i,y=i;
for(int k=20;k>=0;k--)
if(a[i]>>k&1)y=fa[y][k];
y=fa[y][0];
sum[i]++,sum[y]--;
}
dfs2(1,0);
for(int i=1;i<=n;i++)cout<<sum[i]<<' ';
cout<<endl;
return ;
}
int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
// freopen("test.in","r",stdin);
solve();
return 0;
}