codeforce每日一题链接
今日题目讲解视频
题目链接
题目分数:1900
题目描述
输入$n(2≤n≤2e5)$和长为n 的数组$a(1≤a[i]≤n,a[i]!=i)$,表示一个n点n边的无向图(节点编号从1开始),点$i$和$a[i]$相连。你需要给每条边定向(无向变有向),这一共有$2^n$种方案。其中有多少种方案,可以使图中没有环? 结果模$1e9+7$。
样例
输入样例1
3
2 3 1
输出样例1
6
输入样例2
4
2 1 1 1
输出样例2
8
输入样例3
5
2 4 2 5 3
输出样例3
28
算法
(快速幂 + 环) $O(n*log(n))$
每一条边都有两种方向,不考虑有没有环,总的方案数为$2^n$。那么考虑环之后,我们发现在给定的图中,只要出现了环,在这个环中的情况就要减去两种。所以我们只要找出所有的环,求出它们的二次幂减去2后的乘积,最后再乘上不参与形成环的点的二次幂就行了。
C++ 代码
// https://www.acwing.com/blog/content/34755/
#include<bits/stdc++.h>
#define endl '\n'
#define fi first
#define se second
#define all(a) a.begin(), a.end()
#define pd push_back
using namespace std;
typedef long long LL;
typedef pair<int,int> PII;
typedef pair<LL,LL> PLL;
typedef pair<double, double> PDD;
const int N = 2e5+10, M = N * 2, INF = 0x3f3f3f3f, mod = 1e9+7;
int n, m;
int w[N], ti[N];
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cin>>n;
for(int i=1;i<=n;i++) cin>>w[i];
auto qmi = [&](LL a, LL k, LL mod){
LL res = 1;
while(k)
{
if(k&1) res = res * a % mod;
a = a * a % mod;
k >>= 1;
}
return res;
};
int t = 1, m = n;
LL res = 1;
for(int i=1;i<=n;i++)
{
if(ti[i]) continue;
int start = t;
for(int j=i;j>0;j=w[j])
{
if(ti[j]>0){
if(ti[j]>=start){
int tmp = t - ti[j];
res = (res*(qmi(2, tmp, mod) - 2)%mod+mod)%mod;
m -= tmp;
}
break;
}
ti[j] = t;
t++;
}
}
cout<<(res*qmi(2, m, mod)%mod+mod)%mod<<endl;
return 0;
}
java代码
import java.util.Scanner;
import java.util.Arrays;
public class Main{
public static long mod = 1000000007;
public static long qmi(long a, long k, long mod)
{
long res = 1;
while(k > 0)
{
if((k & 1) == 1) res = res * a % mod;
k >>= 1;
a = a * a % mod;
}
return res;
}
public static void main(String[] args){
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int[] w = new int[n + 1];
for(int i=1;i<=n;i++) w[i] = sc.nextInt();
int[] ti = new int[n + 1];
int t = 0, m = n;
long res = 1;
for(int i = 1; i <= n; i ++)
{
if(ti[i] > 0) continue;
int start = t;
for(int j = i; ; j = w[j])
{
if(ti[j] > 0){
if(ti[j] >= start)
{
int s = t - ti[j];
res = (res * (qmi(2, s, mod) - 2) % mod + mod) % mod;
m -= s;
}
break;
}
ti[j] = t;
t ++;
}
}
res = (res * qmi(2, m, mod) % mod + mod) % mod;
System.out.println(res);
}
}
python代码
n = int(input())
w = [int(i) - 1 for i in input().split()]
ti = [0] * n
t, m = 0, n
res = 1
mod = int(1e9 + 7)
def qmi(a, k, mod) -> int :
res = 1
while k > 0:
if (k & 1) == 1:
res = res * a % mod
k >>= 1
a = a * a % mod
return res
for i in range(n):
if ti[i] > 0:
continue
j, start = i, t
while ti[j] == 0:
ti[j] = t
t += 1
j = w[j]
if ti[j] >= start:
s = t - ti[j]
m -= s
res = (res * (qmi(2, s, mod) - 2) % mod + mod) % mod
res = (res * qmi(2, m, mod) + mod) % mod
print(res)