只做代码分享
组合数
#include<iostream>
using namespace std;
typedef long long LL;
const int N = 2005, mod = 998244353;
int primes[N], cnt;
int index[N];
bool st[N];
void get_primes(int a, int b)
{
for(int i = 2; i <= a; i++)
{
if(!st[i])
{
primes[cnt++] = i;
for(int j = a; j; j /= i) index[cnt - 1] += j/i;
for(int j = b; j; j /= i) index[cnt - 1] -= j/i;
for(int j = a - b; j; j /= i) index[cnt - 1] -= j/i;
}
for(int j = 0; primes[j] <= a/i; j++)
{
st[primes[j] * i] = true;
if(i % primes[j] == 0) break;
}
}
}
int qmi(LL a, LL b)
{
LL ans = 1;
for(; b; b >>= 1)
{
if(b & 1) ans = ans * a % mod;
a = a * a % mod;
}
return ans;
}
int main()
{
int n, m, k;
cin >> n >> m >> k;
//C(n-1, k) * m * (m-1)^k
get_primes(n - 1, k);
LL ans = (LL)m * qmi(m - 1, k) % mod;
for(int i = 0; i < cnt; i++)
ans = ans * qmi(primes[i], index[i]) % mod;
cout << ans;
return 0;
}
dp
#include<iostream>
using namespace std;
typedef long long LL;
const int N = 2005, mod = 998244353;
LL f[N][N];//前 i 个小朋友, 恰好有 j 个小朋友与其左边拿的水果不一样
int main()
{
int n, m, k;
cin >> n >> m >> k;
for(int i = 1; i <= n; i++)
f[i][0] = m;
for(int i = 2; i <= n; i++)
for(int j = 1; j <= min(i, k); j++)
f[i][j] = (f[i-1][j] + f[i-1][j-1] * (m - 1) % mod) % mod;
cout << f[n][k];
return 0;
}
一维
#include<iostream>
using namespace std;
typedef long long LL;
const int N = 2005, mod = 998244353;
LL f[N];
int main()
{
int n, m, k;
cin >> n >> m >> k;
f[0] = m;
for(int i = 2; i <= n; i++)
for(int j = min(i, k); j; j--)
f[j] = (f[j] + f[j-1] * (m - 1) % mod) % mod;
cout << f[k];
return 0;
}