diff --git a/CHANGELOG.md b/CHANGELOG.md index 915c20fa31fc22722a6752845813dd80971c2002..925bc9d28762a9628e3c08f012f8435abfcd4d0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,11 @@ documented in this file. ### Fixed +- LDPC decoding (`armral_ldpc_decode_block`) now achieves the expected error + correction performance in the presence of channel noise. The function now uses + `int16_t` internally rather than `int8_t`, which can be slower for certain + input sizes. + ### Security ## [24.10] - 2024-10-17 diff --git a/src/UpperPHY/LDPC/ldpc_decoder.cpp b/src/UpperPHY/LDPC/ldpc_decoder.cpp index 7cb98bd405df60e65d00c7465049c77ef584cf16..78bda9262e1d3365c6dbd21b1f16427f96067ddb 100644 --- a/src/UpperPHY/LDPC/ldpc_decoder.cpp +++ b/src/UpperPHY/LDPC/ldpc_decoder.cpp @@ -74,18 +74,18 @@ public: m_buffer_size = m_total_bits >> 3; } - m_llrs = allocate_uninitialized(allocator, m_total_bits + m_z - 1); + m_llrs = allocate_uninitialized(allocator, m_total_bits + m_z - 1); m_buffer = allocate_uninitialized(allocator, m_buffer_size); } - bool check(const int8_t *new_llrs) { + bool check(const int16_t *new_llrs) { // Copy the LLRs corresponding to the bits we need to do the CRC check after // the padding bits - memset(m_llrs.get(), 0, m_num_pad_bits * sizeof(int8_t)); + memset(m_llrs.get(), 0, m_num_pad_bits * sizeof(int16_t)); for (uint32_t num_block = 0; num_block < ((m_k_prime + m_z - 1) / m_z); num_block++) { memcpy(m_llrs.get() + m_num_pad_bits + (num_block * m_z), - new_llrs + (2 * num_block * m_z), m_z * sizeof(int8_t)); + new_llrs + (2 * num_block * m_z), m_z * sizeof(int16_t)); } // Hard decode @@ -106,20 +106,20 @@ private: uint32_t m_buffer_size{0}; uint32_t m_num_pad_bits{0}; uint32_t m_total_bits{0}; - unique_ptr m_llrs; + unique_ptr m_llrs; unique_ptr m_buffer; }; template -bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, +bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, const armral_ldpc_base_graph_t *graph, int32_t num_lanes, - int32_t full_vec, int32_t tail_size, int8_t *check); + int32_t full_vec, uint32_t tail_size, int16_t *check); template<> -bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, +bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, const armral_ldpc_base_graph_t *graph, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *check) { + uint32_t tail_size, int16_t *check) { // Loop through the rows in the base graph bool passed = true; for (uint32_t row = 0; row < graph->nrows && passed; ++row) { @@ -132,7 +132,7 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, // Loop through the rows in the block for (uint32_t zb = 0; zb < z && passed; ++zb) { // Loop through the columns in the row - int8_t scal_check = 0; + int16_t scal_check = 0; for (uint32_t col = 0; col < num_cols; ++col) { auto shift = (shift_ptr[col] + zb) % z; auto codeword_ind = col_ptr[col] * z + shift; @@ -145,14 +145,14 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, } template<> -bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, +bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, const armral_ldpc_base_graph_t *graph, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *check) { + uint32_t tail_size, int16_t *check) { // Loop through the rows in the base graph bool passed = true; #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg_tail = svwhilelt_b8(0, (int)tail_size); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); for (uint32_t row = 0; row < graph->nrows && passed; ++row) { auto row_start_ind = graph->row_start_inds[row]; auto num_cols = graph->row_start_inds[row + 1] - row_start_ind; @@ -160,7 +160,7 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, const auto *shift_ptr = graph->shifts + row_start_ind * armral::ldpc::num_lifting_sets + lsi * num_cols; - memset(check, 0, z * sizeof(int8_t)); + memset(check, 0, z * sizeof(int16_t)); // Loop through the columns for (uint32_t col = 0; col < num_cols; ++col) { @@ -168,13 +168,13 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, auto codeword_ind = col_ptr[col] * (2 * z) + shift; // No need to loop here, as there is only a tail - const int8_t *llrs_ptr = llrs + codeword_ind; - int8_t *check_ptr = check; + const int16_t *llrs_ptr = llrs + codeword_ind; + int16_t *check_ptr = check; - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t check_reg = svld1_s8(pg_tail, check_ptr); - svint8_t result_reg = sveor_s8_x(pg_tail, check_reg, llrs_reg); - svst1_s8(pg_tail, check_ptr, result_reg); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t check_reg = svld1_s16(pg_tail, check_ptr); + svint16_t result_reg = sveor_s16_x(pg_tail, check_reg, llrs_reg); + svst1_s16(pg_tail, check_ptr, result_reg); } for (uint32_t zb = 0; zb < z && passed; ++zb) { passed &= check[zb] >= 0; @@ -188,7 +188,7 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, const auto *shift_ptr = graph->shifts + row_start_ind * armral::ldpc::num_lifting_sets + lsi * num_cols; - memset(check, 0, z * sizeof(int8_t)); + memset(check, 0, z * sizeof(int16_t)); // Loop through the columns for (uint32_t col = 0; col < num_cols; ++col) { @@ -196,13 +196,13 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, auto codeword_ind = col_ptr[col] * (2 * z) + shift; // Loop through the rows in the block - const int8_t *llrs_ptr = llrs + codeword_ind; - int8_t *check_ptr = check; + const int16_t *llrs_ptr = llrs + codeword_ind; + int16_t *check_ptr = check; - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t check_reg = vld1_s8(check_ptr); - int8x8_t result_reg = veor_s8(check_reg, llrs_reg); - vst1_s8(check_ptr, result_reg); + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t check_reg = vld1_s16(check_ptr); + int16x4_t result_reg = veor_s16(check_reg, llrs_reg); + vst1_s16(check_ptr, result_reg); // Deal with a tail for (uint32_t zb = z - tail_size; zb < z; ++zb) { @@ -218,13 +218,13 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, } template<> -bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, +bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, const armral_ldpc_base_graph_t *graph, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *check) { + uint32_t tail_size, int16_t *check) { #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg = svptrue_b8(); - svbool_t pg_tail = svwhilelt_b8(0, tail_size); + svbool_t pg = svptrue_b16(); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); // Loop through the rows in the base graph bool passed = true; @@ -235,7 +235,7 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, const auto *shift_ptr = graph->shifts + row_start_ind * armral::ldpc::num_lifting_sets + lsi * num_cols; - memset(check, 0, z * sizeof(int8_t)); + memset(check, 0, z * sizeof(int16_t)); // Loop through the columns for (uint32_t col = 0; col < num_cols; ++col) { @@ -248,14 +248,14 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, // represent a hard decision for the bit to be one, and non-negative // values represent a zero. Hence the check needs to xor all LLRs // and then assert that the result is non-negative. - const int8_t *llrs_ptr = llrs + codeword_ind; - int8_t *check_ptr = check; + const int16_t *llrs_ptr = llrs + codeword_ind; + int16_t *check_ptr = check; for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - svint8_t llrs_reg = svld1_s8(pg, llrs_ptr); - svint8_t check_reg = svld1_s8(pg, check_ptr); - svint8_t result_reg = sveor_s8_x(pg, check_reg, llrs_reg); - svst1_s8(pg, check_ptr, result_reg); + svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); + svint16_t check_reg = svld1_s16(pg, check_ptr); + svint16_t result_reg = sveor_s16_x(pg, check_reg, llrs_reg); + svst1_s16(pg, check_ptr, result_reg); // Increment pointers llrs_ptr += num_lanes; @@ -263,10 +263,10 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, } // Process tail if (tail_size != 0) { - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t check_reg = svld1_s8(pg_tail, check_ptr); - svint8_t result_reg = sveor_s8_x(pg_tail, check_reg, llrs_reg); - svst1_s8(pg_tail, check_ptr, result_reg); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t check_reg = svld1_s16(pg_tail, check_ptr); + svint16_t result_reg = sveor_s16_x(pg_tail, check_reg, llrs_reg); + svst1_s16(pg_tail, check_ptr, result_reg); } } for (uint32_t zb = 0; zb < z && passed; ++zb) { @@ -284,7 +284,7 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, const auto *shift_ptr = graph->shifts + row_start_ind * armral::ldpc::num_lifting_sets + lsi * num_cols; - memset(check, 0, z * sizeof(int8_t)); + memset(check, 0, z * sizeof(int16_t)); // Loop through the columns for (uint32_t col = 0; col < num_cols; ++col) { @@ -298,27 +298,27 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, // represent a hard decision for the bit to be one, and non-negative // values represent a zero. Hence the check needs to xor all LLRs // and then assert that the result is non-negative. - const int8_t *llrs_ptr = llrs + codeword_ind; - int8_t *check_ptr = check; + const int16_t *llrs_ptr = llrs + codeword_ind; + int16_t *check_ptr = check; - // Process 16 entries at a time + // Process 8 entries at a time for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - int8x16_t llrs_reg = vld1q_s8(llrs_ptr); - int8x16_t check_reg = vld1q_s8(check_ptr); - int8x16_t result_reg = veorq_s8(check_reg, llrs_reg); - vst1q_s8(check_ptr, result_reg); + int16x8_t llrs_reg = vld1q_s16(llrs_ptr); + int16x8_t check_reg = vld1q_s16(check_ptr); + int16x8_t result_reg = veorq_s16(check_reg, llrs_reg); + vst1q_s16(check_ptr, result_reg); // Increment pointers llrs_ptr += num_lanes; check_ptr += num_lanes; } - // Process a group of 8 elts - if (tail_cnt > 7) { - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t check_reg = vld1_s8(check_ptr); - int8x8_t result_reg = veor_s8(check_reg, llrs_reg); - vst1_s8(check_ptr, result_reg); - tail_cnt = z & 0x7; + // Process a group of 4 elts + if (tail_cnt > 3U) { + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t check_reg = vld1_s16(check_ptr); + int16x4_t result_reg = veor_s16(check_reg, llrs_reg); + vst1_s16(check_ptr, result_reg); + tail_cnt = z & 0x3; } // Deal with a tail for (uint32_t zb = z - tail_cnt; zb < z; ++zb) { @@ -342,42 +342,40 @@ bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, // - \min_{n \in \psi(m)} |L(n,m)| and the second minimum (they will be used to // compute |R(n,m)| in a second step) template -void compute_l_product_min1_and_min2(int8_t *l, const int8_t *__restrict__ llrs, - const int8_t *__restrict__ r, - const ldpc_layer_data *d, - int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *row_min_array, - int8_t *row_min2_array, - int8_t *row_sign_array); +void compute_l_product_min1_and_min2( + int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, + const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, + uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, + int16_t *row_sign_array); template<> void compute_l_product_min1_and_min2( - int8_t *l, const int8_t *__restrict__ llrs, const int8_t *__restrict__ r, + int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *row_min_array, int8_t *row_min2_array, - int8_t *row_sign_array) { + uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, + int16_t *row_sign_array) { const auto *r_ptr = r; // Loop through the Z rows in the layer (check node m) for (uint32_t zb = 0; zb < d->z; ++zb) { // Loop through the columns in the row (variable node n in psi(m)) // Column 0 auto shift = (d->shift_ptr[0] + zb) % d->z; - int8_t l_val = llrs[d->col_ptr[0] * d->z + shift] - *(r_ptr++); + int16_t l_val = vqsubh_s16(llrs[d->col_ptr[0] * d->z + shift], *(r_ptr++)); - int8_t row_sign = l_val; + int16_t row_sign = l_val; - int8_t row_min = vqabsb_s8(l_val); + int16_t row_min = vqabsh_s16(l_val); *(l++) = l_val; // Column 1 shift = (d->shift_ptr[1] + zb) % d->z; - l_val = llrs[d->col_ptr[1] * d->z + shift] - *(r_ptr++); + l_val = vqsubh_s16(llrs[d->col_ptr[1] * d->z + shift], *(r_ptr++)); row_sign ^= l_val; - int8_t abs_val = vqabsb_s8(l_val); - int8_t row_min2 = max(row_min, abs_val); + int16_t abs_val = vqabsh_s16(l_val); + int16_t row_min2 = max(row_min, abs_val); row_min = min(row_min, abs_val); *(l++) = l_val; @@ -386,13 +384,13 @@ void compute_l_product_min1_and_min2( for (uint32_t col = 2; col < d->num_cols; ++col) { // Compute L(n,m) = LLR(n) - R(n,m) shift = (d->shift_ptr[col] + zb) % d->z; - l_val = llrs[d->col_ptr[col] * d->z + shift] - *(r_ptr++); + l_val = vqsubh_s16(llrs[d->col_ptr[col] * d->z + shift], *(r_ptr++)); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) row_sign ^= l_val; // Compute the min(|L(n,m)|) and the second minimum - abs_val = vqabsb_s8(l_val); + abs_val = vqabsh_s16(l_val); row_min2 = max(row_min, min(row_min2, abs_val)); row_min = min(row_min, abs_val); @@ -409,30 +407,30 @@ void compute_l_product_min1_and_min2( template<> void compute_l_product_min1_and_min2( - int8_t *l, const int8_t *__restrict__ llrs, const int8_t *__restrict__ r, + int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *row_min_array, int8_t *row_min2_array, - int8_t *row_sign_array) { - // Case for lifting sizes Z such as 8 <= Z < 16 + uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, + int16_t *row_sign_array) { + // Case for lifting sizes Z such as 4 <= Z < 8 #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg_tail = svwhilelt_b8(0, (int)tail_size); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); // Loop through the columns in the row (variable node n in psi(m)) // Column 0 - int8_t *l_ptr = l; + int16_t *l_ptr = l; auto shift = d->shift_ptr[0] % d->z; - const int8_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int8_t *r_ptr = r; + const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; + const int16_t *r_ptr = r; - svint8_t r_reg = svld1_s8(pg_tail, r_ptr); - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg_tail, r_ptr); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - svint8_t row_sign = l_reg; + svint16_t row_sign = l_reg; - svint8_t row_min = svqabs_s8_x(pg_tail, l_reg); + svint16_t row_min = svqabs_s16_x(pg_tail, l_reg); - svst1_s8(pg_tail, l_ptr, l_reg); + svst1_s16(pg_tail, l_ptr, l_reg); // Column 1 l_ptr = l + d->z; @@ -440,17 +438,17 @@ void compute_l_product_min1_and_min2( llrs_ptr = llrs + d->col_ptr[1] * (2 * d->z) + shift; r_ptr = r + d->z; - r_reg = svld1_s8(pg_tail, r_ptr); - llrs_reg = svld1_s8(pg_tail, llrs_ptr); - l_reg = svqsub_s8(llrs_reg, r_reg); + r_reg = svld1_s16(pg_tail, r_ptr); + llrs_reg = svld1_s16(pg_tail, llrs_ptr); + l_reg = svqsub_s16(llrs_reg, r_reg); - row_sign = sveor_s8_x(pg_tail, row_sign, l_reg); + row_sign = sveor_s16_x(pg_tail, row_sign, l_reg); - svint8_t abs_reg = svqabs_s8_x(pg_tail, l_reg); - svint8_t row_min2 = svmax_s8_x(pg_tail, row_min, abs_reg); - row_min = svmin_s8_x(pg_tail, row_min, abs_reg); + svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); + svint16_t row_min2 = svmax_s16_x(pg_tail, row_min, abs_reg); + row_min = svmin_s16_x(pg_tail, row_min, abs_reg); - svst1_s8(pg_tail, l_ptr, l_reg); + svst1_s16(pg_tail, l_ptr, l_reg); // Columns n >= 2 for (uint32_t col = 2; col < d->num_cols; ++col) { @@ -460,52 +458,52 @@ void compute_l_product_min1_and_min2( r_ptr = r + d->z * col; // Compute L(n,m) = LLR(n) - R(n,m) - r_reg = svld1_s8(pg_tail, r_ptr); - llrs_reg = svld1_s8(pg_tail, llrs_ptr); - l_reg = svqsub_s8(llrs_reg, r_reg); + r_reg = svld1_s16(pg_tail, r_ptr); + llrs_reg = svld1_s16(pg_tail, llrs_ptr); + l_reg = svqsub_s16(llrs_reg, r_reg); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - row_sign = sveor_s8_x(pg_tail, row_sign, l_reg); + row_sign = sveor_s16_x(pg_tail, row_sign, l_reg); // Compute the min(|L(n,m)|) and the second minimum - abs_reg = svqabs_s8_x(pg_tail, l_reg); + abs_reg = svqabs_s16_x(pg_tail, l_reg); row_min2 = - svmax_s8_x(pg_tail, row_min, svmin_s8_x(pg_tail, row_min2, abs_reg)); - row_min = svmin_s8_x(pg_tail, row_min, abs_reg); + svmax_s16_x(pg_tail, row_min, svmin_s16_x(pg_tail, row_min2, abs_reg)); + row_min = svmin_s16_x(pg_tail, row_min, abs_reg); // Store L(n,m) - svst1_s8(pg_tail, l_ptr, l_reg); + svst1_s16(pg_tail, l_ptr, l_reg); } // Store the two minima and the product for Z rows - svst1_s8(pg_tail, row_min_array, row_min); - svst1_s8(pg_tail, row_min2_array, row_min2); - svst1_s8(pg_tail, row_sign_array, row_sign); + svst1_s16(pg_tail, row_min_array, row_min); + svst1_s16(pg_tail, row_min2_array, row_min2); + svst1_s16(pg_tail, row_sign_array, row_sign); #else // Loop through the columns in the row (variable node n in psi(m)) // Column 0 - int8_t *l_ptr = l; + int16_t *l_ptr = l; auto shift = d->shift_ptr[0] % d->z; - const int8_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int8_t *r_ptr = r; + const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; + const int16_t *r_ptr = r; - int8x8_t r_reg = vld1_s8(r_ptr); - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t l_reg = vqsub_s8(llrs_reg, r_reg); + int16x4_t r_reg = vld1_s16(r_ptr); + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); - int8x8_t row_sign = l_reg; + int16x4_t row_sign = l_reg; - int8x8_t row_min = vqabs_s8(row_sign); + int16x4_t row_min = vqabs_s16(row_sign); - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); - int8_t l_val; + int16_t l_val; for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { - l_val = llrs_ptr[zb] - r_ptr[zb]; + l_val = vqsubh_s16(llrs_ptr[zb], r_ptr[zb]); row_sign_array[zb] = l_val; - row_min_array[zb] = vqabsb_s8(l_val); + row_min_array[zb] = vqabsh_s16(l_val); l_ptr[zb] = l_val; } @@ -516,24 +514,24 @@ void compute_l_product_min1_and_min2( llrs_ptr = llrs + d->col_ptr[1] * (2 * d->z) + shift; r_ptr = r + d->z; - r_reg = vld1_s8(r_ptr); - llrs_reg = vld1_s8(llrs_ptr); - l_reg = vqsub_s8(llrs_reg, r_reg); + r_reg = vld1_s16(r_ptr); + llrs_reg = vld1_s16(llrs_ptr); + l_reg = vqsub_s16(llrs_reg, r_reg); - row_sign = veor_s8(row_sign, l_reg); + row_sign = veor_s16(row_sign, l_reg); - int8x8_t abs_reg = vqabs_s8(l_reg); - int8x8_t row_min2 = vmax_s8(row_min, abs_reg); - row_min = vmin_s8(row_min, abs_reg); + int16x4_t abs_reg = vqabs_s16(l_reg); + int16x4_t row_min2 = vmax_s16(row_min, abs_reg); + row_min = vmin_s16(row_min, abs_reg); - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { - l_val = llrs_ptr[zb] - r_ptr[zb]; + l_val = vqsubh_s16(llrs_ptr[zb], r_ptr[zb]); row_sign_array[zb] ^= l_val; - int8_t abs_val = vqabsb_s8(l_val); + int16_t abs_val = vqabsh_s16(l_val); row_min2_array[zb] = max(row_min_array[zb], abs_val); row_min_array[zb] = min(row_min_array[zb], abs_val); @@ -548,28 +546,28 @@ void compute_l_product_min1_and_min2( r_ptr = r + d->z * col; // Compute L(n,m) = LLR(n) - R(n,m) - r_reg = vld1_s8(r_ptr); - llrs_reg = vld1_s8(llrs_ptr); - l_reg = vqsub_s8(llrs_reg, r_reg); + r_reg = vld1_s16(r_ptr); + llrs_reg = vld1_s16(llrs_ptr); + l_reg = vqsub_s16(llrs_reg, r_reg); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - row_sign = veor_s8(row_sign, l_reg); + row_sign = veor_s16(row_sign, l_reg); // Compute the min(|L(n,m)|) and the second minimum - abs_reg = vqabs_s8(l_reg); - row_min2 = vmax_s8(row_min, vmin_s8(row_min2, abs_reg)); - row_min = vmin_s8(row_min, abs_reg); + abs_reg = vqabs_s16(l_reg); + row_min2 = vmax_s16(row_min, vmin_s16(row_min2, abs_reg)); + row_min = vmin_s16(row_min, abs_reg); // Store L(n,m) - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); // Process tail for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { - l_val = llrs_ptr[zb] - r_ptr[zb]; + l_val = vqsubh_s16(llrs_ptr[zb], r_ptr[zb]); row_sign_array[zb] ^= l_val; - int8_t abs_val = vqabsb_s8(l_val); + int16_t abs_val = vqabsh_s16(l_val); row_min2_array[zb] = max(row_min_array[zb], min(row_min2_array[zb], abs_val)); row_min_array[zb] = min(row_min_array[zb], abs_val); @@ -579,41 +577,41 @@ void compute_l_product_min1_and_min2( } // Store the two minima and the product for Z rows - vst1_s8(row_min_array, row_min); - vst1_s8(row_min2_array, row_min2); - vst1_s8(row_sign_array, row_sign); + vst1_s16(row_min_array, row_min); + vst1_s16(row_min2_array, row_min2); + vst1_s16(row_sign_array, row_sign); #endif } template<> void compute_l_product_min1_and_min2( - int8_t *l, const int8_t *__restrict__ llrs, const int8_t *__restrict__ r, + int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *row_min_array, int8_t *row_min2_array, - int8_t *row_sign_array) { + uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, + int16_t *row_sign_array) { #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg = svptrue_b8(); - svbool_t pg_tail = svwhilelt_b8(0, tail_size); + svbool_t pg = svptrue_b16(); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); // Loop through the columns in the row (variable node n in psi(m)) // Column 0 - int8_t *l_ptr = l; + int16_t *l_ptr = l; auto shift = d->shift_ptr[0] % d->z; - const int8_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int8_t *r_ptr = r; - int8_t *sign_ptr = row_sign_array; - int8_t *min_ptr = row_min_array; + const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; + const int16_t *r_ptr = r; + int16_t *sign_ptr = row_sign_array; + int16_t *min_ptr = row_min_array; for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - svint8_t r_reg = svld1_s8(pg, r_ptr); - svint8_t llrs_reg = svld1_s8(pg, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg, r_ptr); + svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - svst1_s8(pg, sign_ptr, l_reg); + svst1_s16(pg, sign_ptr, l_reg); - svst1_s8(pg, min_ptr, svqabs_s8_x(pg, l_reg)); + svst1_s16(pg, min_ptr, svqabs_s16_x(pg, l_reg)); - svst1_s8(pg, l_ptr, l_reg); + svst1_s16(pg, l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -623,15 +621,15 @@ void compute_l_product_min1_and_min2( } if (tail_size != 0) { - svint8_t r_reg = svld1_s8(pg_tail, r_ptr); - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg_tail, r_ptr); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - svst1_s8(pg_tail, sign_ptr, l_reg); + svst1_s16(pg_tail, sign_ptr, l_reg); - svst1_s8(pg_tail, min_ptr, svqabs_s8_x(pg_tail, l_reg)); + svst1_s16(pg_tail, min_ptr, svqabs_s16_x(pg_tail, l_reg)); - svst1_s8(pg_tail, l_ptr, l_reg); + svst1_s16(pg_tail, l_ptr, l_reg); } // Column 1 @@ -641,22 +639,22 @@ void compute_l_product_min1_and_min2( r_ptr = r + d->z; sign_ptr = row_sign_array; min_ptr = row_min_array; - int8_t *min2_ptr = row_min2_array; + int16_t *min2_ptr = row_min2_array; for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - svint8_t r_reg = svld1_s8(pg, r_ptr); - svint8_t llrs_reg = svld1_s8(pg, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg, r_ptr); + svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - svint8_t sign_reg = svld1_s8(pg, sign_ptr); - svst1_s8(pg, sign_ptr, sveor_s8_x(pg, sign_reg, l_reg)); + svint16_t sign_reg = svld1_s16(pg, sign_ptr); + svst1_s16(pg, sign_ptr, sveor_s16_x(pg, sign_reg, l_reg)); - svint8_t min_reg = svld1_s8(pg, min_ptr); - svint8_t abs_reg = svqabs_s8_x(pg, l_reg); - svst1_s8(pg, min2_ptr, svmax_s8_x(pg, min_reg, abs_reg)); - svst1_s8(pg, min_ptr, svmin_s8_x(pg, min_reg, abs_reg)); + svint16_t min_reg = svld1_s16(pg, min_ptr); + svint16_t abs_reg = svqabs_s16_x(pg, l_reg); + svst1_s16(pg, min2_ptr, svmax_s16_x(pg, min_reg, abs_reg)); + svst1_s16(pg, min_ptr, svmin_s16_x(pg, min_reg, abs_reg)); - svst1_s8(pg, l_ptr, l_reg); + svst1_s16(pg, l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -667,19 +665,19 @@ void compute_l_product_min1_and_min2( } if (tail_size != 0) { - svint8_t r_reg = svld1_s8(pg_tail, r_ptr); - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg_tail, r_ptr); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - svint8_t sign_reg = svld1_s8(pg_tail, sign_ptr); - svst1_s8(pg_tail, sign_ptr, sveor_s8_x(pg_tail, sign_reg, l_reg)); + svint16_t sign_reg = svld1_s16(pg_tail, sign_ptr); + svst1_s16(pg_tail, sign_ptr, sveor_s16_x(pg_tail, sign_reg, l_reg)); - svint8_t min_reg = svld1_s8(pg_tail, min_ptr); - svint8_t abs_reg = svqabs_s8_x(pg_tail, l_reg); - svst1_s8(pg_tail, min2_ptr, svmax_s8_x(pg_tail, min_reg, abs_reg)); - svst1_s8(pg_tail, min_ptr, svmin_s8_x(pg_tail, min_reg, abs_reg)); + svint16_t min_reg = svld1_s16(pg_tail, min_ptr); + svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); + svst1_s16(pg_tail, min2_ptr, svmax_s16_x(pg_tail, min_reg, abs_reg)); + svst1_s16(pg_tail, min_ptr, svmin_s16_x(pg_tail, min_reg, abs_reg)); - svst1_s8(pg_tail, l_ptr, l_reg); + svst1_s16(pg_tail, l_ptr, l_reg); } // Columns n >= 2 @@ -695,24 +693,24 @@ void compute_l_product_min1_and_min2( // Loop through the Z rows in the layer (check node m) for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { // Compute L(n,m) = LLR(n) - R(n,m) - svint8_t r_reg = svld1_s8(pg, r_ptr); - svint8_t llrs_reg = svld1_s8(pg, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); + svint16_t r_reg = svld1_s16(pg, r_ptr); + svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - svint8_t sign_reg = svld1_s8(pg, sign_ptr); - svst1_s8(pg, sign_ptr, sveor_s8_x(pg, sign_reg, l_reg)); + svint16_t sign_reg = svld1_s16(pg, sign_ptr); + svst1_s16(pg, sign_ptr, sveor_s16_x(pg, sign_reg, l_reg)); // Compute the min(|L(n,m)|) and the second minimum - svint8_t min_reg = svld1_s8(pg, min_ptr); - svint8_t min2_reg = svld1_s8(pg, min2_ptr); - svint8_t abs_reg = svqabs_s8_x(pg, l_reg); - svst1_s8(pg, min2_ptr, - svmax_s8_x(pg, min_reg, svmin_s8_x(pg, min2_reg, abs_reg))); - svst1_s8(pg, min_ptr, svmin_s8_x(pg, min_reg, abs_reg)); + svint16_t min_reg = svld1_s16(pg, min_ptr); + svint16_t min2_reg = svld1_s16(pg, min2_ptr); + svint16_t abs_reg = svqabs_s16_x(pg, l_reg); + svst1_s16(pg, min2_ptr, + svmax_s16_x(pg, min_reg, svmin_s16_x(pg, min2_reg, abs_reg))); + svst1_s16(pg, min_ptr, svmin_s16_x(pg, min_reg, abs_reg)); // Store L(n,m) - svst1_s8(pg, l_ptr, l_reg); + svst1_s16(pg, l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -724,45 +722,45 @@ void compute_l_product_min1_and_min2( // Process tail if (tail_size != 0) { - svint8_t r_reg = svld1_s8(pg_tail, r_ptr); - svint8_t llrs_reg = svld1_s8(pg_tail, llrs_ptr); - svint8_t l_reg = svqsub_s8(llrs_reg, r_reg); - - svint8_t sign_reg = svld1_s8(pg_tail, sign_ptr); - svst1_s8(pg_tail, sign_ptr, sveor_s8_x(pg_tail, sign_reg, l_reg)); - - svint8_t min_reg = svld1_s8(pg_tail, min_ptr); - svint8_t min2_reg = svld1_s8(pg_tail, min2_ptr); - svint8_t abs_reg = svqabs_s8_x(pg_tail, l_reg); - svst1_s8( - pg_tail, min2_ptr, - svmax_s8_x(pg_tail, min_reg, svmin_s8_x(pg_tail, min2_reg, abs_reg))); - svst1_s8(pg_tail, min_ptr, svmin_s8_x(pg_tail, min_reg, abs_reg)); - - svst1_s8(pg_tail, l_ptr, l_reg); + svint16_t r_reg = svld1_s16(pg_tail, r_ptr); + svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); + svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); + + svint16_t sign_reg = svld1_s16(pg_tail, sign_ptr); + svst1_s16(pg_tail, sign_ptr, sveor_s16_x(pg_tail, sign_reg, l_reg)); + + svint16_t min_reg = svld1_s16(pg_tail, min_ptr); + svint16_t min2_reg = svld1_s16(pg_tail, min2_ptr); + svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); + svst1_s16(pg_tail, min2_ptr, + svmax_s16_x(pg_tail, min_reg, + svmin_s16_x(pg_tail, min2_reg, abs_reg))); + svst1_s16(pg_tail, min_ptr, svmin_s16_x(pg_tail, min_reg, abs_reg)); + + svst1_s16(pg_tail, l_ptr, l_reg); } } #else // Loop through the columns in the row (variable node n in psi(m)) // Column 0 - int8_t *l_ptr = l; + int16_t *l_ptr = l; auto shift = d->shift_ptr[0] % d->z; - const int8_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int8_t *r_ptr = r; - int8_t *sign_ptr = row_sign_array; - int8_t *min_ptr = row_min_array; + const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; + const int16_t *r_ptr = r; + int16_t *sign_ptr = row_sign_array; + int16_t *min_ptr = row_min_array; uint32_t tail_cnt = tail_size; for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { - int8x16_t r_reg = vld1q_s8(r_ptr); - int8x16_t llrs_reg = vld1q_s8(llrs_ptr); - int8x16_t l_reg = vqsubq_s8(llrs_reg, r_reg); + int16x8_t r_reg = vld1q_s16(r_ptr); + int16x8_t llrs_reg = vld1q_s16(llrs_ptr); + int16x8_t l_reg = vqsubq_s16(llrs_reg, r_reg); - vst1q_s8(sign_ptr, l_reg); + vst1q_s16(sign_ptr, l_reg); - vst1q_s8(min_ptr, vqabsq_s8(l_reg)); + vst1q_s16(min_ptr, vqabsq_s16(l_reg)); - vst1q_s8(l_ptr, l_reg); + vst1q_s16(l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -771,26 +769,26 @@ void compute_l_product_min1_and_min2( llrs_ptr += num_lanes; } - if (tail_cnt > 7U) { - int8x8_t r_reg = vld1_s8(r_ptr); - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t l_reg = vqsub_s8(llrs_reg, r_reg); + if (tail_cnt > 3U) { + int16x4_t r_reg = vld1_s16(r_ptr); + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); - vst1_s8(sign_ptr, l_reg); + vst1_s16(sign_ptr, l_reg); - vst1_s8(min_ptr, vqabs_s8(l_reg)); + vst1_s16(min_ptr, vqabs_s16(l_reg)); - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); - tail_cnt = d->z & 0x7; + tail_cnt = d->z & 0x3; } for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { - l[zb] = llrs[d->col_ptr[0] * (2 * d->z) + shift + zb] - r[zb]; + l[zb] = vqsubh_s16(llrs[d->col_ptr[0] * (2 * d->z) + shift + zb], r[zb]); row_sign_array[zb] = l[zb]; - row_min_array[zb] = vqabsb_s8(l[zb]); + row_min_array[zb] = vqabsh_s16(l[zb]); } // Column 1 @@ -798,25 +796,25 @@ void compute_l_product_min1_and_min2( llrs_ptr = llrs + d->col_ptr[1] * (2 * d->z) + shift; sign_ptr = row_sign_array; min_ptr = row_min_array; - int8_t *min2_ptr = row_min2_array; + int16_t *min2_ptr = row_min2_array; l_ptr = l + d->z; r_ptr = r + d->z; tail_cnt = tail_size; for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { - int8x16_t r_reg = vld1q_s8(r_ptr); - int8x16_t llrs_reg = vld1q_s8(llrs_ptr); - int8x16_t l_reg = vqsubq_s8(llrs_reg, r_reg); + int16x8_t r_reg = vld1q_s16(r_ptr); + int16x8_t llrs_reg = vld1q_s16(llrs_ptr); + int16x8_t l_reg = vqsubq_s16(llrs_reg, r_reg); - int8x16_t sign_reg = vld1q_s8(sign_ptr); - vst1q_s8(sign_ptr, veorq_s8(sign_reg, l_reg)); + int16x8_t sign_reg = vld1q_s16(sign_ptr); + vst1q_s16(sign_ptr, veorq_s16(sign_reg, l_reg)); - int8x16_t min_reg = vld1q_s8(min_ptr); - int8x16_t abs_reg = vqabsq_s8(l_reg); - vst1q_s8(min2_ptr, vmaxq_s8(min_reg, abs_reg)); - vst1q_s8(min_ptr, vminq_s8(min_reg, abs_reg)); + int16x8_t min_reg = vld1q_s16(min_ptr); + int16x8_t abs_reg = vqabsq_s16(l_reg); + vst1q_s16(min2_ptr, vmaxq_s16(min_reg, abs_reg)); + vst1q_s16(min_ptr, vminq_s16(min_reg, abs_reg)); - vst1q_s8(l_ptr, l_reg); + vst1q_s16(l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -826,30 +824,31 @@ void compute_l_product_min1_and_min2( llrs_ptr += num_lanes; } - if (tail_cnt > 7U) { - int8x8_t r_reg = vld1_s8(r_ptr); - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t l_reg = vqsub_s8(llrs_reg, r_reg); + if (tail_cnt > 3U) { + int16x4_t r_reg = vld1_s16(r_ptr); + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); - int8x8_t sign_reg = vld1_s8(sign_ptr); - vst1_s8(sign_ptr, veor_s8(sign_reg, l_reg)); + int16x4_t sign_reg = vld1_s16(sign_ptr); + vst1_s16(sign_ptr, veor_s16(sign_reg, l_reg)); - int8x8_t min_reg = vld1_s8(min_ptr); - int8x8_t abs_reg = vqabs_s8(l_reg); - vst1_s8(min2_ptr, vmax_s8(min_reg, abs_reg)); - vst1_s8(min_ptr, vmin_s8(min_reg, abs_reg)); + int16x4_t min_reg = vld1_s16(min_ptr); + int16x4_t abs_reg = vqabs_s16(l_reg); + vst1_s16(min2_ptr, vmax_s16(min_reg, abs_reg)); + vst1_s16(min_ptr, vmin_s16(min_reg, abs_reg)); - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); - tail_cnt = d->z & 0x7; + tail_cnt = d->z & 0x3; } for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { - l[d->z + zb] = llrs[d->col_ptr[1] * (2 * d->z) + shift + zb] - r[d->z + zb]; + l[d->z + zb] = + vqsubh_s16(llrs[d->col_ptr[1] * (2 * d->z) + shift + zb], r[d->z + zb]); row_sign_array[zb] ^= l[d->z + zb]; - int8_t abs_val = vqabsb_s8(l[d->z + zb]); + int16_t abs_val = vqabsh_s16(l[d->z + zb]); row_min2_array[zb] = max(row_min_array[zb], abs_val); row_min_array[zb] = min(row_min_array[zb], abs_val); } @@ -866,26 +865,26 @@ void compute_l_product_min1_and_min2( tail_cnt = tail_size; // Loop through the Z rows in the layer (check node m) - // Process 16 entries at a time + // Process 8 entries at a time for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { // Compute L(n,m) = LLR(n) - R(n,m) - int8x16_t r_reg = vld1q_s8(r_ptr); - int8x16_t llrs_reg = vld1q_s8(llrs_ptr); - int8x16_t l_reg = vqsubq_s8(llrs_reg, r_reg); + int16x8_t r_reg = vld1q_s16(r_ptr); + int16x8_t llrs_reg = vld1q_s16(llrs_ptr); + int16x8_t l_reg = vqsubq_s16(llrs_reg, r_reg); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - int8x16_t sign_reg = vld1q_s8(sign_ptr); - vst1q_s8(sign_ptr, veorq_s8(sign_reg, l_reg)); + int16x8_t sign_reg = vld1q_s16(sign_ptr); + vst1q_s16(sign_ptr, veorq_s16(sign_reg, l_reg)); // Compute the min(|L(n,m)|) and the second minimum - int8x16_t min_reg = vld1q_s8(min_ptr); - int8x16_t min2_reg = vld1q_s8(min2_ptr); - int8x16_t abs_reg = vqabsq_s8(l_reg); - vst1q_s8(min2_ptr, vmaxq_s8(min_reg, vminq_s8(min2_reg, abs_reg))); - vst1q_s8(min_ptr, vminq_s8(min_reg, abs_reg)); + int16x8_t min_reg = vld1q_s16(min_ptr); + int16x8_t min2_reg = vld1q_s16(min2_ptr); + int16x8_t abs_reg = vqabsq_s16(l_reg); + vst1q_s16(min2_ptr, vmaxq_s16(min_reg, vminq_s16(min2_reg, abs_reg))); + vst1q_s16(min_ptr, vminq_s16(min_reg, abs_reg)); // Store L(n,m) - vst1q_s8(l_ptr, l_reg); + vst1q_s16(l_ptr, l_reg); sign_ptr += num_lanes; min_ptr += num_lanes; @@ -895,41 +894,41 @@ void compute_l_product_min1_and_min2( llrs_ptr += num_lanes; } - // Process a group of 8 elts - if (tail_cnt > 7U) { + // Process a group of 4 elts + if (tail_cnt > 3U) { // Compute L(n,m) = LLR(n) - R(n,m) - int8x8_t r_reg = vld1_s8(r_ptr); - int8x8_t llrs_reg = vld1_s8(llrs_ptr); - int8x8_t l_reg = vqsub_s8(llrs_reg, r_reg); + int16x4_t r_reg = vld1_s16(r_ptr); + int16x4_t llrs_reg = vld1_s16(llrs_ptr); + int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - int8x8_t sign_reg = vld1_s8(sign_ptr); - vst1_s8(sign_ptr, veor_s8(sign_reg, l_reg)); + int16x4_t sign_reg = vld1_s16(sign_ptr); + vst1_s16(sign_ptr, veor_s16(sign_reg, l_reg)); // Compute the min(|L(n,m)|) and the second minimum - int8x8_t min_reg = vld1_s8(min_ptr); - int8x8_t min2_reg = vld1_s8(min2_ptr); - int8x8_t abs_reg = vqabs_s8(l_reg); - vst1_s8(min2_ptr, vmax_s8(min_reg, vmin_s8(min2_reg, abs_reg))); - vst1_s8(min_ptr, vmin_s8(min_reg, abs_reg)); + int16x4_t min_reg = vld1_s16(min_ptr); + int16x4_t min2_reg = vld1_s16(min2_ptr); + int16x4_t abs_reg = vqabs_s16(l_reg); + vst1_s16(min2_ptr, vmax_s16(min_reg, vmin_s16(min2_reg, abs_reg))); + vst1_s16(min_ptr, vmin_s16(min_reg, abs_reg)); // Store L(n,m) - vst1_s8(l_ptr, l_reg); + vst1_s16(l_ptr, l_reg); - tail_cnt = d->z & 0x7; + tail_cnt = d->z & 0x3; } // Process tail for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { // Compute L(n,m) = LLR(n) - R(n,m) - l[d->z * col + zb] = - llrs[d->col_ptr[col] * (2 * d->z) + shift + zb] - r[d->z * col + zb]; + l[d->z * col + zb] = vqsubh_s16( + llrs[d->col_ptr[col] * (2 * d->z) + shift + zb], r[d->z * col + zb]); // Compute the product of L(n',m), for all the columns (all n' in psi(m)) row_sign_array[zb] ^= l[d->z * col + zb]; // Compute the min(|L(n,m)|) and the second minimum - int8_t abs_val = vqabsb_s8(l[d->z * col + zb]); + int16_t abs_val = vqabsh_s16(l[d->z * col + zb]); row_min2_array[zb] = max(row_min_array[zb], min(row_min2_array[zb], abs_val)); row_min_array[zb] = min(row_min_array[zb], abs_val); @@ -948,38 +947,38 @@ void compute_l_product_min1_and_min2( // - The log likelihood ratios for each n in \psi(m): // LLR(n) = R(n,m) + L(n,m) template -void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, +void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, int32_t tail_size, - const int8_t *row_min_array, - const int8_t *row_min2_array, - const int8_t *row_sign_array); + int32_t full_vec, uint32_t tail_size, + const int16_t *row_min_array, + const int16_t *row_min2_array, + const int16_t *row_sign_array); template<> -void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, +void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, int32_t tail_size, - const int8_t *row_min_array, - const int8_t *row_min2_array, - const int8_t *row_sign_array) { + int32_t full_vec, uint32_t tail_size, + const int16_t *row_min_array, + const int16_t *row_min2_array, + const int16_t *row_sign_array) { // Loop through the Z rows in the layer (check node m) for (uint32_t zb = 0; zb < d->z; ++zb) { - const int8_t *l_ptr = l + zb * d->num_cols; + const int16_t *l_ptr = l + zb * d->num_cols; // Loop through the columns in the row (variable node n in psi(m)) for (uint32_t col = 0; col < d->num_cols; ++col) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8_t col_sign = (row_sign_array[zb] ^ l_ptr[col]) < 0 ? -1 : 1; + int16_t col_sign = (row_sign_array[zb] ^ l_ptr[col]) < 0 ? -1 : 1; // Compute R(n,m) - int8_t abs_val = vqabsb_s8(l_ptr[col]); - int8_t r_val = + int16_t abs_val = vqabsh_s16(l_ptr[col]); + int16_t r_val = col_sign * (abs_val == row_min_array[zb] ? row_min2_array[zb] : row_min_array[zb]); // Compute LLR(n) = R(n,m) + L(n,m) auto shift = (d->shift_ptr[col] + zb) % d->z; auto col_ind = d->col_ptr[col] * d->z + shift; - llrs[col_ind] = vqaddb_s8(r_val, l_ptr[col]); + llrs[col_ind] = vqaddh_s16(r_val, l_ptr[col]); // Store R(n,m) for the next iteration r[col] = r_val; @@ -988,158 +987,158 @@ void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, } template<> -void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, +void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, int32_t tail_size, - const int8_t *row_min_array, - const int8_t *row_min2_array, - const int8_t *row_sign_array) { - // Case for lifting sizes 8 <= Z < 16 (rows in the layer) + int32_t full_vec, uint32_t tail_size, + const int16_t *row_min_array, + const int16_t *row_min2_array, + const int16_t *row_sign_array) { + // Case for lifting sizes 4 <= Z < 8 (rows in the layer) #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg_tail = svwhilelt_b8(0, (int)tail_size); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); - svint8_t row_min = svld1_s8(pg_tail, row_min_array); - svint8_t row_min2 = svld1_s8(pg_tail, row_min2_array); - svint8_t row_sign = svld1_s8(pg_tail, row_sign_array); + svint16_t row_min = svld1_s16(pg_tail, row_min_array); + svint16_t row_min2 = svld1_s16(pg_tail, row_min2_array); + svint16_t row_sign = svld1_s16(pg_tail, row_sign_array); // Loop through the columns in the row (variable node n in psi(m)) for (uint32_t col = 0; col < d->num_cols; ++col) { auto shift = d->shift_ptr[col] % d->z; auto col_ind = d->col_ptr[col] * (2 * d->z); - int8_t *r_ptr = r + d->z * col; - const int8_t *l_ptr = l + d->z * col; - int8_t *llrs_ptr = llrs + col_ind + shift; + int16_t *r_ptr = r + d->z * col; + const int16_t *l_ptr = l + d->z * col; + int16_t *llrs_ptr = llrs + col_ind + shift; // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - svint8_t l_reg = svld1_s8(pg_tail, l_ptr); - svint8_t abs_reg = svqabs_s8_x(pg_tail, l_reg); - svint8_t eor_reg = sveor_s8_x(pg_tail, row_sign, l_reg); - svbool_t pg_tail_neg = svcmplt_n_s8(pg_tail, eor_reg, 0); + svint16_t l_reg = svld1_s16(pg_tail, l_ptr); + svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); + svint16_t eor_reg = sveor_s16_x(pg_tail, row_sign, l_reg); + svbool_t pg_tail_neg = svcmplt_n_s16(pg_tail, eor_reg, 0); // Compute R(n,m) - svbool_t pg_tail_eq = svcmpeq_s8(pg_tail, abs_reg, row_min); - svint8_t tmp_reg = svsel_s8(pg_tail_eq, row_min2, row_min); - svint8_t r_reg = svneg_s8_m(tmp_reg, pg_tail_neg, tmp_reg); + svbool_t pg_tail_eq = svcmpeq_s16(pg_tail, abs_reg, row_min); + svint16_t tmp_reg = svsel_s16(pg_tail_eq, row_min2, row_min); + svint16_t r_reg = svneg_s16_m(tmp_reg, pg_tail_neg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - svint8_t result = svqadd_s8_x(pg_tail, r_reg, l_reg); - svst1_s8(pg_tail, llrs_ptr, result); + svint16_t result = svqadd_s16_x(pg_tail, r_reg, l_reg); + svst1_s16(pg_tail, llrs_ptr, result); // Store R(n,m) for the next iteration - svst1_s8(pg_tail, r_ptr, r_reg); + svst1_s16(pg_tail, r_ptr, r_reg); // Rearrange LLRs - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int8_t)); + memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); // copy (z - shift) elts in the main block to the replicated block memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int8_t)); + (d->z - shift) * sizeof(int16_t)); } #else - int8x8_t row_min = vld1_s8(row_min_array); - int8x8_t row_min2 = vld1_s8(row_min2_array); - int8x8_t row_sign = vld1_s8(row_sign_array); + int16x4_t row_min = vld1_s16(row_min_array); + int16x4_t row_min2 = vld1_s16(row_min2_array); + int16x4_t row_sign = vld1_s16(row_sign_array); // Loop through the columns in the row (variable node n in psi(m)) for (uint32_t col = 0; col < d->num_cols; ++col) { auto shift = d->shift_ptr[col] % d->z; auto col_ind = d->col_ptr[col] * (2 * d->z); - int8_t *r_ptr = r + d->z * col; - const int8_t *l_ptr = l + d->z * col; - int8_t *llrs_ptr = llrs + col_ind + shift; + int16_t *r_ptr = r + d->z * col; + const int16_t *l_ptr = l + d->z * col; + int16_t *llrs_ptr = llrs + col_ind + shift; // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8x8_t l_reg = vld1_s8(l_ptr); - int8x8_t abs_reg = vqabs_s8(l_reg); - int8x8_t eor_reg = veor_s8(row_sign, l_reg); + int16x4_t l_reg = vld1_s16(l_ptr); + int16x4_t abs_reg = vqabs_s16(l_reg); + int16x4_t eor_reg = veor_s16(row_sign, l_reg); // sign_col_reg will contain a 1 in all lanes for negative values - uint8x8_t sign_col_reg = vcltz_s8(eor_reg); + uint16x4_t sign_col_reg = vcltz_s16(eor_reg); // Compute R(n,m) // Get a mask for the minimum value in a lane - uint8x8_t check_eq = vceq_s8(abs_reg, row_min); - int8x8_t tmp_reg = vbsl_s8(check_eq, row_min2, row_min); + uint16x4_t check_eq = vceq_s16(abs_reg, row_min); + int16x4_t tmp_reg = vbsl_s16(check_eq, row_min2, row_min); // Negate the absolute values - int8x8_t neg_abs_reg = vneg_s8(tmp_reg); - int8x8_t r_reg = vbsl_s8(sign_col_reg, neg_abs_reg, tmp_reg); + int16x4_t neg_abs_reg = vneg_s16(tmp_reg); + int16x4_t r_reg = vbsl_s16(sign_col_reg, neg_abs_reg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - int8x8_t result = vqadd_s8(r_reg, l_reg); - vst1_s8(llrs_ptr, result); + int16x4_t result = vqadd_s16(r_reg, l_reg); + vst1_s16(llrs_ptr, result); // Store R(n,m) for the next iteration - vst1_s8(r_ptr, r_reg); + vst1_s16(r_ptr, r_reg); // Process tail for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8_t col_sign = (row_sign_array[zb] ^ l[d->z * col + zb]) < 0 ? -1 : 1; + int16_t col_sign = (row_sign_array[zb] ^ l[d->z * col + zb]) < 0 ? -1 : 1; // Compute R(n,m) - int8_t abs_val = vqabsb_s8(l[d->z * col + zb]); - int8_t r_val = + int16_t abs_val = vqabsh_s16(l[d->z * col + zb]); + int16_t r_val = col_sign * (abs_val == row_min_array[zb] ? row_min2_array[zb] : row_min_array[zb]); // Compute LLR(n) = R(n,m) + L(n,m) - llrs[col_ind + shift + zb] = vqaddb_s8(r_val, *(l + d->z * col + zb)); + llrs[col_ind + shift + zb] = vqaddh_s16(r_val, *(l + d->z * col + zb)); // Store R(n,m) for the next iteration r[d->z * col + zb] = r_val; } // Rearrange LLRs - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int8_t)); + memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); // copy (z - shift) elts in the main block to the replicated block memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int8_t)); + (d->z - shift) * sizeof(int16_t)); } #endif } template<> -void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, +void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, int32_t tail_size, - const int8_t *row_min_array, - const int8_t *row_min2_array, - const int8_t *row_sign_array) { + int32_t full_vec, uint32_t tail_size, + const int16_t *row_min_array, + const int16_t *row_min2_array, + const int16_t *row_sign_array) { #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg = svptrue_b8(); - svbool_t pg_tail = svwhilelt_b8(0, tail_size); + svbool_t pg = svptrue_b16(); + svbool_t pg_tail = svwhilelt_b16(0U, tail_size); // Loop through the columns in the row (variable node n in psi(m)) for (uint32_t col = 0; col < d->num_cols; ++col) { auto shift = d->shift_ptr[col] % d->z; auto col_ind = d->col_ptr[col] * (2 * d->z); - int8_t *llrs_ptr = llrs + col_ind + shift; - const int8_t *l_ptr = l + d->z * col; - int8_t *r_ptr = r + d->z * col; - const int8_t *sign_ptr = row_sign_array; - const int8_t *min_ptr = row_min_array; - const int8_t *min2_ptr = row_min2_array; + int16_t *llrs_ptr = llrs + col_ind + shift; + const int16_t *l_ptr = l + d->z * col; + int16_t *r_ptr = r + d->z * col; + const int16_t *sign_ptr = row_sign_array; + const int16_t *min_ptr = row_min_array; + const int16_t *min2_ptr = row_min2_array; // Loop through the Z rows in the layer (check node m) for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - svint8_t l_reg = svld1_s8(pg, l_ptr); - svint8_t sign_reg = svld1_s8(pg, sign_ptr); - svint8_t eor_reg = sveor_s8_x(pg, sign_reg, l_reg); - svbool_t pg_neg = svcmplt_n_s8(pg, eor_reg, 0); + svint16_t l_reg = svld1_s16(pg, l_ptr); + svint16_t sign_reg = svld1_s16(pg, sign_ptr); + svint16_t eor_reg = sveor_s16_x(pg, sign_reg, l_reg); + svbool_t pg_neg = svcmplt_n_s16(pg, eor_reg, 0); // Compute R(n,m) - svint8_t min_reg = svld1_s8(pg, min_ptr); - svint8_t min2_reg = svld1_s8(pg, min2_ptr); - svint8_t abs_reg = svqabs_s8_x(pg, l_reg); - svbool_t pg_eq = svcmpeq_s8(pg, abs_reg, min_reg); - svint8_t tmp_reg = svsel_s8(pg_eq, min2_reg, min_reg); - svint8_t r_reg = svneg_s8_m(tmp_reg, pg_neg, tmp_reg); + svint16_t min_reg = svld1_s16(pg, min_ptr); + svint16_t min2_reg = svld1_s16(pg, min2_ptr); + svint16_t abs_reg = svqabs_s16_x(pg, l_reg); + svbool_t pg_eq = svcmpeq_s16(pg, abs_reg, min_reg); + svint16_t tmp_reg = svsel_s16(pg_eq, min2_reg, min_reg); + svint16_t r_reg = svneg_s16_m(tmp_reg, pg_neg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - svint8_t result = svqadd_s8_x(pg, r_reg, l_reg); - svst1_s8(pg, llrs_ptr, result); + svint16_t result = svqadd_s16_x(pg, r_reg, l_reg); + svst1_s16(pg, llrs_ptr, result); // Store R(n,m) for the next iteration - svst1_s8(pg, r_ptr, r_reg); + svst1_s16(pg, r_ptr, r_reg); // Increment pointers l_ptr += num_lanes; @@ -1152,75 +1151,75 @@ void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, if (tail_size != 0) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - svint8_t l_reg = svld1_s8(pg_tail, l_ptr); - svint8_t sign_reg = svld1_s8(pg_tail, sign_ptr); - svint8_t eor_reg = sveor_s8_x(pg_tail, sign_reg, l_reg); - svbool_t pg_tail_neg = svcmplt_n_s8(pg_tail, eor_reg, 0); + svint16_t l_reg = svld1_s16(pg_tail, l_ptr); + svint16_t sign_reg = svld1_s16(pg_tail, sign_ptr); + svint16_t eor_reg = sveor_s16_x(pg_tail, sign_reg, l_reg); + svbool_t pg_tail_neg = svcmplt_n_s16(pg_tail, eor_reg, 0); // Compute R(n,m) - svint8_t min_reg = svld1_s8(pg_tail, min_ptr); - svint8_t min2_reg = svld1_s8(pg_tail, min2_ptr); - svint8_t abs_reg = svqabs_s8_x(pg_tail, l_reg); - svbool_t pg_tail_eq = svcmpeq_s8(pg_tail, abs_reg, min_reg); - svint8_t tmp_reg = svsel_s8(pg_tail_eq, min2_reg, min_reg); - svint8_t r_reg = svneg_s8_m(tmp_reg, pg_tail_neg, tmp_reg); + svint16_t min_reg = svld1_s16(pg_tail, min_ptr); + svint16_t min2_reg = svld1_s16(pg_tail, min2_ptr); + svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); + svbool_t pg_tail_eq = svcmpeq_s16(pg_tail, abs_reg, min_reg); + svint16_t tmp_reg = svsel_s16(pg_tail_eq, min2_reg, min_reg); + svint16_t r_reg = svneg_s16_m(tmp_reg, pg_tail_neg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - svint8_t result = svqadd_s8_x(pg_tail, r_reg, l_reg); - svst1_s8(pg_tail, llrs_ptr, result); + svint16_t result = svqadd_s16_x(pg_tail, r_reg, l_reg); + svst1_s16(pg_tail, llrs_ptr, result); // Store R(n,m) for the next iteration - svst1_s8(pg_tail, r_ptr, r_reg); + svst1_s16(pg_tail, r_ptr, r_reg); } // Rearrange LLRs // copy shifted elements in the replicated block // back to the beginning of the main block - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int8_t)); + memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); // copy (z - shift) elts in the main block to the replicated block memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int8_t)); + (d->z - shift) * sizeof(int16_t)); } #else // Loop through the columns in the row (variable node n in psi(m)) for (uint32_t col = 0; col < d->num_cols; ++col) { - int8_t *r_ptr = r + d->z * col; - const int8_t *sign_ptr = row_sign_array; - const int8_t *min_ptr = row_min_array; - const int8_t *min2_ptr = row_min2_array; + int16_t *r_ptr = r + d->z * col; + const int16_t *sign_ptr = row_sign_array; + const int16_t *min_ptr = row_min_array; + const int16_t *min2_ptr = row_min2_array; auto shift = d->shift_ptr[col] % d->z; auto col_ind = d->col_ptr[col] * (2 * d->z); - const int8_t *l_ptr = l + d->z * col; - int8_t *llrs_ptr = llrs + col_ind + shift; + const int16_t *l_ptr = l + d->z * col; + int16_t *llrs_ptr = llrs + col_ind + shift; uint32_t tail_cnt = tail_size; // Loop through the Z rows in the layer (check node m) - // Process 16 entries at a time + // Process 8 entries at a time for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8x16_t l_reg = vld1q_s8(l_ptr); - int8x16_t sign_reg = vld1q_s8(sign_ptr); - int8x16_t eor_reg = veorq_s8(sign_reg, l_reg); + int16x8_t l_reg = vld1q_s16(l_ptr); + int16x8_t sign_reg = vld1q_s16(sign_ptr); + int16x8_t eor_reg = veorq_s16(sign_reg, l_reg); // sign_col_reg will contain a 1 in all lanes for negative values - uint8x16_t sign_col_reg = vcltzq_s8(eor_reg); + uint16x8_t sign_col_reg = vcltzq_s16(eor_reg); // Compute R(n,m) - int8x16_t min_reg = vld1q_s8(min_ptr); - int8x16_t min2_reg = vld1q_s8(min2_ptr); - int8x16_t abs_reg = vqabsq_s8(l_reg); + int16x8_t min_reg = vld1q_s16(min_ptr); + int16x8_t min2_reg = vld1q_s16(min2_ptr); + int16x8_t abs_reg = vqabsq_s16(l_reg); // Get a mask for the minimum value in a lane - uint8x16_t check_eq = vceqq_s8(abs_reg, min_reg); - int8x16_t tmp_reg = vbslq_s8(check_eq, min2_reg, min_reg); + uint16x8_t check_eq = vceqq_s16(abs_reg, min_reg); + int16x8_t tmp_reg = vbslq_s16(check_eq, min2_reg, min_reg); // Negate the absolute values - int8x16_t neg_abs_reg = vnegq_s8(tmp_reg); - int8x16_t r_reg = vbslq_s8(sign_col_reg, neg_abs_reg, tmp_reg); + int16x8_t neg_abs_reg = vnegq_s16(tmp_reg); + int16x8_t r_reg = vbslq_s16(sign_col_reg, neg_abs_reg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - int8x16_t result = vqaddq_s8(r_reg, l_reg); - vst1q_s8(llrs_ptr, result); + int16x8_t result = vqaddq_s16(r_reg, l_reg); + vst1q_s16(llrs_ptr, result); // Store R(n,m) for the next iteration - vst1q_s8(r_ptr, r_reg); + vst1q_s16(r_ptr, r_reg); // Increment pointers r_ptr += num_lanes; @@ -1231,49 +1230,49 @@ void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, min2_ptr += num_lanes; } - // Process a group of 8 elts - if (tail_cnt > 7U) { + // Process a group of 4 elts + if (tail_cnt > 3U) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8x8_t l_reg = vld1_s8(l_ptr); - int8x8_t sign_reg = vld1_s8(sign_ptr); - int8x8_t eor_reg = veor_s8(sign_reg, l_reg); + int16x4_t l_reg = vld1_s16(l_ptr); + int16x4_t sign_reg = vld1_s16(sign_ptr); + int16x4_t eor_reg = veor_s16(sign_reg, l_reg); // sign_col_reg will contain a 1 in all lanes for negative values - uint8x8_t sign_col_reg = vcltz_s8(eor_reg); + uint16x4_t sign_col_reg = vcltz_s16(eor_reg); // Compute R(n,m) - int8x8_t min_reg = vld1_s8(min_ptr); - int8x8_t min2_reg = vld1_s8(min2_ptr); - int8x8_t abs_reg = vqabs_s8(l_reg); + int16x4_t min_reg = vld1_s16(min_ptr); + int16x4_t min2_reg = vld1_s16(min2_ptr); + int16x4_t abs_reg = vqabs_s16(l_reg); // Get a mask for the minimum value in a lane - uint8x8_t check_eq = vceq_s8(abs_reg, min_reg); - int8x8_t tmp_reg = vbsl_s8(check_eq, min2_reg, min_reg); + uint16x4_t check_eq = vceq_s16(abs_reg, min_reg); + int16x4_t tmp_reg = vbsl_s16(check_eq, min2_reg, min_reg); // Negate the absolute values - int8x8_t neg_abs_reg = vneg_s8(tmp_reg); - int8x8_t r_reg = vbsl_s8(sign_col_reg, neg_abs_reg, tmp_reg); + int16x4_t neg_abs_reg = vneg_s16(tmp_reg); + int16x4_t r_reg = vbsl_s16(sign_col_reg, neg_abs_reg, tmp_reg); // Compute LLR(n) = R(n,m) + L(n,m) - int8x8_t result = vqadd_s8(r_reg, l_reg); - vst1_s8(llrs_ptr, result); + int16x4_t result = vqadd_s16(r_reg, l_reg); + vst1_s16(llrs_ptr, result); // Store R(n,m) for the next iteration - vst1_s8(r_ptr, r_reg); + vst1_s16(r_ptr, r_reg); - tail_cnt = d->z & 0x7; + tail_cnt = d->z & 0x3; } // Process tail for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the product) - int8_t col_sign = (row_sign_array[zb] ^ l[d->z * col + zb]) < 0 ? -1 : 1; + int16_t col_sign = (row_sign_array[zb] ^ l[d->z * col + zb]) < 0 ? -1 : 1; // Compute R(n,m) - int8_t abs_val = vqabsb_s8(l[d->z * col + zb]); - int8_t r_val = + int16_t abs_val = vqabsh_s16(l[d->z * col + zb]); + int16_t r_val = col_sign * (abs_val == row_min_array[zb] ? row_min2_array[zb] : row_min_array[zb]); // Compute LLR(n) = R(n,m) + L(n,m) - llrs[col_ind + shift + zb] = vqaddb_s8(r_val, *(l + d->z * col + zb)); + llrs[col_ind + shift + zb] = vqaddh_s16(r_val, *(l + d->z * col + zb)); // Store R(n,m) for the next iteration r[d->z * col + zb] = r_val; @@ -1282,21 +1281,22 @@ void compute_r_and_llrs(const int8_t *l, int8_t *r, int8_t *llrs, // Rearrange LLRs // copy shifted elements in the replicated block // back to the beginning of the main block - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int8_t)); + memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); // copy (z - shift) elts in the main block to the replicated block memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int8_t)); + (d->z - shift) * sizeof(int16_t)); } #endif } template void __attribute__((flatten)) -run_iterations(uint32_t num_its, int z, int lsi, - const armral_ldpc_base_graph_t *graph, int8_t *r, int8_t *l, - int8_t *new_llrs, int num_lanes, int full_vec, int tail_size, - int8_t *row_min_array, int8_t *row_min2_array, - int8_t *row_sign_array, int8_t *check, bool check_convergence, +run_iterations(uint32_t num_its, uint32_t z, uint32_t lsi, + const armral_ldpc_base_graph_t *graph, int16_t *r, int16_t *l, + int16_t *new_llrs, int32_t num_lanes, int32_t full_vec, + uint32_t tail_size, int16_t *row_min_array, + int16_t *row_min2_array, int16_t *row_sign_array, int16_t *check, + bool check_convergence, std::optional> &crc_checker) { for (uint32_t i = 0; i < num_its; ++i) { ldpc_layer_data d(z, lsi, graph); @@ -1360,34 +1360,34 @@ void armral::ldpc::decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, uint32_t layer_size = (graph->row_start_inds[2] - graph->row_start_inds[1]) * z; // We need to keep a record of matrix L (variable-to-check-node messages) - auto l = allocate_uninitialized(allocator, layer_size); + auto l = allocate_uninitialized(allocator, layer_size); // We need to keep a record of matrix R (check-to-variable-node messages) - auto r = allocate_zeroed(allocator, mat_size); + auto r = allocate_zeroed(allocator, mat_size); - auto row_min_array = allocate_zeroed(allocator, z); - auto row_min2_array = allocate_zeroed(allocator, z); - auto row_sign_array = allocate_zeroed(allocator, z); + auto row_min_array = allocate_zeroed(allocator, z); + auto row_min2_array = allocate_zeroed(allocator, z); + auto row_sign_array = allocate_zeroed(allocator, z); - auto check = allocate_zeroed(allocator, z); + auto check = allocate_zeroed(allocator, z); #if ARMRAL_ARCH_SVE >= 2 bool z_is_tiny = (z == 2); #else - bool z_is_tiny = (z < 8); + bool z_is_tiny = (z < 4); #endif // Keep a record of the current, and previous values of the LLRs // Copy the inputs LLRs const auto *llrs_ptr = llrs; size_t new_llrs_size = num_llrs; - std::optional> maybe_out_llrs; + std::optional> maybe_out_llrs; if (!z_is_tiny) { // Double the storage required to replicate LLRs for optimization new_llrs_size *= 2; // Extra buffer to pack the LLRs again - maybe_out_llrs = allocate_uninitialized(allocator, num_llrs); + maybe_out_llrs = allocate_uninitialized(allocator, num_llrs); } - auto new_llrs = allocate_uninitialized(allocator, new_llrs_size); + auto new_llrs = allocate_uninitialized(allocator, new_llrs_size); // NOTE: All allocations are now done! if constexpr (Allocator::is_counting) { @@ -1396,18 +1396,24 @@ void armral::ldpc::decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, if (z_is_tiny) { // Set the value of the current LLRs from the ones passed in. - // We need to take account of the punctured columns - memset(new_llrs.get(), 0, 2 * z * sizeof(int8_t)); - memcpy(&new_llrs[2 * z], llrs, z * graph->ncodeword_bits * sizeof(int8_t)); + // We need to take account of the punctured columns. + // Also widen to int16_t for use in intermediate calculations. + memset(new_llrs.get(), 0, 2 * z * sizeof(int16_t)); + for (uint32_t i = 0; i < z * graph->ncodeword_bits; i++) { + new_llrs[2 * z + i] = (int16_t)llrs[i]; + } } else { // Each block of Z elements replicated b1|b1|b2|b2 ... - // We need to take account of the punctured columns - memset(new_llrs.get(), 0, 4 * z * sizeof(int8_t)); + // We need to take account of the punctured columns. + // Also widen to int16_t for use in intermediate calculations. + memset(new_llrs.get(), 0, 4 * z * sizeof(int16_t)); auto *new_llrs_ptr = &new_llrs[4 * z]; for (uint32_t num_block = 0; num_block < graph->ncodeword_bits; num_block++) { - memcpy(new_llrs_ptr, llrs_ptr, z * sizeof(int8_t)); - memcpy(new_llrs_ptr + z, llrs_ptr, z * sizeof(int8_t)); + for (uint32_t i = 0; i < z; i++) { + new_llrs_ptr[i] = (int16_t)llrs_ptr[i]; + new_llrs_ptr[z + i] = (int16_t)llrs_ptr[i]; + } new_llrs_ptr += 2 * z; llrs_ptr += z; } @@ -1415,19 +1421,19 @@ void armral::ldpc::decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, // Precompute number of full vector and tail #if ARMRAL_ARCH_SVE >= 2 - int32_t num_lanes = svcntb(); + int32_t num_lanes = svcnth(); int32_t full_vec = z / num_lanes; uint32_t tail_size = z % num_lanes; bool is_tail_only = (tail_size == z && !z_is_tiny); #else - int32_t num_lanes = 16; + int32_t num_lanes = 8; int32_t full_vec = z / num_lanes; uint32_t tail_size = z % num_lanes; bool is_tail_only = (tail_size == z && !z_is_tiny); if (is_tail_only) { - // tail size = Z - 8 - tail_size -= 8; + // tail size = Z - 4 + tail_size -= 4; } #endif @@ -1459,7 +1465,7 @@ void armral::ldpc::decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, for (uint32_t num_block = 0; num_block < graph->ncodeword_bits + 2; num_block++) { memcpy(out_llrs + num_block * z, &new_llrs[2 * num_block * z], - z * sizeof(int8_t)); + z * sizeof(int16_t)); } // Hard decode into the output variable