AcWing
  • 首页
  • 活动
  • 题库
  • 竞赛
  • 应用
  • 更多
    • 题解
    • 分享
    • 商店
    • 问答
    • 吐槽
  • App
  • 登录/注册

Codeforces 1139/C. C. Edgy Trees    原题链接    简单

作者: 作者的头像   啼莺修竹 ,  2023-05-23 09:32:23 ,  所有人可见 ,  阅读 55


5


3
codeforce每日一题
题目链接
题目分数:1400

题目描述

输入$n(2≤n≤1e5)k(2≤k≤100)$和一棵无向树的$n-1$条边(节点编号从$1$开始),每条边包含$3$个数$x, y, c$,表示有一条颜色为$c$的边连接$x$和$y$,其中$c$等于$0$或$1$。对于长为$k$节点序列$a$,走最短路,按顺序经过节点$a1 -> a2 -> …->ak$。对于所有长为$k$的节点序列$a$(这有$n^k$个),统计至少经过一条$c=1$的边的序列$a$的个数。

样例

输入样例1
4 4
1 2 1
2 3 1
3 4 1
输出样例1
252
输入样例2
4 6
1 2 0
1 3 0
1 4 0
输出样例2
0
输入样例3
3 5
1 2 1
2 3 0
输出样例3
210

算法

(dfs) $O(n*log(k))$

这道题我们从正面去想,看哪些点最短路之间有$1$会比较难做,不如从反方向去想,我们去看哪些点之间的最短路没有$1$。因为这是一棵树,所以我们可以发现一些特殊的性质,假设以k为根节点的子树,以$a$的子节点$a1$为根节点的树中有$m$个节点之间的最短路没有$1$,且$a1$到这$m$个点的最短路中也没有$1$,那么只要$a$到$a1$的边不是$1$,$a$到这$m$个点的最短路中也没有1。如果$a$到$a1$的边为$1$,则把$m$存起来。最后只需要将所有的$m$取出来,答案减去$k^m$就好了。

C++ 代码

//  https://www.acwing.com/blog/content/34755/

#include<bits/stdc++.h>

#define endl '\n'
#define fi first
#define se second
#define all(a) a.begin(), a.end()
#define pd push_back

using namespace std;

typedef long long LL;
typedef pair<int,int> PII;
typedef pair<LL,LL> PLL;
typedef pair<double, double> PDD;

const int N = 1e5 + 10, M = N * 2, INF = 0x3f3f3f3f, mod = 1e9 + 7;

int n, m;
int e[M], ne[M], w[M], h[N], idx;
vector<int> cnt;

LL qmi(LL a, LL k, LL mod)
{
    LL res = 1;
    while(k)
    {
        if(k&1) res = res * a % mod;
        a = a * a % mod;
        k >>= 1;
    }
    return res;
}

inline void add(int a, int b, int c)
{
    e[idx]=b, w[idx]=c, ne[idx]=h[a], h[a]=idx++;
}

inline int dfs(int u, int fa)
{
    int d = 1;
    for(int i=h[u];~i;i=ne[i])
    {
        int j = e[i];
        if(j==fa) continue;
        int t = dfs(j, u);
        if(w[i]==0) d += t;
        else{
            if(t) cnt.pd(t);
        }
    }

    return d;
}

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    memset(h, -1, sizeof h);
    cin>>n>>m;

    for(int i=0;i<n-1;i++){
        int a, b, c;
        cin>>a>>b>>c;
        add(a, b, c), add(b, a, c);
    }

    int t = dfs(1, -1);
    if(t) cnt.pd(t);

    LL res = qmi(n, m, mod);

    for(auto u:cnt){
        LL tmp = qmi(u, m, mod);
        res = (res - tmp + mod) % mod;
    }

    cout<<res<<endl;

    return 0;
}

0 评论

你确定删除吗?

© 2018-2023 AcWing 版权所有  |  京ICP备17053197号-1
用户协议  |  隐私政策  |  常见问题  |  联系我们
AcWing
请输入登录信息
更多登录方式: 微信图标 qq图标 qq图标
请输入绑定的邮箱地址
请输入注册信息