codeforce每日一题
题目链接
题目分数:1400
题目描述
输入$n(2≤n≤1e5)k(2≤k≤100)$和一棵无向树的$n-1$条边(节点编号从$1$开始),每条边包含$3$个数$x, y, c$,表示有一条颜色为$c$的边连接$x$和$y$,其中$c$等于$0$或$1$。对于长为$k$节点序列$a$,走最短路,按顺序经过节点$a1 -> a2 -> …->ak$。对于所有长为$k$的节点序列$a$(这有$n^k$个),统计至少经过一条$c=1$的边的序列$a$的个数。
样例
输入样例1
4 4
1 2 1
2 3 1
3 4 1
输出样例1
252
输入样例2
4 6
1 2 0
1 3 0
1 4 0
输出样例2
0
输入样例3
3 5
1 2 1
2 3 0
输出样例3
210
算法
(dfs) $O(n*log(k))$
这道题我们从正面去想,看哪些点最短路之间有$1$会比较难做,不如从反方向去想,我们去看哪些点之间的最短路没有$1$。因为这是一棵树,所以我们可以发现一些特殊的性质,假设以k为根节点的子树,以$a$的子节点$a1$为根节点的树中有$m$个节点之间的最短路没有$1$,且$a1$到这$m$个点的最短路中也没有$1$,那么只要$a$到$a1$的边不是$1$,$a$到这$m$个点的最短路中也没有1。如果$a$到$a1$的边为$1$,则把$m$存起来。最后只需要将所有的$m$取出来,答案减去$k^m$就好了。
C++ 代码
// https://www.acwing.com/blog/content/34755/
#include<bits/stdc++.h>
#define endl '\n'
#define fi first
#define se second
#define all(a) a.begin(), a.end()
#define pd push_back
using namespace std;
typedef long long LL;
typedef pair<int,int> PII;
typedef pair<LL,LL> PLL;
typedef pair<double, double> PDD;
const int N = 1e5 + 10, M = N * 2, INF = 0x3f3f3f3f, mod = 1e9 + 7;
int n, m;
int e[M], ne[M], w[M], h[N], idx;
vector<int> cnt;
LL qmi(LL a, LL k, LL mod)
{
LL res = 1;
while(k)
{
if(k&1) res = res * a % mod;
a = a * a % mod;
k >>= 1;
}
return res;
}
inline void add(int a, int b, int c)
{
e[idx]=b, w[idx]=c, ne[idx]=h[a], h[a]=idx++;
}
inline int dfs(int u, int fa)
{
int d = 1;
for(int i=h[u];~i;i=ne[i])
{
int j = e[i];
if(j==fa) continue;
int t = dfs(j, u);
if(w[i]==0) d += t;
else{
if(t) cnt.pd(t);
}
}
return d;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
memset(h, -1, sizeof h);
cin>>n>>m;
for(int i=0;i<n-1;i++){
int a, b, c;
cin>>a>>b>>c;
add(a, b, c), add(b, a, c);
}
int t = dfs(1, -1);
if(t) cnt.pd(t);
LL res = qmi(n, m, mod);
for(auto u:cnt){
LL tmp = qmi(u, m, mod);
res = (res - tmp + mod) % mod;
}
cout<<res<<endl;
return 0;
}