From fa224bbfe33ad3597a5c6501c71315b8c25cee1d Mon Sep 17 00:00:00 2001 From: Nick Dingle Date: Thu, 5 Dec 2024 10:35:23 +0000 Subject: [PATCH] Fix LDPC decoding accuracy A poster on our forum observed that the accuracy of our LDPC decoder was very much worse than that of other implementations. The cause is that we do all our calculations in int8_t fixed-point, which quickly saturates. The fix is to do all intermediate calculations in int16_t fixed-point. On average this causes a performance regression of between 1.2x and 1.3x across all ArmRAL benchmarks. --- CHANGELOG.md | 5 + src/UpperPHY/LDPC/ldpc_decoder.cpp | 934 +++++++++++++++-------------- 2 files changed, 475 insertions(+), 464 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 915c20f..925bc9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,11 @@ documented in this file. ### Fixed +- LDPC decoding (`armral_ldpc_decode_block`) now achieves the expected error + correction performance in the presence of channel noise. The function now uses + `int16_t` internally rather than `int8_t`, which can be slower for certain + input sizes. + ### Security ## [24.10] - 2024-10-17 diff --git a/src/UpperPHY/LDPC/ldpc_decoder.cpp b/src/UpperPHY/LDPC/ldpc_decoder.cpp index 7cb98bd..78bda92 100644 --- a/src/UpperPHY/LDPC/ldpc_decoder.cpp +++ b/src/UpperPHY/LDPC/ldpc_decoder.cpp @@ -74,18 +74,18 @@ public: m_buffer_size = m_total_bits >> 3; } - m_llrs = allocate_uninitialized(allocator, m_total_bits + m_z - 1); + m_llrs = allocate_uninitialized(allocator, m_total_bits + m_z - 1); m_buffer = allocate_uninitialized(allocator, m_buffer_size); } - bool check(const int8_t *new_llrs) { + bool check(const int16_t *new_llrs) { // Copy the LLRs corresponding to the bits we need to do the CRC check after // the padding bits - memset(m_llrs.get(), 0, m_num_pad_bits * sizeof(int8_t)); + memset(m_llrs.get(), 0, m_num_pad_bits * sizeof(int16_t)); for (uint32_t num_block = 0; num_block < ((m_k_prime + m_z - 1) / m_z); num_block++) { memcpy(m_llrs.get() + m_num_pad_bits + (num_block * m_z), - new_llrs + (2 * num_block * m_z), m_z * sizeof(int8_t)); + new_llrs + (2 * num_block * m_z), m_z * sizeof(int16_t)); } // Hard decode @@ -106,20 +106,20 @@ private: uint32_t m_buffer_size{0}; uint32_t m_num_pad_bits{0}; uint32_t m_total_bits{0}; - unique_ptr m_llrs; + unique_ptr m_llrs; unique_ptr m_buffer; }; template -bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, +bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, const armral_ldpc_base_graph_t *graph, int32_t num_lanes, - int32_t full_vec, int32_t tail_size, int8_t *check); + int32_t full_vec, uint32_t tail_size, int16_t *check); template<> -bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, +bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, const armral_ldpc_base_graph_t *graph, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *check) { + uint32_t tail_size, int16_t *check) { // Loop through the rows in the base graph bool passed = true; for (uint32_t row = 0; row < graph->nrows && passed; ++row) { @@ -132,7 +132,7 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, // Loop through the rows in the block for (uint32_t zb = 0; zb < z && passed; ++zb) { // Loop through the columns in the row - int8_t scal_check = 0; + int16_t scal_check = 0; for (uint32_t col = 0; col < num_cols; ++col) { auto shift = (shift_ptr[col] + zb) % z; auto codeword_ind = col_ptr[col] * z + shift; @@ -145,14 +145,14 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, } template<> -bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, +bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, const armral_ldpc_base_graph_t *graph, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *check) { + uint32_t tail_size, int16_t *check) { // Loop through the rows in the base graph bool passed = true; #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg_tail = svwhilelt_b8(0, (int)tail_size); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); for (uint32_t row = 0; row < graph->nrows && passed; ++row) { auto row_start_ind = graph->row_start_inds[row]; auto num_cols = graph->row_start_inds[row + 1] - row_start_ind; @@ -160,7 +160,7 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, const auto *shift_ptr = graph->shifts + row_start_ind * armral::ldpc::num_lifting_sets + lsi * num_cols; - memset(check, 0, z * sizeof(int8_t)); + memset(check, 0, z * sizeof(int16_t)); // Loop through the columns for (uint32_t col = 0; col < num_cols; ++col) { @@ -168,13 +168,13 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, auto codeword_ind = col_ptr[col] * (2 * z) + shift; // No need to loop here, as there is only a tail - const int8_t *llrs_ptr = llrs + codeword_ind; - int8_t *check_ptr = check; + const int16_t *llrs_ptr = llrs + codeword_ind; + int16_t *check_ptr = check; - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t check_reg = svld1_s8(pg_tail, check_ptr); - svint8_t result_reg = sveor_s8_x(pg_tail, check_reg, llrs_reg); - svst1_s8(pg_tail, check_ptr, result_reg); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t check_reg = svld1_s16(pg_tail, check_ptr); + svint16_t result_reg = sveor_s16_x(pg_tail, check_reg, llrs_reg); + svst1_s16(pg_tail, check_ptr, result_reg); } for (uint32_t zb = 0; zb < z && passed; ++zb) { passed &= check[zb] >= 0; @@ -188,7 +188,7 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, const auto *shift_ptr = graph->shifts + row_start_ind * armral::ldpc::num_lifting_sets + lsi * num_cols; - memset(check, 0, z * sizeof(int8_t)); + memset(check, 0, z * sizeof(int16_t)); // Loop through the columns for (uint32_t col = 0; col < num_cols; ++col) { @@ -196,13 +196,13 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, auto codeword_ind = col_ptr[col] * (2 * z) + shift; // Loop through the rows in the block - const int8_t *llrs_ptr = llrs + codeword_ind; - int8_t *check_ptr = check; + const int16_t *llrs_ptr = llrs + codeword_ind; + int16_t *check_ptr = check; - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t check_reg = vld1_s8(check_ptr); - int8x8_t result_reg = veor_s8(check_reg, llrs_reg); - vst1_s8(check_ptr, result_reg); + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t check_reg = vld1_s16(check_ptr); + int16x4_t result_reg = veor_s16(check_reg, llrs_reg); + vst1_s16(check_ptr, result_reg); // Deal with a tail for (uint32_t zb = z - tail_size; zb < z; ++zb) { @@ -218,13 +218,13 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, } template<> -bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, +bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, const armral_ldpc_base_graph_t *graph, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *check) { + uint32_t tail_size, int16_t *check) { #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg = svptrue_b8(); - svbool_t pg_tail = svwhilelt_b8(0, tail_size); + svbool_t pg = svptrue_b16(); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); // Loop through the rows in the base graph bool passed = true; @@ -235,7 +235,7 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, const auto *shift_ptr = graph->shifts + row_start_ind * armral::ldpc::num_lifting_sets + lsi * num_cols; - memset(check, 0, z * sizeof(int8_t)); + memset(check, 0, z * sizeof(int16_t)); // Loop through the columns for (uint32_t col = 0; col < num_cols; ++col) { @@ -248,14 +248,14 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, // represent a hard decision for the bit to be one, and non-negative // values represent a zero. Hence the check needs to xor all LLRs // and then assert that the result is non-negative. - const int8_t *llrs_ptr = llrs + codeword_ind; - int8_t *check_ptr = check; + const int16_t *llrs_ptr = llrs + codeword_ind; + int16_t *check_ptr = check; for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - svint8_t llrs_reg = svld1_s8(pg, llrs_ptr); - svint8_t check_reg = svld1_s8(pg, check_ptr); - svint8_t result_reg = sveor_s8_x(pg, check_reg, llrs_reg); - svst1_s8(pg, check_ptr, result_reg); + svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); + svint16_t check_reg = svld1_s16(pg, check_ptr); + svint16_t result_reg = sveor_s16_x(pg, check_reg, llrs_reg); + svst1_s16(pg, check_ptr, result_reg); // Increment pointers llrs_ptr += num_lanes; @@ -263,10 +263,10 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, } // Process tail if (tail_size != 0) { - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t check_reg = svld1_s8(pg_tail, check_ptr); - svint8_t result_reg = sveor_s8_x(pg_tail, check_reg, llrs_reg); - svst1_s8(pg_tail, check_ptr, result_reg); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t check_reg = svld1_s16(pg_tail, check_ptr); + svint16_t result_reg = sveor_s16_x(pg_tail, check_reg, llrs_reg); + svst1_s16(pg_tail, check_ptr, result_reg); } } for (uint32_t zb = 0; zb < z && passed; ++zb) { @@ -284,7 +284,7 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, const auto *shift_ptr = graph->shifts + row_start_ind * armral::ldpc::num_lifting_sets + lsi * num_cols; - memset(check, 0, z * sizeof(int8_t)); + memset(check, 0, z * sizeof(int16_t)); // Loop through the columns for (uint32_t col = 0; col < num_cols; ++col) { @@ -298,27 +298,27 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, // represent a hard decision for the bit to be one, and non-negative // values represent a zero. Hence the check needs to xor all LLRs // and then assert that the result is non-negative. - const int8_t *llrs_ptr = llrs + codeword_ind; - int8_t *check_ptr = check; + const int16_t *llrs_ptr = llrs + codeword_ind; + int16_t *check_ptr = check; - // Process 16 entries at a time + // Process 8 entries at a time for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - int8x16_t llrs_reg = vld1q_s8(llrs_ptr); - int8x16_t check_reg = vld1q_s8(check_ptr); - int8x16_t result_reg = veorq_s8(check_reg, llrs_reg); - vst1q_s8(check_ptr, result_reg); + int16x8_t llrs_reg = vld1q_s16(llrs_ptr); + int16x8_t check_reg = vld1q_s16(check_ptr); + int16x8_t result_reg = veorq_s16(check_reg, llrs_reg); + vst1q_s16(check_ptr, result_reg); // Increment pointers llrs_ptr += num_lanes; check_ptr += num_lanes; } - // Process a group of 8 elts - if (tail_cnt > 7) { - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t check_reg = vld1_s8(check_ptr); - int8x8_t result_reg = veor_s8(check_reg, llrs_reg); - vst1_s8(check_ptr, result_reg); - tail_cnt = z & 0x7; + // Process a group of 4 elts + if (tail_cnt > 3U) { + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t check_reg = vld1_s16(check_ptr); + int16x4_t result_reg = veor_s16(check_reg, llrs_reg); + vst1_s16(check_ptr, result_reg); + tail_cnt = z & 0x3; } // Deal with a tail for (uint32_t zb = z - tail_cnt; zb < z; ++zb) { @@ -342,42 +342,40 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, // - \min_{n \in \psi(m)} |L(n,m)| and the second minimum (they will be used to // compute |R(n,m)| in a second step) template -void compute_l_product_min1_and_min2(int8_t *l, const int8_t *__restrict__ llrs, - const int8_t *__restrict__ r, - const ldpc_layer_data *d, - int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *row_min_array, - int8_t *row_min2_array, - int8_t *row_sign_array); +void compute_l_product_min1_and_min2( + int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, + const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, + uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, + int16_t *row_sign_array); template<> void compute_l_product_min1_and_min2( - int8_t *l, const int8_t *__restrict__ llrs, const int8_t *__restrict__ r, + int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *row_min_array, int8_t *row_min2_array, - int8_t *row_sign_array) { + uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, + int16_t *row_sign_array) { const auto *r_ptr = r; // Loop through the Z rows in the layer (check node m) for (uint32_t zb = 0; zb < d->z; ++zb) { // Loop through the columns in the row (variable node n in psi(m)) // Column 0 auto shift = (d->shift_ptr[0] + zb) % d->z; - int8_t l_val = llrs[d->col_ptr[0] * d->z + shift] - *(r_ptr++); + int16_t l_val = vqsubh_s16(llrs[d->col_ptr[0] * d->z + shift], *(r_ptr++)); - int8_t row_sign = l_val; + int16_t row_sign = l_val; - int8_t row_min = vqabsb_s8(l_val); + int16_t row_min = vqabsh_s16(l_val); *(l++) = l_val; // Column 1 shift = (d->shift_ptr[1] + zb) % d->z; - l_val = llrs[d->col_ptr[1] * d->z + shift] - *(r_ptr++); + l_val = vqsubh_s16(llrs[d->col_ptr[1] * d->z + shift], *(r_ptr++)); row_sign ^= l_val; - int8_t abs_val = vqabsb_s8(l_val); - int8_t row_min2 = max(row_min, abs_val); + int16_t abs_val = vqabsh_s16(l_val); + int16_t row_min2 = max(row_min, abs_val); row_min = min(row_min, abs_val); *(l++) = l_val; @@ -386,13 +384,13 @@ void compute_l_product_min1_and_min2( for (uint32_t col = 2; col < d->num_cols; ++col) { // Compute L(n,m) = LLR(n) - R(n,m) shift = (d->shift_ptr[col] + zb) % d->z; - l_val = llrs[d->col_ptr[col] * d->z + shift] - *(r_ptr++); + l_val = vqsubh_s16(llrs[d->col_ptr[col] * d->z + shift], *(r_ptr++)); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) row_sign ^= l_val; // Compute the min(|L(n,m)|) and the second minimum - abs_val = vqabsb_s8(l_val); + abs_val = vqabsh_s16(l_val); row_min2 = max(row_min, min(row_min2, abs_val)); row_min = min(row_min, abs_val); @@ -409,30 +407,30 @@ void compute_l_product_min1_and_min2( template<> void compute_l_product_min1_and_min2( - int8_t *l, const int8_t *__restrict__ llrs, const int8_t *__restrict__ r, + int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *row_min_array, int8_t *row_min2_array, - int8_t *row_sign_array) { - // Case for lifting sizes Z such as 8 <= Z < 16 + uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, + int16_t *row_sign_array) { + // Case for lifting sizes Z such as 4 <= Z < 8 #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg_tail = svwhilelt_b8(0, (int)tail_size); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); // Loop through the columns in the row (variable node n in psi(m)) // Column 0 - int8_t *l_ptr = l; + int16_t *l_ptr = l; auto shift = d->shift_ptr[0] % d->z; - const int8_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int8_t *r_ptr = r; + const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; + const int16_t *r_ptr = r; - svint8_t r_reg = svld1_s8(pg_tail, r_ptr); - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg_tail, r_ptr); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - svint8_t row_sign = l_reg; + svint16_t row_sign = l_reg; - svint8_t row_min = svqabs_s8_x(pg_tail, l_reg); + svint16_t row_min = svqabs_s16_x(pg_tail, l_reg); - svst1_s8(pg_tail, l_ptr, l_reg); + svst1_s16(pg_tail, l_ptr, l_reg); // Column 1 l_ptr = l + d->z; @@ -440,17 +438,17 @@ void compute_l_product_min1_and_min2( llrs_ptr = llrs + d->col_ptr[1] * (2 * d->z) + shift; r_ptr = r + d->z; - r_reg = svld1_s8(pg_tail, r_ptr); - llrs_reg = svld1_s8(pg_tail, llrs_ptr); - l_reg = svqsub_s8(llrs_reg, r_reg); + r_reg = svld1_s16(pg_tail, r_ptr); + llrs_reg = svld1_s16(pg_tail, llrs_ptr); + l_reg = svqsub_s16(llrs_reg, r_reg); - row_sign = sveor_s8_x(pg_tail, row_sign, l_reg); + row_sign = sveor_s16_x(pg_tail, row_sign, l_reg); - svint8_t abs_reg = svqabs_s8_x(pg_tail, l_reg); - svint8_t row_min2 = svmax_s8_x(pg_tail, row_min, abs_reg); - row_min = svmin_s8_x(pg_tail, row_min, abs_reg); + svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); + svint16_t row_min2 = svmax_s16_x(pg_tail, row_min, abs_reg); + row_min = svmin_s16_x(pg_tail, row_min, abs_reg); - svst1_s8(pg_tail, l_ptr, l_reg); + svst1_s16(pg_tail, l_ptr, l_reg); // Columns n >= 2 for (uint32_t col = 2; col < d->num_cols; ++col) { @@ -460,52 +458,52 @@ void compute_l_product_min1_and_min2( r_ptr = r + d->z * col; // Compute L(n,m) = LLR(n) - R(n,m) - r_reg = svld1_s8(pg_tail, r_ptr); - llrs_reg = svld1_s8(pg_tail, llrs_ptr); - l_reg = svqsub_s8(llrs_reg, r_reg); + r_reg = svld1_s16(pg_tail, r_ptr); + llrs_reg = svld1_s16(pg_tail, llrs_ptr); + l_reg = svqsub_s16(llrs_reg, r_reg); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - row_sign = sveor_s8_x(pg_tail, row_sign, l_reg); + row_sign = sveor_s16_x(pg_tail, row_sign, l_reg); // Compute the min(|L(n,m)|) and the second minimum - abs_reg = svqabs_s8_x(pg_tail, l_reg); + abs_reg = svqabs_s16_x(pg_tail, l_reg); row_min2 = - svmax_s8_x(pg_tail, row_min, svmin_s8_x(pg_tail, row_min2, abs_reg)); - row_min = svmin_s8_x(pg_tail, row_min, abs_reg); + svmax_s16_x(pg_tail, row_min, svmin_s16_x(pg_tail, row_min2, abs_reg)); + row_min = svmin_s16_x(pg_tail, row_min, abs_reg); // Store L(n,m) - svst1_s8(pg_tail, l_ptr, l_reg); + svst1_s16(pg_tail, l_ptr, l_reg); } // Store the two minima and the product for Z rows - svst1_s8(pg_tail, row_min_array, row_min); - svst1_s8(pg_tail, row_min2_array, row_min2); - svst1_s8(pg_tail, row_sign_array, row_sign); + svst1_s16(pg_tail, row_min_array, row_min); + svst1_s16(pg_tail, row_min2_array, row_min2); + svst1_s16(pg_tail, row_sign_array, row_sign); #else // Loop through the columns in the row (variable node n in psi(m)) // Column 0 - int8_t *l_ptr = l; + int16_t *l_ptr = l; auto shift = d->shift_ptr[0] % d->z; - const int8_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int8_t *r_ptr = r; + const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; + const int16_t *r_ptr = r; - int8x8_t r_reg = vld1_s8(r_ptr); - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t l_reg = vqsub_s8(llrs_reg, r_reg); + int16x4_t r_reg = vld1_s16(r_ptr); + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); - int8x8_t row_sign = l_reg; + int16x4_t row_sign = l_reg; - int8x8_t row_min = vqabs_s8(row_sign); + int16x4_t row_min = vqabs_s16(row_sign); - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); - int8_t l_val; + int16_t l_val; for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { - l_val = llrs_ptr[zb] - r_ptr[zb]; + l_val = vqsubh_s16(llrs_ptr[zb], r_ptr[zb]); row_sign_array[zb] = l_val; - row_min_array[zb] = vqabsb_s8(l_val); + row_min_array[zb] = vqabsh_s16(l_val); l_ptr[zb] = l_val; } @@ -516,24 +514,24 @@ void compute_l_product_min1_and_min2( llrs_ptr = llrs + d->col_ptr[1] * (2 * d->z) + shift; r_ptr = r + d->z; - r_reg = vld1_s8(r_ptr); - llrs_reg = vld1_s8(llrs_ptr); - l_reg = vqsub_s8(llrs_reg, r_reg); + r_reg = vld1_s16(r_ptr); + llrs_reg = vld1_s16(llrs_ptr); + l_reg = vqsub_s16(llrs_reg, r_reg); - row_sign = veor_s8(row_sign, l_reg); + row_sign = veor_s16(row_sign, l_reg); - int8x8_t abs_reg = vqabs_s8(l_reg); - int8x8_t row_min2 = vmax_s8(row_min, abs_reg); - row_min = vmin_s8(row_min, abs_reg); + int16x4_t abs_reg = vqabs_s16(l_reg); + int16x4_t row_min2 = vmax_s16(row_min, abs_reg); + row_min = vmin_s16(row_min, abs_reg); - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { - l_val = llrs_ptr[zb] - r_ptr[zb]; + l_val = vqsubh_s16(llrs_ptr[zb], r_ptr[zb]); row_sign_array[zb] ^= l_val; - int8_t abs_val = vqabsb_s8(l_val); + int16_t abs_val = vqabsh_s16(l_val); row_min2_array[zb] = max(row_min_array[zb], abs_val); row_min_array[zb] = min(row_min_array[zb], abs_val); @@ -548,28 +546,28 @@ void compute_l_product_min1_and_min2( r_ptr = r + d->z * col; // Compute L(n,m) = LLR(n) - R(n,m) - r_reg = vld1_s8(r_ptr); - llrs_reg = vld1_s8(llrs_ptr); - l_reg = vqsub_s8(llrs_reg, r_reg); + r_reg = vld1_s16(r_ptr); + llrs_reg = vld1_s16(llrs_ptr); + l_reg = vqsub_s16(llrs_reg, r_reg); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - row_sign = veor_s8(row_sign, l_reg); + row_sign = veor_s16(row_sign, l_reg); // Compute the min(|L(n,m)|) and the second minimum - abs_reg = vqabs_s8(l_reg); - row_min2 = vmax_s8(row_min, vmin_s8(row_min2, abs_reg)); - row_min = vmin_s8(row_min, abs_reg); + abs_reg = vqabs_s16(l_reg); + row_min2 = vmax_s16(row_min, vmin_s16(row_min2, abs_reg)); + row_min = vmin_s16(row_min, abs_reg); // Store L(n,m) - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); // Process tail for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { - l_val = llrs_ptr[zb] - r_ptr[zb]; + l_val = vqsubh_s16(llrs_ptr[zb], r_ptr[zb]); row_sign_array[zb] ^= l_val; - int8_t abs_val = vqabsb_s8(l_val); + int16_t abs_val = vqabsh_s16(l_val); row_min2_array[zb] = max(row_min_array[zb], min(row_min2_array[zb], abs_val)); row_min_array[zb] = min(row_min_array[zb], abs_val); @@ -579,41 +577,41 @@ void compute_l_product_min1_and_min2( } // Store the two minima and the product for Z rows - vst1_s8(row_min_array, row_min); - vst1_s8(row_min2_array, row_min2); - vst1_s8(row_sign_array, row_sign); + vst1_s16(row_min_array, row_min); + vst1_s16(row_min2_array, row_min2); + vst1_s16(row_sign_array, row_sign); #endif } template<> void compute_l_product_min1_and_min2( - int8_t *l, const int8_t *__restrict__ llrs, const int8_t *__restrict__ r, + int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *row_min_array, int8_t *row_min2_array, - int8_t *row_sign_array) { + uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, + int16_t *row_sign_array) { #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg = svptrue_b8(); - svbool_t pg_tail = svwhilelt_b8(0, tail_size); + svbool_t pg = svptrue_b16(); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); // Loop through the columns in the row (variable node n in psi(m)) // Column 0 - int8_t *l_ptr = l; + int16_t *l_ptr = l; auto shift = d->shift_ptr[0] % d->z; - const int8_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int8_t *r_ptr = r; - int8_t *sign_ptr = row_sign_array; - int8_t *min_ptr = row_min_array; + const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; + const int16_t *r_ptr = r; + int16_t *sign_ptr = row_sign_array; + int16_t *min_ptr = row_min_array; for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - svint8_t r_reg = svld1_s8(pg, r_ptr); - svint8_t llrs_reg = svld1_s8(pg, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg, r_ptr); + svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - svst1_s8(pg, sign_ptr, l_reg); + svst1_s16(pg, sign_ptr, l_reg); - svst1_s8(pg, min_ptr, svqabs_s8_x(pg, l_reg)); + svst1_s16(pg, min_ptr, svqabs_s16_x(pg, l_reg)); - svst1_s8(pg, l_ptr, l_reg); + svst1_s16(pg, l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -623,15 +621,15 @@ void compute_l_product_min1_and_min2( } if (tail_size != 0) { - svint8_t r_reg = svld1_s8(pg_tail, r_ptr); - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg_tail, r_ptr); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - svst1_s8(pg_tail, sign_ptr, l_reg); + svst1_s16(pg_tail, sign_ptr, l_reg); - svst1_s8(pg_tail, min_ptr, svqabs_s8_x(pg_tail, l_reg)); + svst1_s16(pg_tail, min_ptr, svqabs_s16_x(pg_tail, l_reg)); - svst1_s8(pg_tail, l_ptr, l_reg); + svst1_s16(pg_tail, l_ptr, l_reg); } // Column 1 @@ -641,22 +639,22 @@ void compute_l_product_min1_and_min2( r_ptr = r + d->z; sign_ptr = row_sign_array; min_ptr = row_min_array; - int8_t *min2_ptr = row_min2_array; + int16_t *min2_ptr = row_min2_array; for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - svint8_t r_reg = svld1_s8(pg, r_ptr); - svint8_t llrs_reg = svld1_s8(pg, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg, r_ptr); + svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - svint8_t sign_reg = svld1_s8(pg, sign_ptr); - svst1_s8(pg, sign_ptr, sveor_s8_x(pg, sign_reg, l_reg)); + svint16_t sign_reg = svld1_s16(pg, sign_ptr); + svst1_s16(pg, sign_ptr, sveor_s16_x(pg, sign_reg, l_reg)); - svint8_t min_reg = svld1_s8(pg, min_ptr); - svint8_t abs_reg = svqabs_s8_x(pg, l_reg); - svst1_s8(pg, min2_ptr, svmax_s8_x(pg, min_reg, abs_reg)); - svst1_s8(pg, min_ptr, svmin_s8_x(pg, min_reg, abs_reg)); + svint16_t min_reg = svld1_s16(pg, min_ptr); + svint16_t abs_reg = svqabs_s16_x(pg, l_reg); + svst1_s16(pg, min2_ptr, svmax_s16_x(pg, min_reg, abs_reg)); + svst1_s16(pg, min_ptr, svmin_s16_x(pg, min_reg, abs_reg)); - svst1_s8(pg, l_ptr, l_reg); + svst1_s16(pg, l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -667,19 +665,19 @@ void compute_l_product_min1_and_min2( } if (tail_size != 0) { - svint8_t r_reg = svld1_s8(pg_tail, r_ptr); - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg_tail, r_ptr); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - svint8_t sign_reg = svld1_s8(pg_tail, sign_ptr); - svst1_s8(pg_tail, sign_ptr, sveor_s8_x(pg_tail, sign_reg, l_reg)); + svint16_t sign_reg = svld1_s16(pg_tail, sign_ptr); + svst1_s16(pg_tail, sign_ptr, sveor_s16_x(pg_tail, sign_reg, l_reg)); - svint8_t min_reg = svld1_s8(pg_tail, min_ptr); - svint8_t abs_reg = svqabs_s8_x(pg_tail, l_reg); - svst1_s8(pg_tail, min2_ptr, svmax_s8_x(pg_tail, min_reg, abs_reg)); - svst1_s8(pg_tail, min_ptr, svmin_s8_x(pg_tail, min_reg, abs_reg)); + svint16_t min_reg = svld1_s16(pg_tail, min_ptr); + svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); + svst1_s16(pg_tail, min2_ptr, svmax_s16_x(pg_tail, min_reg, abs_reg)); + svst1_s16(pg_tail, min_ptr, svmin_s16_x(pg_tail, min_reg, abs_reg)); - svst1_s8(pg_tail, l_ptr, l_reg); + svst1_s16(pg_tail, l_ptr, l_reg); } // Columns n >= 2 @@ -695,24 +693,24 @@ void compute_l_product_min1_and_min2( // Loop through the Z rows in the layer (check node m) for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { // Compute L(n,m) = LLR(n) - R(n,m) - svint8_t r_reg = svld1_s8(pg, r_ptr); - svint8_t llrs_reg = svld1_s8(pg, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg, r_ptr); + svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - svint8_t sign_reg = svld1_s8(pg, sign_ptr); - svst1_s8(pg, sign_ptr, sveor_s8_x(pg, sign_reg, l_reg)); + svint16_t sign_reg = svld1_s16(pg, sign_ptr); + svst1_s16(pg, sign_ptr, sveor_s16_x(pg, sign_reg, l_reg)); // Compute the min(|L(n,m)|) and the second minimum - svint8_t min_reg = svld1_s8(pg, min_ptr); - svint8_t min2_reg = svld1_s8(pg, min2_ptr); - svint8_t abs_reg = svqabs_s8_x(pg, l_reg); - svst1_s8(pg, min2_ptr, - svmax_s8_x(pg, min_reg, svmin_s8_x(pg, min2_reg, abs_reg))); - svst1_s8(pg, min_ptr, svmin_s8_x(pg, min_reg, abs_reg)); + svint16_t min_reg = svld1_s16(pg, min_ptr); + svint16_t min2_reg = svld1_s16(pg, min2_ptr); + svint16_t abs_reg = svqabs_s16_x(pg, l_reg); + svst1_s16(pg, min2_ptr, + svmax_s16_x(pg, min_reg, svmin_s16_x(pg, min2_reg, abs_reg))); + svst1_s16(pg, min_ptr, svmin_s16_x(pg, min_reg, abs_reg)); // Store L(n,m) - svst1_s8(pg, l_ptr, l_reg); + svst1_s16(pg, l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -724,45 +722,45 @@ void compute_l_product_min1_and_min2( // Process tail if (tail_size != 0) { - svint8_t r_reg = svld1_s8(pg_tail, r_ptr); - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); - - svint8_t sign_reg = svld1_s8(pg_tail, sign_ptr); - svst1_s8(pg_tail, sign_ptr, sveor_s8_x(pg_tail, sign_reg, l_reg)); - - svint8_t min_reg = svld1_s8(pg_tail, min_ptr); - svint8_t min2_reg = svld1_s8(pg_tail, min2_ptr); - svint8_t abs_reg = svqabs_s8_x(pg_tail, l_reg); - svst1_s8( - pg_tail, min2_ptr, - svmax_s8_x(pg_tail, min_reg, svmin_s8_x(pg_tail, min2_reg, abs_reg))); - svst1_s8(pg_tail, min_ptr, svmin_s8_x(pg_tail, min_reg, abs_reg)); - - svst1_s8(pg_tail, l_ptr, l_reg); + svint16_t r_reg = svld1_s16(pg_tail, r_ptr); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); + + svint16_t sign_reg = svld1_s16(pg_tail, sign_ptr); + svst1_s16(pg_tail, sign_ptr, sveor_s16_x(pg_tail, sign_reg, l_reg)); + + svint16_t min_reg = svld1_s16(pg_tail, min_ptr); + svint16_t min2_reg = svld1_s16(pg_tail, min2_ptr); + svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); + svst1_s16(pg_tail, min2_ptr, + svmax_s16_x(pg_tail, min_reg, + svmin_s16_x(pg_tail, min2_reg, abs_reg))); + svst1_s16(pg_tail, min_ptr, svmin_s16_x(pg_tail, min_reg, abs_reg)); + + svst1_s16(pg_tail, l_ptr, l_reg); } } #else // Loop through the columns in the row (variable node n in psi(m)) // Column 0 - int8_t *l_ptr = l; + int16_t *l_ptr = l; auto shift = d->shift_ptr[0] % d->z; - const int8_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int8_t *r_ptr = r; - int8_t *sign_ptr = row_sign_array; - int8_t *min_ptr = row_min_array; + const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; + const int16_t *r_ptr = r; + int16_t *sign_ptr = row_sign_array; + int16_t *min_ptr = row_min_array; uint32_t tail_cnt = tail_size; for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { - int8x16_t r_reg = vld1q_s8(r_ptr); - int8x16_t llrs_reg = vld1q_s8(llrs_ptr); - int8x16_t l_reg = vqsubq_s8(llrs_reg, r_reg); + int16x8_t r_reg = vld1q_s16(r_ptr); + int16x8_t llrs_reg = vld1q_s16(llrs_ptr); + int16x8_t l_reg = vqsubq_s16(llrs_reg, r_reg); - vst1q_s8(sign_ptr, l_reg); + vst1q_s16(sign_ptr, l_reg); - vst1q_s8(min_ptr, vqabsq_s8(l_reg)); + vst1q_s16(min_ptr, vqabsq_s16(l_reg)); - vst1q_s8(l_ptr, l_reg); + vst1q_s16(l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -771,26 +769,26 @@ void compute_l_product_min1_and_min2( llrs_ptr += num_lanes; } - if (tail_cnt > 7U) { - int8x8_t r_reg = vld1_s8(r_ptr); - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t l_reg = vqsub_s8(llrs_reg, r_reg); + if (tail_cnt > 3U) { + int16x4_t r_reg = vld1_s16(r_ptr); + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); - vst1_s8(sign_ptr, l_reg); + vst1_s16(sign_ptr, l_reg); - vst1_s8(min_ptr, vqabs_s8(l_reg)); + vst1_s16(min_ptr, vqabs_s16(l_reg)); - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); - tail_cnt = d->z & 0x7; + tail_cnt = d->z & 0x3; } for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { - l[zb] = llrs[d->col_ptr[0] * (2 * d->z) + shift + zb] - r[zb]; + l[zb] = vqsubh_s16(llrs[d->col_ptr[0] * (2 * d->z) + shift + zb], r[zb]); row_sign_array[zb] = l[zb]; - row_min_array[zb] = vqabsb_s8(l[zb]); + row_min_array[zb] = vqabsh_s16(l[zb]); } // Column 1 @@ -798,25 +796,25 @@ void compute_l_product_min1_and_min2( llrs_ptr = llrs + d->col_ptr[1] * (2 * d->z) + shift; sign_ptr = row_sign_array; min_ptr = row_min_array; - int8_t *min2_ptr = row_min2_array; + int16_t *min2_ptr = row_min2_array; l_ptr = l + d->z; r_ptr = r + d->z; tail_cnt = tail_size; for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { - int8x16_t r_reg = vld1q_s8(r_ptr); - int8x16_t llrs_reg = vld1q_s8(llrs_ptr); - int8x16_t l_reg = vqsubq_s8(llrs_reg, r_reg); + int16x8_t r_reg = vld1q_s16(r_ptr); + int16x8_t llrs_reg = vld1q_s16(llrs_ptr); + int16x8_t l_reg = vqsubq_s16(llrs_reg, r_reg); - int8x16_t sign_reg = vld1q_s8(sign_ptr); - vst1q_s8(sign_ptr, veorq_s8(sign_reg, l_reg)); + int16x8_t sign_reg = vld1q_s16(sign_ptr); + vst1q_s16(sign_ptr, veorq_s16(sign_reg, l_reg)); - int8x16_t min_reg = vld1q_s8(min_ptr); - int8x16_t abs_reg = vqabsq_s8(l_reg); - vst1q_s8(min2_ptr, vmaxq_s8(min_reg, abs_reg)); - vst1q_s8(min_ptr, vminq_s8(min_reg, abs_reg)); + int16x8_t min_reg = vld1q_s16(min_ptr); + int16x8_t abs_reg = vqabsq_s16(l_reg); + vst1q_s16(min2_ptr, vmaxq_s16(min_reg, abs_reg)); + vst1q_s16(min_ptr, vminq_s16(min_reg, abs_reg)); - vst1q_s8(l_ptr, l_reg); + vst1q_s16(l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -826,30 +824,31 @@ void compute_l_product_min1_and_min2( llrs_ptr += num_lanes; } - if (tail_cnt > 7U) { - int8x8_t r_reg = vld1_s8(r_ptr); - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t l_reg = vqsub_s8(llrs_reg, r_reg); + if (tail_cnt > 3U) { + int16x4_t r_reg = vld1_s16(r_ptr); + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); - int8x8_t sign_reg = vld1_s8(sign_ptr); - vst1_s8(sign_ptr, veor_s8(sign_reg, l_reg)); + int16x4_t sign_reg = vld1_s16(sign_ptr); + vst1_s16(sign_ptr, veor_s16(sign_reg, l_reg)); - int8x8_t min_reg = vld1_s8(min_ptr); - int8x8_t abs_reg = vqabs_s8(l_reg); - vst1_s8(min2_ptr, vmax_s8(min_reg, abs_reg)); - vst1_s8(min_ptr, vmin_s8(min_reg, abs_reg)); + int16x4_t min_reg = vld1_s16(min_ptr); + int16x4_t abs_reg = vqabs_s16(l_reg); + vst1_s16(min2_ptr, vmax_s16(min_reg, abs_reg)); + vst1_s16(min_ptr, vmin_s16(min_reg, abs_reg)); - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); - tail_cnt = d->z & 0x7; + tail_cnt = d->z & 0x3; } for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { - l[d->z + zb] = llrs[d->col_ptr[1] * (2 * d->z) + shift + zb] - r[d->z + zb]; + l[d->z + zb] = + vqsubh_s16(llrs[d->col_ptr[1] * (2 * d->z) + shift + zb], r[d->z + zb]); row_sign_array[zb] ^= l[d->z + zb]; - int8_t abs_val = vqabsb_s8(l[d->z + zb]); + int16_t abs_val = vqabsh_s16(l[d->z + zb]); row_min2_array[zb] = max(row_min_array[zb], abs_val); row_min_array[zb] = min(row_min_array[zb], abs_val); } @@ -866,26 +865,26 @@ void compute_l_product_min1_and_min2( tail_cnt = tail_size; // Loop through the Z rows in the layer (check node m) - // Process 16 entries at a time + // Process 8 entries at a time for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { // Compute L(n,m) = LLR(n) - R(n,m) - int8x16_t r_reg = vld1q_s8(r_ptr); - int8x16_t llrs_reg = vld1q_s8(llrs_ptr); - int8x16_t l_reg = vqsubq_s8(llrs_reg, r_reg); + int16x8_t r_reg = vld1q_s16(r_ptr); + int16x8_t llrs_reg = vld1q_s16(llrs_ptr); + int16x8_t l_reg = vqsubq_s16(llrs_reg, r_reg); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - int8x16_t sign_reg = vld1q_s8(sign_ptr); - vst1q_s8(sign_ptr, veorq_s8(sign_reg, l_reg)); + int16x8_t sign_reg = vld1q_s16(sign_ptr); + vst1q_s16(sign_ptr, veorq_s16(sign_reg, l_reg)); // Compute the min(|L(n,m)|) and the second minimum - int8x16_t min_reg = vld1q_s8(min_ptr); - int8x16_t min2_reg = vld1q_s8(min2_ptr); - int8x16_t abs_reg = vqabsq_s8(l_reg); - vst1q_s8(min2_ptr, vmaxq_s8(min_reg, vminq_s8(min2_reg, abs_reg))); - vst1q_s8(min_ptr, vminq_s8(min_reg, abs_reg)); + int16x8_t min_reg = vld1q_s16(min_ptr); + int16x8_t min2_reg = vld1q_s16(min2_ptr); + int16x8_t abs_reg = vqabsq_s16(l_reg); + vst1q_s16(min2_ptr, vmaxq_s16(min_reg, vminq_s16(min2_reg, abs_reg))); + vst1q_s16(min_ptr, vminq_s16(min_reg, abs_reg)); // Store L(n,m) - vst1q_s8(l_ptr, l_reg); + vst1q_s16(l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -895,41 +894,41 @@ void compute_l_product_min1_and_min2( llrs_ptr += num_lanes; } - // Process a group of 8 elts - if (tail_cnt > 7U) { + // Process a group of 4 elts + if (tail_cnt > 3U) { // Compute L(n,m) = LLR(n) - R(n,m) - int8x8_t r_reg = vld1_s8(r_ptr); - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t l_reg = vqsub_s8(llrs_reg, r_reg); + int16x4_t r_reg = vld1_s16(r_ptr); + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - int8x8_t sign_reg = vld1_s8(sign_ptr); - vst1_s8(sign_ptr, veor_s8(sign_reg, l_reg)); + int16x4_t sign_reg = vld1_s16(sign_ptr); + vst1_s16(sign_ptr, veor_s16(sign_reg, l_reg)); // Compute the min(|L(n,m)|) and the second minimum - int8x8_t min_reg = vld1_s8(min_ptr); - int8x8_t min2_reg = vld1_s8(min2_ptr); - int8x8_t abs_reg = vqabs_s8(l_reg); - vst1_s8(min2_ptr, vmax_s8(min_reg, vmin_s8(min2_reg, abs_reg))); - vst1_s8(min_ptr, vmin_s8(min_reg, abs_reg)); + int16x4_t min_reg = vld1_s16(min_ptr); + int16x4_t min2_reg = vld1_s16(min2_ptr); + int16x4_t abs_reg = vqabs_s16(l_reg); + vst1_s16(min2_ptr, vmax_s16(min_reg, vmin_s16(min2_reg, abs_reg))); + vst1_s16(min_ptr, vmin_s16(min_reg, abs_reg)); // Store L(n,m) - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); - tail_cnt = d->z & 0x7; + tail_cnt = d->z & 0x3; } // Process tail for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { // Compute L(n,m) = LLR(n) - R(n,m) - l[d->z * col + zb] = - llrs[d->col_ptr[col] * (2 * d->z) + shift + zb] - r[d->z * col + zb]; + l[d->z * col + zb] = vqsubh_s16( + llrs[d->col_ptr[col] * (2 * d->z) + shift + zb], r[d->z * col + zb]); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) row_sign_array[zb] ^= l[d->z * col + zb]; // Compute the min(|L(n,m)|) and the second minimum - int8_t abs_val = vqabsb_s8(l[d->z * col + zb]); + int16_t abs_val = vqabsh_s16(l[d->z * col + zb]); row_min2_array[zb] = max(row_min_array[zb], min(row_min2_array[zb], abs_val)); row_min_array[zb] = min(row_min_array[zb], abs_val); @@ -948,38 +947,38 @@ void compute_l_product_min1_and_min2( // - The log likelihood ratios for each n in \psi(m): // LLR(n) = R(n,m) + L(n,m) template -void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, +void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, int32_t tail_size, - const int8_t *row_min_array, - const int8_t *row_min2_array, - const int8_t *row_sign_array); + int32_t full_vec, uint32_t tail_size, + const int16_t *row_min_array, + const int16_t *row_min2_array, + const int16_t *row_sign_array); template<> -void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, +void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, int32_t tail_size, - const int8_t *row_min_array, - const int8_t *row_min2_array, - const int8_t *row_sign_array) { + int32_t full_vec, uint32_t tail_size, + const int16_t *row_min_array, + const int16_t *row_min2_array, + const int16_t *row_sign_array) { // Loop through the Z rows in the layer (check node m) for (uint32_t zb = 0; zb < d->z; ++zb) { - const int8_t *l_ptr = l + zb * d->num_cols; + const int16_t *l_ptr = l + zb * d->num_cols; // Loop through the columns in the row (variable node n in psi(m)) for (uint32_t col = 0; col < d->num_cols; ++col) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8_t col_sign = (row_sign_array[zb] ^ l_ptr[col]) < 0 ? -1 : 1; + int16_t col_sign = (row_sign_array[zb] ^ l_ptr[col]) < 0 ? -1 : 1; // Compute R(n,m) - int8_t abs_val = vqabsb_s8(l_ptr[col]); - int8_t r_val = + int16_t abs_val = vqabsh_s16(l_ptr[col]); + int16_t r_val = col_sign * (abs_val == row_min_array[zb] ? row_min2_array[zb] : row_min_array[zb]); // Compute LLR(n) = R(n,m) + L(n,m) auto shift = (d->shift_ptr[col] + zb) % d->z; auto col_ind = d->col_ptr[col] * d->z + shift; - llrs[col_ind] = vqaddb_s8(r_val, l_ptr[col]); + llrs[col_ind] = vqaddh_s16(r_val, l_ptr[col]); // Store R(n,m) for the next iteration r[col] = r_val; @@ -988,158 +987,158 @@ void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, } template<> -void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, +void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, int32_t tail_size, - const int8_t *row_min_array, - const int8_t *row_min2_array, - const int8_t *row_sign_array) { - // Case for lifting sizes 8 <= Z < 16 (rows in the layer) + int32_t full_vec, uint32_t tail_size, + const int16_t *row_min_array, + const int16_t *row_min2_array, + const int16_t *row_sign_array) { + // Case for lifting sizes 4 <= Z < 8 (rows in the layer) #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg_tail = svwhilelt_b8(0, (int)tail_size); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); - svint8_t row_min = svld1_s8(pg_tail, row_min_array); - svint8_t row_min2 = svld1_s8(pg_tail, row_min2_array); - svint8_t row_sign = svld1_s8(pg_tail, row_sign_array); + svint16_t row_min = svld1_s16(pg_tail, row_min_array); + svint16_t row_min2 = svld1_s16(pg_tail, row_min2_array); + svint16_t row_sign = svld1_s16(pg_tail, row_sign_array); // Loop through the columns in the row (variable node n in psi(m)) for (uint32_t col = 0; col < d->num_cols; ++col) { auto shift = d->shift_ptr[col] % d->z; auto col_ind = d->col_ptr[col] * (2 * d->z); - int8_t *r_ptr = r + d->z * col; - const int8_t *l_ptr = l + d->z * col; - int8_t *llrs_ptr = llrs + col_ind + shift; + int16_t *r_ptr = r + d->z * col; + const int16_t *l_ptr = l + d->z * col; + int16_t *llrs_ptr = llrs + col_ind + shift; // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - svint8_t l_reg = svld1_s8(pg_tail, l_ptr); - svint8_t abs_reg = svqabs_s8_x(pg_tail, l_reg); - svint8_t eor_reg = sveor_s8_x(pg_tail, row_sign, l_reg); - svbool_t pg_tail_neg = svcmplt_n_s8(pg_tail, eor_reg, 0); + svint16_t l_reg = svld1_s16(pg_tail, l_ptr); + svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); + svint16_t eor_reg = sveor_s16_x(pg_tail, row_sign, l_reg); + svbool_t pg_tail_neg = svcmplt_n_s16(pg_tail, eor_reg, 0); // Compute R(n,m) - svbool_t pg_tail_eq = svcmpeq_s8(pg_tail, abs_reg, row_min); - svint8_t tmp_reg = svsel_s8(pg_tail_eq, row_min2, row_min); - svint8_t r_reg = svneg_s8_m(tmp_reg, pg_tail_neg, tmp_reg); + svbool_t pg_tail_eq = svcmpeq_s16(pg_tail, abs_reg, row_min); + svint16_t tmp_reg = svsel_s16(pg_tail_eq, row_min2, row_min); + svint16_t r_reg = svneg_s16_m(tmp_reg, pg_tail_neg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - svint8_t result = svqadd_s8_x(pg_tail, r_reg, l_reg); - svst1_s8(pg_tail, llrs_ptr, result); + svint16_t result = svqadd_s16_x(pg_tail, r_reg, l_reg); + svst1_s16(pg_tail, llrs_ptr, result); // Store R(n,m) for the next iteration - svst1_s8(pg_tail, r_ptr, r_reg); + svst1_s16(pg_tail, r_ptr, r_reg); // Rearrange LLRs - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int8_t)); + memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); // copy (z - shift) elts in the main block to the replicated block memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int8_t)); + (d->z - shift) * sizeof(int16_t)); } #else - int8x8_t row_min = vld1_s8(row_min_array); - int8x8_t row_min2 = vld1_s8(row_min2_array); - int8x8_t row_sign = vld1_s8(row_sign_array); + int16x4_t row_min = vld1_s16(row_min_array); + int16x4_t row_min2 = vld1_s16(row_min2_array); + int16x4_t row_sign = vld1_s16(row_sign_array); // Loop through the columns in the row (variable node n in psi(m)) for (uint32_t col = 0; col < d->num_cols; ++col) { auto shift = d->shift_ptr[col] % d->z; auto col_ind = d->col_ptr[col] * (2 * d->z); - int8_t *r_ptr = r + d->z * col; - const int8_t *l_ptr = l + d->z * col; - int8_t *llrs_ptr = llrs + col_ind + shift; + int16_t *r_ptr = r + d->z * col; + const int16_t *l_ptr = l + d->z * col; + int16_t *llrs_ptr = llrs + col_ind + shift; // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8x8_t l_reg = vld1_s8(l_ptr); - int8x8_t abs_reg = vqabs_s8(l_reg); - int8x8_t eor_reg = veor_s8(row_sign, l_reg); + int16x4_t l_reg = vld1_s16(l_ptr); + int16x4_t abs_reg = vqabs_s16(l_reg); + int16x4_t eor_reg = veor_s16(row_sign, l_reg); // sign_col_reg will contain a 1 in all lanes for negative values - uint8x8_t sign_col_reg = vcltz_s8(eor_reg); + uint16x4_t sign_col_reg = vcltz_s16(eor_reg); // Compute R(n,m) // Get a mask for the minimum value in a lane - uint8x8_t check_eq = vceq_s8(abs_reg, row_min); - int8x8_t tmp_reg = vbsl_s8(check_eq, row_min2, row_min); + uint16x4_t check_eq = vceq_s16(abs_reg, row_min); + int16x4_t tmp_reg = vbsl_s16(check_eq, row_min2, row_min); // Negate the absolute values - int8x8_t neg_abs_reg = vneg_s8(tmp_reg); - int8x8_t r_reg = vbsl_s8(sign_col_reg, neg_abs_reg, tmp_reg); + int16x4_t neg_abs_reg = vneg_s16(tmp_reg); + int16x4_t r_reg = vbsl_s16(sign_col_reg, neg_abs_reg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - int8x8_t result = vqadd_s8(r_reg, l_reg); - vst1_s8(llrs_ptr, result); + int16x4_t result = vqadd_s16(r_reg, l_reg); + vst1_s16(llrs_ptr, result); // Store R(n,m) for the next iteration - vst1_s8(r_ptr, r_reg); + vst1_s16(r_ptr, r_reg); // Process tail for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8_t col_sign = (row_sign_array[zb] ^ l[d->z * col + zb]) < 0 ? -1 : 1; + int16_t col_sign = (row_sign_array[zb] ^ l[d->z * col + zb]) < 0 ? -1 : 1; // Compute R(n,m) - int8_t abs_val = vqabsb_s8(l[d->z * col + zb]); - int8_t r_val = + int16_t abs_val = vqabsh_s16(l[d->z * col + zb]); + int16_t r_val = col_sign * (abs_val == row_min_array[zb] ? row_min2_array[zb] : row_min_array[zb]); // Compute LLR(n) = R(n,m) + L(n,m) - llrs[col_ind + shift + zb] = vqaddb_s8(r_val, *(l + d->z * col + zb)); + llrs[col_ind + shift + zb] = vqaddh_s16(r_val, *(l + d->z * col + zb)); // Store R(n,m) for the next iteration r[d->z * col + zb] = r_val; } // Rearrange LLRs - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int8_t)); + memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); // copy (z - shift) elts in the main block to the replicated block memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int8_t)); + (d->z - shift) * sizeof(int16_t)); } #endif } template<> -void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, +void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, int32_t tail_size, - const int8_t *row_min_array, - const int8_t *row_min2_array, - const int8_t *row_sign_array) { + int32_t full_vec, uint32_t tail_size, + const int16_t *row_min_array, + const int16_t *row_min2_array, + const int16_t *row_sign_array) { #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg = svptrue_b8(); - svbool_t pg_tail = svwhilelt_b8(0, tail_size); + svbool_t pg = svptrue_b16(); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); // Loop through the columns in the row (variable node n in psi(m)) for (uint32_t col = 0; col < d->num_cols; ++col) { auto shift = d->shift_ptr[col] % d->z; auto col_ind = d->col_ptr[col] * (2 * d->z); - int8_t *llrs_ptr = llrs + col_ind + shift; - const int8_t *l_ptr = l + d->z * col; - int8_t *r_ptr = r + d->z * col; - const int8_t *sign_ptr = row_sign_array; - const int8_t *min_ptr = row_min_array; - const int8_t *min2_ptr = row_min2_array; + int16_t *llrs_ptr = llrs + col_ind + shift; + const int16_t *l_ptr = l + d->z * col; + int16_t *r_ptr = r + d->z * col; + const int16_t *sign_ptr = row_sign_array; + const int16_t *min_ptr = row_min_array; + const int16_t *min2_ptr = row_min2_array; // Loop through the Z rows in the layer (check node m) for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - svint8_t l_reg = svld1_s8(pg, l_ptr); - svint8_t sign_reg = svld1_s8(pg, sign_ptr); - svint8_t eor_reg = sveor_s8_x(pg, sign_reg, l_reg); - svbool_t pg_neg = svcmplt_n_s8(pg, eor_reg, 0); + svint16_t l_reg = svld1_s16(pg, l_ptr); + svint16_t sign_reg = svld1_s16(pg, sign_ptr); + svint16_t eor_reg = sveor_s16_x(pg, sign_reg, l_reg); + svbool_t pg_neg = svcmplt_n_s16(pg, eor_reg, 0); // Compute R(n,m) - svint8_t min_reg = svld1_s8(pg, min_ptr); - svint8_t min2_reg = svld1_s8(pg, min2_ptr); - svint8_t abs_reg = svqabs_s8_x(pg, l_reg); - svbool_t pg_eq = svcmpeq_s8(pg, abs_reg, min_reg); - svint8_t tmp_reg = svsel_s8(pg_eq, min2_reg, min_reg); - svint8_t r_reg = svneg_s8_m(tmp_reg, pg_neg, tmp_reg); + svint16_t min_reg = svld1_s16(pg, min_ptr); + svint16_t min2_reg = svld1_s16(pg, min2_ptr); + svint16_t abs_reg = svqabs_s16_x(pg, l_reg); + svbool_t pg_eq = svcmpeq_s16(pg, abs_reg, min_reg); + svint16_t tmp_reg = svsel_s16(pg_eq, min2_reg, min_reg); + svint16_t r_reg = svneg_s16_m(tmp_reg, pg_neg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - svint8_t result = svqadd_s8_x(pg, r_reg, l_reg); - svst1_s8(pg, llrs_ptr, result); + svint16_t result = svqadd_s16_x(pg, r_reg, l_reg); + svst1_s16(pg, llrs_ptr, result); // Store R(n,m) for the next iteration - svst1_s8(pg, r_ptr, r_reg); + svst1_s16(pg, r_ptr, r_reg); // Increment pointers l_ptr += num_lanes; @@ -1152,75 +1151,75 @@ void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, if (tail_size != 0) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - svint8_t l_reg = svld1_s8(pg_tail, l_ptr); - svint8_t sign_reg = svld1_s8(pg_tail, sign_ptr); - svint8_t eor_reg = sveor_s8_x(pg_tail, sign_reg, l_reg); - svbool_t pg_tail_neg = svcmplt_n_s8(pg_tail, eor_reg, 0); + svint16_t l_reg = svld1_s16(pg_tail, l_ptr); + svint16_t sign_reg = svld1_s16(pg_tail, sign_ptr); + svint16_t eor_reg = sveor_s16_x(pg_tail, sign_reg, l_reg); + svbool_t pg_tail_neg = svcmplt_n_s16(pg_tail, eor_reg, 0); // Compute R(n,m) - svint8_t min_reg = svld1_s8(pg_tail, min_ptr); - svint8_t min2_reg = svld1_s8(pg_tail, min2_ptr); - svint8_t abs_reg = svqabs_s8_x(pg_tail, l_reg); - svbool_t pg_tail_eq = svcmpeq_s8(pg_tail, abs_reg, min_reg); - svint8_t tmp_reg = svsel_s8(pg_tail_eq, min2_reg, min_reg); - svint8_t r_reg = svneg_s8_m(tmp_reg, pg_tail_neg, tmp_reg); + svint16_t min_reg = svld1_s16(pg_tail, min_ptr); + svint16_t min2_reg = svld1_s16(pg_tail, min2_ptr); + svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); + svbool_t pg_tail_eq = svcmpeq_s16(pg_tail, abs_reg, min_reg); + svint16_t tmp_reg = svsel_s16(pg_tail_eq, min2_reg, min_reg); + svint16_t r_reg = svneg_s16_m(tmp_reg, pg_tail_neg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - svint8_t result = svqadd_s8_x(pg_tail, r_reg, l_reg); - svst1_s8(pg_tail, llrs_ptr, result); + svint16_t result = svqadd_s16_x(pg_tail, r_reg, l_reg); + svst1_s16(pg_tail, llrs_ptr, result); // Store R(n,m) for the next iteration - svst1_s8(pg_tail, r_ptr, r_reg); + svst1_s16(pg_tail, r_ptr, r_reg); } // Rearrange LLRs // copy shifted elements in the replicated block // back to the beginning of the main block - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int8_t)); + memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); // copy (z - shift) elts in the main block to the replicated block memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int8_t)); + (d->z - shift) * sizeof(int16_t)); } #else // Loop through the columns in the row (variable node n in psi(m)) for (uint32_t col = 0; col < d->num_cols; ++col) { - int8_t *r_ptr = r + d->z * col; - const int8_t *sign_ptr = row_sign_array; - const int8_t *min_ptr = row_min_array; - const int8_t *min2_ptr = row_min2_array; + int16_t *r_ptr = r + d->z * col; + const int16_t *sign_ptr = row_sign_array; + const int16_t *min_ptr = row_min_array; + const int16_t *min2_ptr = row_min2_array; auto shift = d->shift_ptr[col] % d->z; auto col_ind = d->col_ptr[col] * (2 * d->z); - const int8_t *l_ptr = l + d->z * col; - int8_t *llrs_ptr = llrs + col_ind + shift; + const int16_t *l_ptr = l + d->z * col; + int16_t *llrs_ptr = llrs + col_ind + shift; uint32_t tail_cnt = tail_size; // Loop through the Z rows in the layer (check node m) - // Process 16 entries at a time + // Process 8 entries at a time for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8x16_t l_reg = vld1q_s8(l_ptr); - int8x16_t sign_reg = vld1q_s8(sign_ptr); - int8x16_t eor_reg = veorq_s8(sign_reg, l_reg); + int16x8_t l_reg = vld1q_s16(l_ptr); + int16x8_t sign_reg = vld1q_s16(sign_ptr); + int16x8_t eor_reg = veorq_s16(sign_reg, l_reg); // sign_col_reg will contain a 1 in all lanes for negative values - uint8x16_t sign_col_reg = vcltzq_s8(eor_reg); + uint16x8_t sign_col_reg = vcltzq_s16(eor_reg); // Compute R(n,m) - int8x16_t min_reg = vld1q_s8(min_ptr); - int8x16_t min2_reg = vld1q_s8(min2_ptr); - int8x16_t abs_reg = vqabsq_s8(l_reg); + int16x8_t min_reg = vld1q_s16(min_ptr); + int16x8_t min2_reg = vld1q_s16(min2_ptr); + int16x8_t abs_reg = vqabsq_s16(l_reg); // Get a mask for the minimum value in a lane - uint8x16_t check_eq = vceqq_s8(abs_reg, min_reg); - int8x16_t tmp_reg = vbslq_s8(check_eq, min2_reg, min_reg); + uint16x8_t check_eq = vceqq_s16(abs_reg, min_reg); + int16x8_t tmp_reg = vbslq_s16(check_eq, min2_reg, min_reg); // Negate the absolute values - int8x16_t neg_abs_reg = vnegq_s8(tmp_reg); - int8x16_t r_reg = vbslq_s8(sign_col_reg, neg_abs_reg, tmp_reg); + int16x8_t neg_abs_reg = vnegq_s16(tmp_reg); + int16x8_t r_reg = vbslq_s16(sign_col_reg, neg_abs_reg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - int8x16_t result = vqaddq_s8(r_reg, l_reg); - vst1q_s8(llrs_ptr, result); + int16x8_t result = vqaddq_s16(r_reg, l_reg); + vst1q_s16(llrs_ptr, result); // Store R(n,m) for the next iteration - vst1q_s8(r_ptr, r_reg); + vst1q_s16(r_ptr, r_reg); // Increment pointers r_ptr += num_lanes; @@ -1231,49 +1230,49 @@ void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, min2_ptr += num_lanes; } - // Process a group of 8 elts - if (tail_cnt > 7U) { + // Process a group of 4 elts + if (tail_cnt > 3U) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8x8_t l_reg = vld1_s8(l_ptr); - int8x8_t sign_reg = vld1_s8(sign_ptr); - int8x8_t eor_reg = veor_s8(sign_reg, l_reg); + int16x4_t l_reg = vld1_s16(l_ptr); + int16x4_t sign_reg = vld1_s16(sign_ptr); + int16x4_t eor_reg = veor_s16(sign_reg, l_reg); // sign_col_reg will contain a 1 in all lanes for negative values - uint8x8_t sign_col_reg = vcltz_s8(eor_reg); + uint16x4_t sign_col_reg = vcltz_s16(eor_reg); // Compute R(n,m) - int8x8_t min_reg = vld1_s8(min_ptr); - int8x8_t min2_reg = vld1_s8(min2_ptr); - int8x8_t abs_reg = vqabs_s8(l_reg); + int16x4_t min_reg = vld1_s16(min_ptr); + int16x4_t min2_reg = vld1_s16(min2_ptr); + int16x4_t abs_reg = vqabs_s16(l_reg); // Get a mask for the minimum value in a lane - uint8x8_t check_eq = vceq_s8(abs_reg, min_reg); - int8x8_t tmp_reg = vbsl_s8(check_eq, min2_reg, min_reg); + uint16x4_t check_eq = vceq_s16(abs_reg, min_reg); + int16x4_t tmp_reg = vbsl_s16(check_eq, min2_reg, min_reg); // Negate the absolute values - int8x8_t neg_abs_reg = vneg_s8(tmp_reg); - int8x8_t r_reg = vbsl_s8(sign_col_reg, neg_abs_reg, tmp_reg); + int16x4_t neg_abs_reg = vneg_s16(tmp_reg); + int16x4_t r_reg = vbsl_s16(sign_col_reg, neg_abs_reg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - int8x8_t result = vqadd_s8(r_reg, l_reg); - vst1_s8(llrs_ptr, result); + int16x4_t result = vqadd_s16(r_reg, l_reg); + vst1_s16(llrs_ptr, result); // Store R(n,m) for the next iteration - vst1_s8(r_ptr, r_reg); + vst1_s16(r_ptr, r_reg); - tail_cnt = d->z & 0x7; + tail_cnt = d->z & 0x3; } // Process tail for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8_t col_sign = (row_sign_array[zb] ^ l[d->z * col + zb]) < 0 ? -1 : 1; + int16_t col_sign = (row_sign_array[zb] ^ l[d->z * col + zb]) < 0 ? -1 : 1; // Compute R(n,m) - int8_t abs_val = vqabsb_s8(l[d->z * col + zb]); - int8_t r_val = + int16_t abs_val = vqabsh_s16(l[d->z * col + zb]); + int16_t r_val = col_sign * (abs_val == row_min_array[zb] ? row_min2_array[zb] : row_min_array[zb]); // Compute LLR(n) = R(n,m) + L(n,m) - llrs[col_ind + shift + zb] = vqaddb_s8(r_val, *(l + d->z * col + zb)); + llrs[col_ind + shift + zb] = vqaddh_s16(r_val, *(l + d->z * col + zb)); // Store R(n,m) for the next iteration r[d->z * col + zb] = r_val; @@ -1282,21 +1281,22 @@ void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, // Rearrange LLRs // copy shifted elements in the replicated block // back to the beginning of the main block - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int8_t)); + memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); // copy (z - shift) elts in the main block to the replicated block memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int8_t)); + (d->z - shift) * sizeof(int16_t)); } #endif } template void __attribute__((flatten)) -run_iterations(uint32_t num_its, int z, int lsi, - const armral_ldpc_base_graph_t *graph, int8_t *r, int8_t *l, - int8_t *new_llrs, int num_lanes, int full_vec, int tail_size, - int8_t *row_min_array, int8_t *row_min2_array, - int8_t *row_sign_array, int8_t *check, bool check_convergence, +run_iterations(uint32_t num_its, uint32_t z, uint32_t lsi, + const armral_ldpc_base_graph_t *graph, int16_t *r, int16_t *l, + int16_t *new_llrs, int32_t num_lanes, int32_t full_vec, + uint32_t tail_size, int16_t *row_min_array, + int16_t *row_min2_array, int16_t *row_sign_array, int16_t *check, + bool check_convergence, std::optional> &crc_checker) { for (uint32_t i = 0; i < num_its; ++i) { ldpc_layer_data d(z, lsi, graph); @@ -1360,34 +1360,34 @@ void armral::ldpc::decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, uint32_t layer_size = (graph->row_start_inds[2] - graph->row_start_inds[1]) * z; // We need to keep a record of matrix L (variable-to-check-node messages) - auto l = allocate_uninitialized(allocator, layer_size); + auto l = allocate_uninitialized(allocator, layer_size); // We need to keep a record of matrix R (check-to-variable-node messages) - auto r = allocate_zeroed(allocator, mat_size); + auto r = allocate_zeroed(allocator, mat_size); - auto row_min_array = allocate_zeroed(allocator, z); - auto row_min2_array = allocate_zeroed(allocator, z); - auto row_sign_array = allocate_zeroed(allocator, z); + auto row_min_array = allocate_zeroed(allocator, z); + auto row_min2_array = allocate_zeroed(allocator, z); + auto row_sign_array = allocate_zeroed(allocator, z); - auto check = allocate_zeroed(allocator, z); + auto check = allocate_zeroed(allocator, z); #if ARMRAL_ARCH_SVE >= 2 bool z_is_tiny = (z == 2); #else - bool z_is_tiny = (z < 8); + bool z_is_tiny = (z < 4); #endif // Keep a record of the current, and previous values of the LLRs // Copy the inputs LLRs const auto *llrs_ptr = llrs; size_t new_llrs_size = num_llrs; - std::optional> maybe_out_llrs; + std::optional> maybe_out_llrs; if (!z_is_tiny) { // Double the storage required to replicate LLRs for optimization new_llrs_size *= 2; // Extra buffer to pack the LLRs again - maybe_out_llrs = allocate_uninitialized(allocator, num_llrs); + maybe_out_llrs = allocate_uninitialized(allocator, num_llrs); } - auto new_llrs = allocate_uninitialized(allocator, new_llrs_size); + auto new_llrs = allocate_uninitialized(allocator, new_llrs_size); // NOTE: All allocations are now done! if constexpr (Allocator::is_counting) { @@ -1396,18 +1396,24 @@ void armral::ldpc::decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, if (z_is_tiny) { // Set the value of the current LLRs from the ones passed in. - // We need to take account of the punctured columns - memset(new_llrs.get(), 0, 2 * z * sizeof(int8_t)); - memcpy(&new_llrs[2 * z], llrs, z * graph->ncodeword_bits * sizeof(int8_t)); + // We need to take account of the punctured columns. + // Also widen to int16_t for use in intermediate calculations. + memset(new_llrs.get(), 0, 2 * z * sizeof(int16_t)); + for (uint32_t i = 0; i < z * graph->ncodeword_bits; i++) { + new_llrs[2 * z + i] = (int16_t)llrs[i]; + } } else { // Each block of Z elements replicated b1|b1|b2|b2 ... - // We need to take account of the punctured columns - memset(new_llrs.get(), 0, 4 * z * sizeof(int8_t)); + // We need to take account of the punctured columns. + // Also widen to int16_t for use in intermediate calculations. + memset(new_llrs.get(), 0, 4 * z * sizeof(int16_t)); auto *new_llrs_ptr = &new_llrs[4 * z]; for (uint32_t num_block = 0; num_block < graph->ncodeword_bits; num_block++) { - memcpy(new_llrs_ptr, llrs_ptr, z * sizeof(int8_t)); - memcpy(new_llrs_ptr + z, llrs_ptr, z * sizeof(int8_t)); + for (uint32_t i = 0; i < z; i++) { + new_llrs_ptr[i] = (int16_t)llrs_ptr[i]; + new_llrs_ptr[z + i] = (int16_t)llrs_ptr[i]; + } new_llrs_ptr += 2 * z; llrs_ptr += z; } @@ -1415,19 +1421,19 @@ void armral::ldpc::decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, // Precompute number of full vector and tail #if ARMRAL_ARCH_SVE >= 2 - int32_t num_lanes = svcntb(); + int32_t num_lanes = svcnth(); int32_t full_vec = z / num_lanes; uint32_t tail_size = z % num_lanes; bool is_tail_only = (tail_size == z && !z_is_tiny); #else - int32_t num_lanes = 16; + int32_t num_lanes = 8; int32_t full_vec = z / num_lanes; uint32_t tail_size = z % num_lanes; bool is_tail_only = (tail_size == z && !z_is_tiny); if (is_tail_only) { - // tail size = Z - 8 - tail_size -= 8; + // tail size = Z - 4 + tail_size -= 4; } #endif @@ -1459,7 +1465,7 @@ void armral::ldpc::decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, for (uint32_t num_block = 0; num_block < graph->ncodeword_bits + 2; num_block++) { memcpy(out_llrs + num_block * z, &new_llrs[2 * num_block * z], - z * sizeof(int8_t)); + z * sizeof(int16_t)); } // Hard decode into the output variable -- GitLab