#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cstring>
// #include <immintrin.h>
struct matrix {
matrix(int h, int w) : h(h), w(w) {};
int n, c, h, w;
long long* data;
inline long long& at(int r, int c) {return *(data + r * w + c);}
inline long long* atptr(int r, int c) {return (data + r * w + c);}
};
void mmalloc(matrix& m, size_t align=8) {
// if can use AVX2 align should be 128/256bit
// here use 8B.
m.data = (long long*)malloc(sizeof(long long) * m.h * m.w);
memset(m.data, 0, sizeof(m.data));
}
void mfree(matrix& m) {
free(m.data);
}
void matMul(matrix& lhs, matrix& rhs, matrix& out) {
int m = lhs.h;
int q = rhs.w;
int n = lhs.w;
for (int i = 0; i < m; ++i) {
for (int j = 0; j < q; ++j) {
for (int k = 0; k < n; ++k) {
out.at(i, j) += lhs.at(i, k) * rhs.at(k, j);
}
}
}
}
void vecMul(matrix& lhs, matrix& rhs, matrix& out) {
int h = rhs.h;
int w = rhs.w;
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
out.at(i, j) = lhs.at(0, i) * rhs.at(i, j);
}
}
}
void transpose(matrix& in, matrix& out) {
int h = in.h;
int w = in.w;
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
out.at(j, i) = in.at(i, j);
}
}
}
void mfill(matrix& m) {
int h = m.h;
int w = m.w;
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
scanf("%lld", m.atptr(i, j));
}
}
}
void mprint(matrix& m) {
int h = m.h;
int w = m.w;
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
printf("%lld ", m.at(i, j));
}
printf("\n");
}
}
int main() {
int n, d;
scanf("%d %d", &n, &d);
matrix W(1, n), Q(n, d), K(n, d), Kt(d, n), V(n, d), KtV(d, d), QKtV(n, d), Res(n, d);
// alloc
mmalloc(W);
mmalloc(Q);
mmalloc(K);
mmalloc(Kt);
mmalloc(V);
mmalloc(KtV);
mmalloc(QKtV);
mmalloc(Res);
// read
mfill(Q);
mfill(K);
mfill(V);
mfill(W);
// matMul
transpose(K, Kt);
mfree(K);
matMul(Kt, V, KtV);
mfree(Kt);
mfree(V);
matMul(Q, KtV, QKtV);
mfree(Q);
mfree(KtV);
vecMul(W, QKtV, Res);
mfree(W);
mfree(QKtV);
mprint(Res);
mfree(Res);
}
可以拆分成 4x4 矩阵,Reduce 展开成 FMA 指令。