使用NEON指令集优化矩阵运算代码

有一个矩阵运算\(Y=J^tMJ\),其中\(J\)为[30576, 8] 大的矩阵,\(M\)是[30576, 30576]的对角矩阵,最终输出的\(Y\)为[8, 8]大的对称矩阵。使用Eigen库,如下实现矩阵运算:

1
Y.noalias() = J.transpose() * M.asDiagonal() * J; // 这里M为一个30576长的vector

我们希望能够能够通过NEON指令集实现该矩阵运算,获得更短的运算耗时。

NEON优化版

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
void CalcAtMA(const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor> &A,
const Eigen::VectorXf &M,
Eigen::Matrix<float, 8, 8> &At_M_A) {
// 参考: https://developer.arm.com/documentation/102107a/0100/Single-precision-4x4-matrix-multiplication

int At_idx;
int A_idx;
int At_M_A_idx;

// these are the columns of a 4x4 sub matrix of At
float32x4_t At0;
float32x4_t At1;
float32x4_t At2;
float32x4_t At3;

// these are the columns of a 4x4 sub matrix of A
float32x4_t A0;
float32x4_t A1;
float32x4_t A2;
float32x4_t A3;

// these are the columns of a 4x4 sub matrix of At_M_A
float32x4_t At_M_A0;
float32x4_t At_M_A1;
float32x4_t At_M_A2;
float32x4_t At_M_A3;

// [8, 8] = [8, 30576] * [30576, 30576] * [30576, 8]
// n k k k k m
uint32_t n = A.cols(); // = m
uint32_t k = A.rows();
uint32_t k_step = k - 4;

const float *A_ptr = A.data();
const float *M_ptr = M.data();
float *At_M_A_ptr = At_M_A.data();

for (int i_idx = 0; i_idx < n; i_idx += 4) {
for (int j_idx = i_idx; j_idx < n; j_idx += 4) {
// Zero accumulators before matrix op
At_M_A0 = vmovq_n_f32(0);
At_M_A1 = vmovq_n_f32(0);
At_M_A2 = vmovq_n_f32(0);
At_M_A3 = vmovq_n_f32(0);
int k_idx = 0;
for (; k_idx <= k_step; k_idx += 4) {

// Compute base index to 4x4 block
At_idx = k * i_idx + k_idx;
A_idx = k * j_idx + k_idx;

// Load most current At values in row
At0 = vld1q_f32(A_ptr + At_idx);
At1 = vld1q_f32(A_ptr + At_idx + k);
At2 = vld1q_f32(A_ptr + At_idx + 2 * k);
At3 = vld1q_f32(A_ptr + At_idx + 3 * k);

MatTransposeInp4x4NeonF32(At0, At1, At2, At3, At0, At1, At2, At3);

At0 = vmulq_n_f32(At0, M_ptr[k_idx]);
At1 = vmulq_n_f32(At1, M_ptr[k_idx + 1]);
At2 = vmulq_n_f32(At2, M_ptr[k_idx + 2]);
At3 = vmulq_n_f32(At3, M_ptr[k_idx + 3]);

// Multiply accumulate in 4x1 blocks, i.e. each column in C
// Load most current A values in col
A0 = vld1q_f32(A_ptr + A_idx);
At_M_A0 = vfmaq_laneq_f32(At_M_A0, At0, A0, 0);
At_M_A0 = vfmaq_laneq_f32(At_M_A0, At1, A0, 1);
At_M_A0 = vfmaq_laneq_f32(At_M_A0, At2, A0, 2);
At_M_A0 = vfmaq_laneq_f32(At_M_A0, At3, A0, 3);

A1 = vld1q_f32(A_ptr + A_idx + k);
At_M_A1 = vfmaq_laneq_f32(At_M_A1, At0, A1, 0);
At_M_A1 = vfmaq_laneq_f32(At_M_A1, At1, A1, 1);
At_M_A1 = vfmaq_laneq_f32(At_M_A1, At2, A1, 2);
At_M_A1 = vfmaq_laneq_f32(At_M_A1, At3, A1, 3);

A2 = vld1q_f32(A_ptr + A_idx + 2 * k);
At_M_A2 = vfmaq_laneq_f32(At_M_A2, At0, A2, 0);
At_M_A2 = vfmaq_laneq_f32(At_M_A2, At1, A2, 1);
At_M_A2 = vfmaq_laneq_f32(At_M_A2, At2, A2, 2);
At_M_A2 = vfmaq_laneq_f32(At_M_A2, At3, A2, 3);

A3 = vld1q_f32(A_ptr + A_idx + 3 * k);
At_M_A3 = vfmaq_laneq_f32(At_M_A3, At0, A3, 0);
At_M_A3 = vfmaq_laneq_f32(At_M_A3, At1, A3, 1);
At_M_A3 = vfmaq_laneq_f32(At_M_A3, At2, A3, 2);
At_M_A3 = vfmaq_laneq_f32(At_M_A3, At3, A3, 3);
}
// Compute base index for stores
At_M_A_idx = n * j_idx + i_idx;
vst1q_f32(At_M_A_ptr + At_M_A_idx, At_M_A0);
vst1q_f32(At_M_A_ptr + At_M_A_idx + n, At_M_A1);
vst1q_f32(At_M_A_ptr + At_M_A_idx + 2 * n, At_M_A2);
vst1q_f32(At_M_A_ptr + At_M_A_idx + 3 * n, At_M_A3);

for (; k_idx < k; k_idx++) {
for (int jp_idx = 0; jp_idx < 4; jp_idx++) {
for (int ip_idx = 0; ip_idx < 4; ip_idx++) {
At_M_A_ptr[At_M_A_idx + jp_idx * n + ip_idx] += A_ptr[(i_idx + ip_idx) * k + k_idx] * A_ptr[(j_idx + jp_idx) * k + k_idx];
}
}
}
}
}

// 补全下三角形矩阵
for (int i_idx = 0; i_idx < n; i_idx ++) {
At_M_A_idx = i_idx * n;
for (int j_idx = i_idx + 1; j_idx < n; j_idx ++) {
At_M_A_ptr[At_M_A_idx + j_idx] = At_M_A_ptr[j_idx * n + i_idx];
}
}
}

耗时测试

方法 M1 Max, 时间(ms) RK3588, 时间(ms)
Eigen 0.15472 0.604365
NEON 0.114444 0.19703

在M1 Max平台上,加速比(0.15472 - 0.114444) / 0.15472 = 26%;RK3588平台上,加速比(0.604365 - 0.19703) / 0.604365 = 67.4%。可以看到RK3588平台上的加速效果更加明显。