AcWing
  • 首页
  • 活动
  • 题库
  • 竞赛
  • 商店
  • 应用
  • 文章
    • 题解
    • 分享
    • 问答
  • 吐槽
  • 登录/注册

树上差分(倍增法)

作者: 作者的头像   Noe1017 ,  2022-08-06 22:06:51 ,  所有人可见 ,  阅读 29


4


树上差分有多种写法.

写法一:直接对于所有深度的链单独进行处理:

Code:

#include<bits/stdc++.h>
#pragma optimize(2)
#define endl '\n'
#define ll() to_ullong() 
#define string() to_string()
#define Endl endl
using namespace std;
typedef long long ll;
typedef pair<int,int>PII; 
typedef unsigned long long ull;
const int M=2010;
const int P=13331;
const ll  llinf=0x3f3f3f3f3f3f3f3f;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
const int N=2e6+10;
int dx[4]={0,1,0,-1};
int dy[4]={-1,0,1,0};
int id[N],a[N],sum[N];
int e[2*N],ne[2*N],idx,h[2*N];
void add(int a,int b)
{
        e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs1(int u,int fa,int depth)
{
        id[depth]=u;
        sum[id[max(0,depth-a[u]-1)]]--;
        sum[u]++; 
        for(int i=h[u];~i;i=ne[i])
        {
                int j=e[i];
                if(j==fa)continue;
                dfs1(j,u,depth+1);
        }
}
int ans[N];
int dfs2(int u,int fa)
{
        for(int i=h[u];~i;i=ne[i])
        {
                int j=e[i];
                if(j==fa)continue;
                sum[u]+=dfs2(j,u);
        }
        return sum[u];
}
void solve()
{
       memset(h,-1,sizeof h);
       int n;cin>>n;
       for(int i=1;i<n;i++)
       {
               int x,y;
               cin>>x>>y;
               add(x,y),add(y,x);
       }
       for(int i=1;i<=n;i++)cin>>a[i];
       dfs1(1,-1,1); 
       dfs2(1,-1);
       for(int i=1;i<=n;i++)cout<<sum[i]<<' ';
       cout<<endl;
       return ;
}
int main()
{
        ios_base::sync_with_stdio(0);
        cin.tie(0);
        cout.tie(0);
  //      freopen("test.in","r",stdin);
        solve();
        return 0;
}

写法二:倍增法,求出每个点对应的链上祖先.

Code:

#include<bits/stdc++.h>
#pragma optimize(2)
#define endl '\n'
#define ll() to_ullong() 
#define string() to_string()
#define Endl endl
using namespace std;
typedef long long ll;
typedef pair<int,int>PII; 
typedef unsigned long long ull;
const int M=2010;
const int P=13331;
const ll  llinf=0x3f3f3f3f3f3f3f3f;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
const int N=2e6+10;
int dx[4]={0,1,0,-1};
int dy[4]={-1,0,1,0};
int a[N],sum[N];
int e[2*N],ne[2*N],idx,h[2*N];
int fa[N][21];
void add(int a,int b)
{
        e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs1(int u,int father,int depth)
{
        fa[u][0]=father;
        for(int k=1;k<=20;k++)fa[u][k]=fa[fa[u][k-1]][k-1];

        for(int i=h[u];~i;i=ne[i])
        {
                int j=e[i];
                if(j==father)continue;
                dfs1(j,u,depth+1);
        }
}//倍增 
int dfs2(int u,int father)
{
        for(int i=h[u];~i;i=ne[i])
        {
                int j=e[i];
                if(j==father)continue;
                sum[u]+=dfs2(j,u);
        }
        return sum[u];
}
void solve()
{
       memset(h,-1,sizeof h);
       int n;cin>>n;
       for(int i=1;i<n;i++)
       {
               int x,y;
               cin>>x>>y;
               add(x,y),add(y,x);
       }
       for(int i=1;i<=n;i++)cin>>a[i];

       //注意dfs和倍增的顺序 
       dfs1(1,0,1);
       for(int i=1;i<=n;i++)
       {
               int x=i,y=i;
               for(int k=20;k>=0;k--)
                if(a[i]>>k&1)y=fa[y][k];

                y=fa[y][0];
                sum[i]++,sum[y]--;
       }
       dfs2(1,0);
       for(int i=1;i<=n;i++)cout<<sum[i]<<' ';
       cout<<endl;
       return ;
}
int main()
{
        ios_base::sync_with_stdio(0);
        cin.tie(0);
        cout.tie(0);
    //    freopen("test.in","r",stdin);
        solve();
        return 0;
}

0 评论

你确定删除吗?

© 2018-2022 AcWing 版权所有  |  京ICP备17053197号-1
用户协议  |  常见问题  |  联系我们
AcWing
请输入登录信息
更多登录方式: 微信图标 qq图标
请输入绑定的邮箱地址
请输入注册信息