算法分析
初看这道题的时候以为是个模拟题,然后仔细看了一下数据范围…开始陷入沉思
正解思路是在等式的左右两边分别乘一个随机向量(也就是 $1 \times n$ 的矩阵)即可
于是,$\mathcal{O}(n^3) \to \mathcal{O}(n^2)$
C++ 代码
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < (n); ++i)
using namespace std;
using ll = long long;
const int mod = 998244353;
//const int mod = 1000000007;
struct mint {
ll x;
mint(ll x=0):x((x%mod+mod)%mod) {}
mint operator-() const {
return mint(-x);
}
mint& operator+=(const mint a) {
if ((x += a.x) >= mod) x -= mod;
return *this;
}
mint& operator-=(const mint a) {
if ((x += mod-a.x) >= mod) x -= mod;
return *this;
}
mint& operator*=(const mint a) {
(x *= a.x) %= mod;
return *this;
}
mint operator+(const mint a) const {
return mint(*this) += a;
}
mint operator-(const mint a) const {
return mint(*this) -= a;
}
mint operator*(const mint a) const {
return mint(*this) *= a;
}
mint pow(ll t) const {
if (!t) return 1;
mint a = pow(t>>1);
a *= a;
if (t&1) a *= *this;
return a;
}
// for prime mod
mint inv() const {
return pow(mod-2);
}
mint& operator/=(const mint a) {
return *this *= a.inv();
}
mint operator/(const mint a) const {
return mint(*this) /= a;
}
};
istream& operator>>(istream& is, mint& a) {
return is >> a.x;
}
ostream& operator<<(ostream& os, const mint& a) {
return os << a.x;
}
template<class T=int>
inline T read() {
T num = 0;
int neg = 0;
char c = getchar();
while (!isdigit(c) and c != '-') {
c = getchar();
}
if (c == '-') neg = 1;
else num = c-'0';
c = getchar();
while (isdigit(c)) {
num = (num<<1)+(num<<3)+(c^48);
c = getchar();
}
return num;
}
using vi = vector<mint>;
using vvi = vector<vi>;
vi mul(vi X, vvi A) {
int n = X.size();
vi res(n);
rep(i, n)rep(j, n) res[i] += X[j]*A[j][i];
return res;
}
void solve() {
int n = read();
vvi A(n, vi(n));
rep(i, n)rep(j, n) A[i][j] = read();
vvi B(n, vi(n));
rep(i, n)rep(j, n) B[i][j] = read();
vvi C(n, vi(n));
rep(i, n)rep(j, n) C[i][j] = read();
mt19937 rnd(114514);
vi X(n);
rep(i, n) X[i] = rnd();
vi X1 = mul(X, A);
X1 = mul(X1, B);
vi X2 = mul(X, C);
rep(i, n) {
if (X1[i].x != X2[i].x) {
puts("No");
return;
}
}
puts("Yes");
}
int main() {
int t = read();
while (t--) solve();
return 0;
}