算法
(prufer序列、生成函数)
可以考虑 prufer序列
结论:
顶点 $i$ 为 $d_i$ 的完全图 $K_n$ 的生成树有 $\frac{(n-2)!}{\prod (d_i-1)!}$ 种
我们需要求的是 $\sum\limits_d \frac{(n-2)!}{\prod\limits_{i} (d_i-1)!}$
可以适当做一下变形,$(n-2)! \cdot \sum\limits_d \prod\limits_i\frac{1}{(d_i-1)!}$
其中,$\sum\limits_d \prod\limits_i\frac{1}{(d_i-1)!}$ 这一部分可以转成指数型生成函数求解
记 $F = \sum\limits_{s \in S} \frac{1}{(s-1)!}x^s$
那么,最后的答案就是 $(n-2)! \cdot [x^{2n-2}]F^n$
C++ 代码
#include <bits/stdc++.h>
#if __has_include(<atcoder/all>)
#include <atcoder/all>
using namespace atcoder;
#endif
#define rep(i, n) for (int i = 0; i < (n); ++i)
using namespace std;
using mint = modint998244353;
// combination mod prime
struct modinv {
int n; vector<mint> d;
modinv(): n(2), d({0,1}) {}
mint operator()(int i) {
while (n <= i) d.push_back(-d[mint::mod()%n]*(mint::mod()/n)), ++n;
return d[i];
}
mint operator[](int i) const { return d[i];}
} invs;
struct modfact {
int n; vector<mint> d;
modfact(): n(2), d({1,1}) {}
mint operator()(int i) {
while (n <= i) d.push_back(d.back()*n), ++n;
return d[i];
}
mint operator[](int i) const { return d[i];}
} facts;
struct modfactinv {
int n; vector<mint> d;
modfactinv(): n(2), d({1,1}) {}
mint operator()(int i) {
while (n <= i) d.push_back(d.back()*invs(n)), ++n;
return d[i];
}
mint operator[](int i) const { return d[i];}
} ifacts;
mint comb(int n, int k) {
if (n < k || k < 0) return 0;
return facts(n)*ifacts(k)*ifacts(n-k);
}
// Formal Power Series
using vm = vector<mint>;
struct fps : vm {
#define d (*this)
#define s int(vm::size())
template<class...Args> fps(Args...args): vm(args...) {}
fps(initializer_list<mint> a): vm(a.begin(),a.end()) {}
void rsz(int n) { if (s < n) resize(n);}
fps& low_(int n) { resize(n); return d;}
fps low(int n) const { return fps(d).low_(n);}
mint& operator[](int i) { rsz(i+1); return vm::operator[](i);}
mint operator[](int i) const { return i<s ? vm::operator[](i) : 0;}
mint operator()(mint x) const {
mint r;
for (int i = s-1; i >= 0; --i) r = r*x+d[i];
return r;
}
fps operator-() const { fps r(d); rep(i,s) r[i] = -r[i]; return r;}
fps& operator+=(const fps& a) { rsz(a.size()); rep(i,a.size()) d[i] += a[i]; return d;}
fps& operator-=(const fps& a) { rsz(a.size()); rep(i,a.size()) d[i] -= a[i]; return d;}
fps& operator*=(const fps& a) { return d = convolution(d, a);}
fps& operator*=(mint a) { rep(i,s) d[i] *= a; return d;}
fps& operator/=(mint a) { rep(i,s) d[i] /= a; return d;}
fps operator+(const fps& a) const { return fps(d) += a;}
fps operator-(const fps& a) const { return fps(d) -= a;}
fps operator*(const fps& a) const { return fps(d) *= a;}
fps operator*(mint a) const { return fps(d) *= a;}
fps operator/(mint a) const { return fps(d) /= a;}
fps operator~() const {
fps r({d[0].inv()});
for (int i = 1; i < s; i <<= 1) r = r*mint(2) - (r*r*low(i<<1)).low(i<<1);
return r.low_(s);
}
fps& operator/=(const fps& a) { int w = s; d *= ~a; return d.low_(w);}
fps operator/(const fps& a) const { return fps(d) /= a;}
fps integ() const {
fps r;
rep(i,s) r[i+1] = d[i]/(i+1);
return r;
}
fps pow(int t) {
if (t == 1) return *this;
fps r = pow(t>>1);
(r *= r).low_(s);
if (t&1) (r *= *this).low_(s);
return r;
}
#undef s
#undef d
};
ostream& operator<<(ostream&o,const fps&a) {
rep(i,a.size()) o<<(i?" ":"")<<a[i].val();
return o;
}
int main() {
cin.tie(nullptr) -> sync_with_stdio(false);
int n, k;
cin >> n >> k;
int m = n*2-2;
fps f(m+1);
rep(i, k) {
int s;
cin >> s;
f[s] = ifacts(s-1);
}
f.pow(n);
mint ans = f.pow(n)[m]*facts(n-2);
cout << ans.val() << '\n';
return 0;
}