1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
| void CalcAtMA(const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor> &A, const Eigen::VectorXf &M, Eigen::Matrix<float, 8, 8> &At_M_A) {
int At_idx; int A_idx; int At_M_A_idx;
float32x4_t At0; float32x4_t At1; float32x4_t At2; float32x4_t At3;
float32x4_t A0; float32x4_t A1; float32x4_t A2; float32x4_t A3;
float32x4_t At_M_A0; float32x4_t At_M_A1; float32x4_t At_M_A2; float32x4_t At_M_A3;
uint32_t n = A.cols(); uint32_t k = A.rows(); uint32_t k_step = k - 4;
const float *A_ptr = A.data(); const float *M_ptr = M.data(); float *At_M_A_ptr = At_M_A.data();
for (int i_idx = 0; i_idx < n; i_idx += 4) { for (int j_idx = i_idx; j_idx < n; j_idx += 4) { At_M_A0 = vmovq_n_f32(0); At_M_A1 = vmovq_n_f32(0); At_M_A2 = vmovq_n_f32(0); At_M_A3 = vmovq_n_f32(0); int k_idx = 0; for (; k_idx <= k_step; k_idx += 4) {
At_idx = k * i_idx + k_idx; A_idx = k * j_idx + k_idx;
At0 = vld1q_f32(A_ptr + At_idx); At1 = vld1q_f32(A_ptr + At_idx + k); At2 = vld1q_f32(A_ptr + At_idx + 2 * k); At3 = vld1q_f32(A_ptr + At_idx + 3 * k);
MatTransposeInp4x4NeonF32(At0, At1, At2, At3, At0, At1, At2, At3);
At0 = vmulq_n_f32(At0, M_ptr[k_idx]); At1 = vmulq_n_f32(At1, M_ptr[k_idx + 1]); At2 = vmulq_n_f32(At2, M_ptr[k_idx + 2]); At3 = vmulq_n_f32(At3, M_ptr[k_idx + 3]);
A0 = vld1q_f32(A_ptr + A_idx); At_M_A0 = vfmaq_laneq_f32(At_M_A0, At0, A0, 0); At_M_A0 = vfmaq_laneq_f32(At_M_A0, At1, A0, 1); At_M_A0 = vfmaq_laneq_f32(At_M_A0, At2, A0, 2); At_M_A0 = vfmaq_laneq_f32(At_M_A0, At3, A0, 3);
A1 = vld1q_f32(A_ptr + A_idx + k); At_M_A1 = vfmaq_laneq_f32(At_M_A1, At0, A1, 0); At_M_A1 = vfmaq_laneq_f32(At_M_A1, At1, A1, 1); At_M_A1 = vfmaq_laneq_f32(At_M_A1, At2, A1, 2); At_M_A1 = vfmaq_laneq_f32(At_M_A1, At3, A1, 3);
A2 = vld1q_f32(A_ptr + A_idx + 2 * k); At_M_A2 = vfmaq_laneq_f32(At_M_A2, At0, A2, 0); At_M_A2 = vfmaq_laneq_f32(At_M_A2, At1, A2, 1); At_M_A2 = vfmaq_laneq_f32(At_M_A2, At2, A2, 2); At_M_A2 = vfmaq_laneq_f32(At_M_A2, At3, A2, 3);
A3 = vld1q_f32(A_ptr + A_idx + 3 * k); At_M_A3 = vfmaq_laneq_f32(At_M_A3, At0, A3, 0); At_M_A3 = vfmaq_laneq_f32(At_M_A3, At1, A3, 1); At_M_A3 = vfmaq_laneq_f32(At_M_A3, At2, A3, 2); At_M_A3 = vfmaq_laneq_f32(At_M_A3, At3, A3, 3); } At_M_A_idx = n * j_idx + i_idx; vst1q_f32(At_M_A_ptr + At_M_A_idx, At_M_A0); vst1q_f32(At_M_A_ptr + At_M_A_idx + n, At_M_A1); vst1q_f32(At_M_A_ptr + At_M_A_idx + 2 * n, At_M_A2); vst1q_f32(At_M_A_ptr + At_M_A_idx + 3 * n, At_M_A3);
for (; k_idx < k; k_idx++) { for (int jp_idx = 0; jp_idx < 4; jp_idx++) { for (int ip_idx = 0; ip_idx < 4; ip_idx++) { At_M_A_ptr[At_M_A_idx + jp_idx * n + ip_idx] += A_ptr[(i_idx + ip_idx) * k + k_idx] * A_ptr[(j_idx + jp_idx) * k + k_idx]; } } } } } for (int i_idx = 0; i_idx < n; i_idx ++) { At_M_A_idx = i_idx * n; for (int j_idx = i_idx + 1; j_idx < n; j_idx ++) { At_M_A_ptr[At_M_A_idx + j_idx] = At_M_A_ptr[j_idx * n + i_idx]; } } }
|