diff --git a/bench/UpperPHY/LDPC/Decoding/main.cpp b/bench/UpperPHY/LDPC/Decoding/main.cpp index 67c9fe829c26a74cd710149df04f07a59e7a7f6c..59840decef7bd3a1b98dde54e16ec7cdda5a4661 100755 --- a/bench/UpperPHY/LDPC/Decoding/main.cpp +++ b/bench/UpperPHY/LDPC/Decoding/main.cpp @@ -35,14 +35,14 @@ void run_ldpc_decoding_perf(armral_ldpc_graph_t bg, uint32_t z, std::vector buffer(buffer_size); for (uint32_t r = 0; r < num_reps; ++r) { buffer_bump_allocator allocator{buffer.data()}; - armral::ldpc::decode_block( + armral::ldpc::decode_block( llr_ptr, bg, z, crc_idx, num_its, out_ptr, allocator); } #else for (uint32_t r = 0; r < num_reps; ++r) { heap_allocator allocator{}; - armral::ldpc::decode_block( - llr_ptr, bg, z, crc_idx, num_its, out_ptr, allocator); + armral::ldpc::decode_block(llr_ptr, bg, z, crc_idx, num_its, + out_ptr, allocator); } #endif } diff --git a/include/armral.h b/include/armral.h index 3c8d0b93a9369463aa6d8ec397ce991cf1388fbe..ae446821c2515db1ab446cfc38dd7a512d4fcfe8 100644 --- a/include/armral.h +++ b/include/armral.h @@ -106,7 +106,8 @@ extern "C" { */ typedef enum { ARMRAL_SUCCESS = 0, ///< No error. - ARMRAL_ARGUMENT_ERROR = -1, ///< One or more arguments are incorrect. + ARMRAL_ARGUMENT_ERROR = -1, ///< One or more arguments are incorrect + ARMRAL_RESULT_FAIL = -2 ///< Result failed. } armral_status; /** @@ -3563,7 +3564,7 @@ uint32_t armral_ldpc_encode_block_noalloc_buffer_size(armral_ldpc_graph_t bg, * * @param[in] llrs The initial LLRs to use in the decoding. This is * typically the output after demodulation and rate - * recovery. + * recovery. Supports 8 bit llrs in q1.7. * @param[in] bg The type of base graph to use for the decoding. * @param[in] z The lifting size. Valid values of the lifting size are * described in table 5.3.2-1 in TS 38.212. @@ -3574,8 +3575,8 @@ uint32_t armral_ldpc_encode_block_noalloc_buffer_size(armral_ldpc_graph_t bg, * run. The algorithm may terminate after fewer iterations * if the current candidate codeword passes all the parity * checks, or if it satisfies the CRC check. - * @param[out] data_out The decoded bits. These are of length `68 * z` for base - * graph 1 and `52 * z` for base graph 2. It is assumed + * @param[out] data_out The decoded bits. These are of length `22 * z` for base + * graph 1 and `10 * z` for base graph 2. It is assumed * that the array `data_out` is able to store this many * bits. * @return An `armral_status` value that indicates success or failure. @@ -3613,7 +3614,7 @@ armral_status armral_ldpc_decode_block(const int8_t *llrs, * * @param[in] llrs The initial LLRs to use in the decoding. This is * typically the output after demodulation and rate - * recovery. + * recovery. Supports 8 bit llrs in q1.7. * @param[in] bg The type of base graph to use for the decoding. * @param[in] z The lifting size. Valid values of the lifting size are * described in table 5.3.2-1 in TS 38.212. @@ -3624,8 +3625,8 @@ armral_status armral_ldpc_decode_block(const int8_t *llrs, * run. The algorithm may terminate after fewer iterations * if the current candidate codeword passes all the parity * checks, or if it satisfies the CRC check. - * @param[out] data_out The decoded bits. These are of length `68 * z` for base - * graph 1 and `52 * z` for base graph 2. It is assumed + * @param[out] data_out The decoded bits. These are of length `22 * z` for base + * graph 1 and `10 * z` for base graph 2. It is assumed * that the array `data_out` is able to store this many * bits. * @param[in] buffer Workspace buffer to be used internally. diff --git a/simulation/ldpc_awgn/ldpc_awgn.cpp b/simulation/ldpc_awgn/ldpc_awgn.cpp index 13f59cd9f76718bd94946fc4ee3a9f7917f3bc6e..06331567aad3451bbf0597c3fb7d3b07e74c9384 100644 --- a/simulation/ldpc_awgn/ldpc_awgn.cpp +++ b/simulation/ldpc_awgn/ldpc_awgn.cpp @@ -228,7 +228,7 @@ int run_check(armral::utils::random_state *state, uint32_t z, // To make it easier to compare the values, convert the bit array to a byte // array - armral::bits_to_bytes(data->len_out, data->data_decoded, + armral::bits_to_bytes(data->len_in, data->data_decoded, data->data_decoded_bytes); // Check the number of errors in decoding @@ -243,7 +243,7 @@ int run_check(armral::utils::random_state *state, uint32_t z, // For the remainder of the columns check that the data is the same // as what came out of the encoding const uint8_t *out_ptr = data->data_decoded_bytes + 2 * z; - for (uint32_t i = 0; i < data->len_out - 2 * z; ++i) { + for (uint32_t i = 0; i < data->len_in - 2 * z; ++i) { if (out_ptr[i] != data->data_encoded_bytes[i]) { num_errors++; } diff --git a/src/UpperPHY/LDPC/arm_ldpc_decoder.cpp b/src/UpperPHY/LDPC/arm_ldpc_decoder.cpp index 9eaffd97d8c7058eec49b173d846b43d48d7fa4c..c00af85d01272c9cc9cc2e705885d2b596d949c4 100644 --- a/src/UpperPHY/LDPC/arm_ldpc_decoder.cpp +++ b/src/UpperPHY/LDPC/arm_ldpc_decoder.cpp @@ -17,1478 +17,1600 @@ #include #include -namespace { - -struct ldpc_layer_data { - uint32_t z; - uint32_t lsi; - uint32_t row; - uint32_t row_start_ind; - const armral_ldpc_base_graph_t *graph; - uint32_t num_cols; - const uint32_t *shift_ptr; - const uint32_t *col_ptr; - - ldpc_layer_data(uint32_t z_in, uint32_t lsi_in, - const armral_ldpc_base_graph_t *graph_in) - : z(z_in), lsi(lsi_in), row(0), row_start_ind(0), graph(graph_in), - num_cols(graph->row_start_inds[1]), - shift_ptr(graph->shifts + lsi * num_cols), col_ptr(graph->col_inds) {} - - void next() { - row++; - row_start_ind = graph->row_start_inds[row]; - col_ptr += num_cols; - num_cols = graph->row_start_inds[row + 1] - row_start_ind; - shift_ptr = graph->shifts + row_start_ind * armral::ldpc::num_lifting_sets + - lsi * num_cols; - } -}; +namespace armral::ldpc { -template -inline T max(T a, T b) { - return a > b ? a : b; +// compute number of half words supported +inline uint32_t get_num_lanes() { +#if ARMRAL_ARCH_SVE >= 2 + return svcnth(); +#else + return 8; +#endif } -template -inline T min(T a, T b) { - return a < b ? a : b; -} +// Check nodes process the received information, update it, and send it back to +// the connected variable nodes. +// l is updated belief. +// r is extrinsic information. +// min values, signs and sign products are passed as input argument to update +// the belief and store the extrinsic information. +void update_l_and_r(int16_t *__restrict__ l, int16_t *__restrict__ r, + const armral_ldpc_base_graph_t *graph, uint16_t z, + uint32_t lsi, uint16_t layer, + const int16_t *__restrict__ row_min1_array, + const int16_t *__restrict__ row_min2_array, + const int16_t *__restrict__ row_sign_array, + const uint16_t *__restrict__ row_pos_array, + const int16_t *__restrict__ sign_scratch, + uint32_t *__restrict__ r_index) { + + const uint32_t *col_indices; + uint32_t i; + uint32_t j; + uint32_t r_i = *r_index; + uint32_t num_lanes = get_num_lanes(); + + i = graph->row_start_inds[layer]; + // Get the number of nonzero entries in the row + j = graph->row_start_inds[layer + 1] - i; + col_indices = graph->col_inds + i; + const uint32_t *shift_ptr = graph->shifts + i * 8 + lsi * j; + + const int16_t *sgn_scratch_buf = sign_scratch; -enum lifting_size_category { CAT_TINY, CAT_TAIL, CAT_LARGE }; +#if ARMRAL_ARCH_SVE >= 2 -template -class crc_checker { -public: - crc_checker(uint32_t z, uint32_t crc_idx, Allocator &allocator) : m_z(z) { - // Calculate K', which is the number of info bits + CRC bits (i.e. the - // non-filler bits of the code block) - m_k_prime = crc_idx + 24; - - // The CRC calculation routine expects a particular size of input - // (n % 16 = 0 where n is the number of bytes), which requires padding - // the input to the required size - m_buffer_size = (m_k_prime + 7) / 8; - m_total_bits = m_k_prime; - if (m_k_prime % 128 != 0) { - m_num_pad_bits = 128 - (m_k_prime % 128); - m_total_bits = m_k_prime + m_num_pad_bits; - m_buffer_size = m_total_bits >> 3; - } + svbool_t pg = svptrue_b16(); - m_llrs = allocate_uninitialized(allocator, m_total_bits + m_z - 1); - m_buffer = allocate_uninitialized(allocator, m_buffer_size); - } + // for each column i.e only non -1's + for (uint16_t col = 0; col < j; col++) { + uint32_t col_block = col_indices[col]; + + int16_t *ptr_r = &r[r_i * z]; + uint32_t shift = shift_ptr[col] % z; + + const int16_t *min1_buf = row_min1_array; + const int16_t *min2_buf = row_min2_array; + const int16_t *sgn_buf = row_sign_array; + const uint16_t *pos_buf = row_pos_array; + + svuint16_t pos_current = svdup_n_u16(col); + + uint32_t blk1 = (z - shift) / num_lanes; + uint32_t blk2 = shift / num_lanes; + uint32_t tail1 = (z - shift) & (num_lanes - 1); + uint32_t tail2 = (shift) & (num_lanes - 1); + svbool_t pg_tail1 = svwhilelt_b16(0U, tail1); + svbool_t pg_tail2 = svwhilelt_b16(0U, tail2); + + // Loop over z + // shift to z-1 + int16_t *ptr_l = &l[col_block * z + shift]; // Input,point to shift3 + + for (uint32_t v_cnt = 0; v_cnt < blk1; v_cnt++) { + + svint16_t min1 = svld1_s16(pg, min1_buf); + svint16_t min2 = svld1_s16(pg, min2_buf); + svuint16_t pos = svld1_u16(pg, pos_buf); + + // check if this the column matching position for the min1 + svbool_t pos_mask = svcmpeq_u16(pg, pos, pos_current); - bool check(const int16_t *new_llrs) { - // Copy the LLRs corresponding to the bits we need to do the CRC check after - // the padding bits - memset(m_llrs.get(), 0, m_num_pad_bits * sizeof(int16_t)); - for (uint32_t num_block = 0; num_block < ((m_k_prime + m_z - 1) / m_z); - num_block++) { - memcpy(m_llrs.get() + m_num_pad_bits + (num_block * m_z), - new_llrs + (2 * num_block * m_z), m_z * sizeof(int16_t)); + // if yes replace min1 with min2, otherwise min1 + svint16_t merged_mins = svsel_s16(pos_mask, min2, min1); + + // apply sign + svint16_t signs = svld1_s16(pg, sgn_scratch_buf); + merged_mins = svmul_s16_x(pg, merged_mins, signs); + + // apply sign product + svint16_t sign_prod = svld1_s16(pg, sgn_buf); + merged_mins = svmul_s16_x(pg, merged_mins, sign_prod); + + // update r + svst1_s16(pg, ptr_r, merged_mins); + + // update l + svint16_t llrs_reg = svld1_s16(pg, ptr_l); + llrs_reg = svadd_s16_x(pg, llrs_reg, merged_mins); + svst1_s16(pg, ptr_l, llrs_reg); + + ptr_l += num_lanes; + ptr_r += num_lanes; + sgn_scratch_buf += num_lanes; + sgn_buf += num_lanes; + min1_buf += num_lanes; + min2_buf += num_lanes; + pos_buf += num_lanes; } - // Hard decode - armral::llrs_to_bits(m_total_bits, m_llrs.get(), m_buffer.get()); + if (tail1 > 0U) { + svint16_t min1 = svld1_s16(pg_tail1, min1_buf); + svint16_t min2 = svld1_s16(pg_tail1, min2_buf); + svuint16_t pos = svld1_u16(pg_tail1, pos_buf); - // Generate the CRC parity bits - uint64_t crc; - armral_crc24_b_be(m_buffer_size, (const uint64_t *)m_buffer.get(), &crc); + // check if this the column matching position for the min1 + svbool_t pos_mask = svcmpeq_u16(pg_tail1, pos, pos_current); - // If the CRC is zero then the code block has been correctly decoded and we - // can terminate the iterations early - return (crc == 0); - } + // if yes replace min1 with min2, otherwise min1 + svint16_t merged_mins = svsel_s16(pos_mask, min2, min1); -private: - uint32_t m_z{0}; - uint32_t m_k_prime{0}; - uint32_t m_buffer_size{0}; - uint32_t m_num_pad_bits{0}; - uint32_t m_total_bits{0}; - unique_ptr m_llrs; - unique_ptr m_buffer; -}; - -template -bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, - const armral_ldpc_base_graph_t *graph, int32_t num_lanes, - int32_t full_vec, uint32_t tail_size, int16_t *check); - -template<> -bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, - const armral_ldpc_base_graph_t *graph, - int32_t num_lanes, int32_t full_vec, - uint32_t tail_size, int16_t *check) { - // Loop through the rows in the base graph - bool passed = true; - for (uint32_t row = 0; row < graph->nrows && passed; ++row) { - auto row_start_ind = graph->row_start_inds[row]; - auto num_cols = graph->row_start_inds[row + 1] - row_start_ind; - const auto *col_ptr = graph->col_inds + row_start_ind; - const auto *shift_ptr = graph->shifts + - row_start_ind * armral::ldpc::num_lifting_sets + - lsi * num_cols; - // Loop through the rows in the block - for (uint32_t zb = 0; zb < z && passed; ++zb) { - // Loop through the columns in the row - int16_t scal_check = 0; - for (uint32_t col = 0; col < num_cols; ++col) { - auto shift = (shift_ptr[col] + zb) % z; - auto codeword_ind = col_ptr[col] * z + shift; - scal_check ^= llrs[codeword_ind]; - } - passed &= scal_check >= 0; + // apply sign + svint16_t signs = svld1_s16(pg_tail1, sgn_scratch_buf); + merged_mins = svmul_s16_x(pg_tail1, merged_mins, signs); + + // apply sign product + svint16_t sign_prod = svld1_s16(pg_tail1, sgn_buf); + merged_mins = svmul_s16_x(pg_tail1, merged_mins, sign_prod); + + // update r + svst1_s16(pg_tail1, ptr_r, merged_mins); + + // update l + svint16_t llrs_reg = svld1_s16(pg_tail1, ptr_l); + llrs_reg = svadd_s16_x(pg_tail1, llrs_reg, merged_mins); + svst1_s16(pg_tail1, ptr_l, llrs_reg); + + ptr_l += tail1; + ptr_r += tail1; + sgn_scratch_buf += tail1; + sgn_buf += tail1; + min1_buf += tail1; + min2_buf += tail1; + pos_buf += tail1; } - } - return passed; -} -template<> -bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, - const armral_ldpc_base_graph_t *graph, - int32_t num_lanes, int32_t full_vec, - uint32_t tail_size, int16_t *check) { - // Loop through the rows in the base graph - bool passed = true; -#if ARMRAL_ARCH_SVE >= 2 - svbool_t pg_tail = svwhilelt_b16(0U, tail_size); - for (uint32_t row = 0; row < graph->nrows && passed; ++row) { - auto row_start_ind = graph->row_start_inds[row]; - auto num_cols = graph->row_start_inds[row + 1] - row_start_ind; - const auto *col_ptr = graph->col_inds + row_start_ind; - const auto *shift_ptr = graph->shifts + - row_start_ind * armral::ldpc::num_lifting_sets + - lsi * num_cols; - memset(check, 0, z * sizeof(int16_t)); - - // Loop through the columns - for (uint32_t col = 0; col < num_cols; ++col) { - auto shift = (shift_ptr[col] % z); - auto codeword_ind = col_ptr[col] * (2 * z) + shift; - - // No need to loop here, as there is only a tail - const int16_t *llrs_ptr = llrs + codeword_ind; - int16_t *check_ptr = check; - - svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); - svint16_t check_reg = svld1_s16(pg_tail, check_ptr); - svint16_t result_reg = sveor_s16_x(pg_tail, check_reg, llrs_reg); - svst1_s16(pg_tail, check_ptr, result_reg); + // 0 to shift-1 + ptr_l = &l[col_block * z]; // point to start + + for (uint32_t v_cnt = 0; v_cnt < blk2; v_cnt++) { + + svint16_t min1 = svld1_s16(pg, min1_buf); + svint16_t min2 = svld1_s16(pg, min2_buf); + svuint16_t pos = svld1_u16(pg, pos_buf); + + // check if this the column matching position for the min1 + svbool_t pos_mask = svcmpeq_u16(pg, pos, pos_current); + + // if yes replace min1 with min2, otherwise min1 + svint16_t merged_mins = svsel_s16(pos_mask, min2, min1); + + // apply sign + svint16_t signs = svld1_s16(pg, sgn_scratch_buf); + merged_mins = svmul_s16_x(pg, merged_mins, signs); + + // apply sign product + svint16_t sign_prod = svld1_s16(pg, sgn_buf); + merged_mins = svmul_s16_x(pg, merged_mins, sign_prod); + + // update r + svst1_s16(pg, ptr_r, merged_mins); + + // update l + svint16_t llrs_reg = svld1_s16(pg, ptr_l); + llrs_reg = svadd_s16_x(pg, llrs_reg, merged_mins); + svst1_s16(pg, ptr_l, llrs_reg); + + ptr_l += num_lanes; + ptr_r += num_lanes; + sgn_scratch_buf += num_lanes; + sgn_buf += num_lanes; + min1_buf += num_lanes; + min2_buf += num_lanes; + pos_buf += num_lanes; } - for (uint32_t zb = 0; zb < z && passed; ++zb) { - passed &= check[zb] >= 0; + + if (tail2 > 0U) { + svint16_t min1 = svld1_s16(pg_tail2, min1_buf); + svint16_t min2 = svld1_s16(pg_tail2, min2_buf); + svuint16_t pos = svld1_u16(pg_tail2, pos_buf); + + // check if this the column matching position for the min1 + svbool_t pos_mask = svcmpeq_u16(pg_tail2, pos, pos_current); + + // if yes replace min1 with min2, otherwise min1 + svint16_t merged_mins = svsel_s16(pos_mask, min2, min1); + + // apply sign + svint16_t signs = svld1_s16(pg_tail2, sgn_scratch_buf); + merged_mins = svmul_s16_x(pg_tail2, merged_mins, signs); + + // apply sign product + svint16_t sign_prod = svld1_s16(pg_tail2, sgn_buf); + merged_mins = svmul_s16_x(pg_tail2, merged_mins, sign_prod); + + // update r + svst1_s16(pg_tail2, ptr_r, merged_mins); + + // update l + svint16_t llrs_reg = svld1_s16(pg_tail2, ptr_l); + llrs_reg = svadd_s16_x(pg_tail2, llrs_reg, merged_mins); + svst1_s16(pg_tail2, ptr_l, llrs_reg); + + ptr_l += tail2; + ptr_r += tail2; + sgn_scratch_buf += tail2; + sgn_buf += tail2; + min1_buf += tail2; + min2_buf += tail2; + pos_buf += tail2; } + + r_i++; } + #else - for (uint32_t row = 0; row < graph->nrows && passed; ++row) { - auto row_start_ind = graph->row_start_inds[row]; - auto num_cols = graph->row_start_inds[row + 1] - row_start_ind; - const auto *col_ptr = graph->col_inds + row_start_ind; - const auto *shift_ptr = graph->shifts + - row_start_ind * armral::ldpc::num_lifting_sets + - lsi * num_cols; - memset(check, 0, z * sizeof(int16_t)); - - // Loop through the columns - for (uint32_t col = 0; col < num_cols; ++col) { - auto shift = (shift_ptr[col] % z); - auto codeword_ind = col_ptr[col] * (2 * z) + shift; - - // Loop through the rows in the block - const int16_t *llrs_ptr = llrs + codeword_ind; - int16_t *check_ptr = check; - - int16x4_t llrs_reg = vld1_s16(llrs_ptr); - int16x4_t check_reg = vld1_s16(check_ptr); - int16x4_t result_reg = veor_s16(check_reg, llrs_reg); - vst1_s16(check_ptr, result_reg); - - // Deal with a tail - for (uint32_t zb = z - tail_size; zb < z; ++zb) { - check[zb] ^= llrs_ptr[zb]; - } - } - for (uint32_t zb = 0; zb < z && passed; ++zb) { - passed &= check[zb] >= 0; + + // for each column i.e only non -1's + for (uint16_t col = 0; col < j; col++) { + uint32_t col_block = col_indices[col]; + + int16_t *ptr_r = &r[r_i * z]; + uint32_t shift = shift_ptr[col] % z; + + const int16_t *min1_buf = row_min1_array; + const int16_t *min2_buf = row_min2_array; + const int16_t *sgn_buf = row_sign_array; // set to 0 + const uint16_t *pos_buf = row_pos_array; + + uint16x8_t pos_current = {col, col, col, col, col, col, col, col}; + uint16x4_t pos_current_4 = {col, col, col, col}; + + uint32_t blk1 = (z - shift) / num_lanes; + uint32_t blk2 = shift / num_lanes; + uint32_t tail1 = (z - shift) & (num_lanes - 1); + uint32_t tail2 = (shift) & (num_lanes - 1); + + // Loop over z + // shift to z-1 + int16_t *ptr_l = &l[col_block * z + shift]; // Input,point to shift3 + + for (uint32_t v_cnt = 0; v_cnt < blk1; v_cnt++) { + + int16x8_t min1 = vld1q_s16(min1_buf); + int16x8_t min2 = vld1q_s16(min2_buf); + uint16x8_t pos = vld1q_u16(pos_buf); + + // check if this the column matching position for the min1 + uint16x8_t pos_mask = vceqq_u16(pos, pos_current); + + // if yes replace min1 with min2, otherwise min1 + int16x8_t merged_mins = vbslq_s16(pos_mask, min2, min1); + + // apply sign + int16x8_t signs = vld1q_s16(sgn_scratch_buf); + merged_mins = vmulq_s16(merged_mins, signs); + + // apply sign product + int16x8_t sign_prod = vld1q_s16(sgn_buf); + merged_mins = vmulq_s16(merged_mins, sign_prod); + + // update r + vst1q_s16(ptr_r, merged_mins); + + // update l + int16x8_t llrs_reg = vld1q_s16(ptr_l); + llrs_reg = vqaddq_s16(llrs_reg, merged_mins); + vst1q_s16(ptr_l, llrs_reg); + + ptr_l += num_lanes; + ptr_r += num_lanes; + sgn_scratch_buf += num_lanes; + sgn_buf += num_lanes; + min1_buf += num_lanes; + min2_buf += num_lanes; + pos_buf += num_lanes; } - } -#endif - return passed; -} -template<> -bool parity_check(const int16_t *llrs, uint32_t z, uint32_t lsi, - const armral_ldpc_base_graph_t *graph, - int32_t num_lanes, int32_t full_vec, - uint32_t tail_size, int16_t *check) { -#if ARMRAL_ARCH_SVE >= 2 - svbool_t pg = svptrue_b16(); - svbool_t pg_tail = svwhilelt_b16(0U, tail_size); - - // Loop through the rows in the base graph - bool passed = true; - for (uint32_t row = 0; row < graph->nrows && passed; ++row) { - auto row_start_ind = graph->row_start_inds[row]; - auto num_cols = graph->row_start_inds[row + 1] - row_start_ind; - const auto *col_ptr = graph->col_inds + row_start_ind; - const auto *shift_ptr = graph->shifts + - row_start_ind * armral::ldpc::num_lifting_sets + - lsi * num_cols; - memset(check, 0, z * sizeof(int16_t)); - - // Loop through the columns - for (uint32_t col = 0; col < num_cols; ++col) { - auto shift = (shift_ptr[col] % z); - auto codeword_ind = col_ptr[col] * (2 * z) + shift; - // Loop through the rows in the block - - // The check can be done on the LLRs instead of on the bit values, as - // there is a one-to-one transform between LLRs and bit. Negative LLRs - // represent a hard decision for the bit to be one, and non-negative - // values represent a zero. Hence the check needs to xor all LLRs - // and then assert that the result is non-negative. - const int16_t *llrs_ptr = llrs + codeword_ind; - int16_t *check_ptr = check; - - for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); - svint16_t check_reg = svld1_s16(pg, check_ptr); - svint16_t result_reg = sveor_s16_x(pg, check_reg, llrs_reg); - svst1_s16(pg, check_ptr, result_reg); - - // Increment pointers - llrs_ptr += num_lanes; - check_ptr += num_lanes; + if (tail1 > 0U) { + + if (tail1 > 3U) { + + int16x4_t min1 = vld1_s16(min1_buf); + int16x4_t min2 = vld1_s16(min2_buf); + uint16x4_t pos = vld1_u16(pos_buf); + + // check if this the column matching position for the min1 + uint16x4_t pos_mask = vceq_u16(pos, pos_current_4); + + // if yes replace min1 with min2, otherwise min1 + int16x4_t merged_mins = vbsl_s16(pos_mask, min2, min1); + + // apply sign + int16x4_t signs = vld1_s16(sgn_scratch_buf); + merged_mins = vmul_s16(merged_mins, signs); + + // apply sign product + int16x4_t sign_prod = vld1_s16(sgn_buf); + merged_mins = vmul_s16(merged_mins, sign_prod); + + // update r + vst1_s16(ptr_r, merged_mins); + + // update l + int16x4_t llrs_reg = vld1_s16(ptr_l); + llrs_reg = vqadd_s16(llrs_reg, merged_mins); + vst1_s16(ptr_l, llrs_reg); + + ptr_l += 4; + ptr_r += 4; + sgn_scratch_buf += 4; + sgn_buf += 4; + min1_buf += 4; + min2_buf += 4; + pos_buf += 4; + tail1 = (z - shift) & 0x3; } - // Process tail - if (tail_size != 0) { - svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); - svint16_t check_reg = svld1_s16(pg_tail, check_ptr); - svint16_t result_reg = sveor_s16_x(pg_tail, check_reg, llrs_reg); - svst1_s16(pg_tail, check_ptr, result_reg); + + if (tail1 > 0U) { + for (uint32_t t_cnt = 0; t_cnt < tail1; t_cnt++) { + uint16_t pos = pos_buf[t_cnt]; + int16_t min1 = min1_buf[t_cnt]; + int16_t min2 = min2_buf[t_cnt]; + int16_t val = (pos == col) ? min2 : min1; + val = sgn_scratch_buf[t_cnt] * val; + ptr_r[t_cnt] = sgn_buf[t_cnt] * val; + ptr_l[t_cnt] = ptr_l[t_cnt] + ptr_r[t_cnt]; + } + + ptr_l += tail1; + ptr_r += tail1; + sgn_scratch_buf += tail1; + sgn_buf += tail1; + min1_buf += tail1; + min2_buf += tail1; + pos_buf += tail1; } } - for (uint32_t zb = 0; zb < z && passed; ++zb) { - passed &= check[zb] >= 0; + + // 0 to shift-1 + ptr_l = &l[col_block * z]; // point to start + for (uint32_t v_cnt = 0; v_cnt < blk2; v_cnt++) { + + int16x8_t min1 = vld1q_s16(min1_buf); + int16x8_t min2 = vld1q_s16(min2_buf); + uint16x8_t pos = vld1q_u16(pos_buf); + + // check if this the column matching position for the min1 + uint16x8_t pos_mask = vceqq_u16(pos, pos_current); + + // if yes replace min1 with min2, otherwise min1 + int16x8_t merged_mins = vbslq_s16(pos_mask, min2, min1); + + // apply sign + int16x8_t signs = vld1q_s16(sgn_scratch_buf); + merged_mins = vmulq_s16(merged_mins, signs); + + // apply sign product + int16x8_t sign_prod = vld1q_s16(sgn_buf); + merged_mins = vmulq_s16(merged_mins, sign_prod); + + // update r + vst1q_s16(ptr_r, merged_mins); + + // update l + int16x8_t llrs_reg = vld1q_s16(ptr_l); + llrs_reg = vqaddq_s16(llrs_reg, merged_mins); + vst1q_s16(ptr_l, llrs_reg); + + ptr_l += num_lanes; + ptr_r += num_lanes; + sgn_scratch_buf += num_lanes; + sgn_buf += num_lanes; + min1_buf += num_lanes; + min2_buf += num_lanes; + pos_buf += num_lanes; } - } - return passed; -#else - // Loop through the rows in the base graph - bool passed = true; - for (uint32_t row = 0; row < graph->nrows && passed; ++row) { - auto row_start_ind = graph->row_start_inds[row]; - auto num_cols = graph->row_start_inds[row + 1] - row_start_ind; - const auto *col_ptr = graph->col_inds + row_start_ind; - const auto *shift_ptr = graph->shifts + - row_start_ind * armral::ldpc::num_lifting_sets + - lsi * num_cols; - memset(check, 0, z * sizeof(int16_t)); - - // Loop through the columns - for (uint32_t col = 0; col < num_cols; ++col) { - auto shift = (shift_ptr[col] % z); - auto codeword_ind = col_ptr[col] * (2 * z) + shift; - uint32_t tail_cnt = tail_size; - // Loop through the rows in the block - - // The check can be done on the LLRs instead of on the bit values, as - // there is a one-to-one transform between LLRs and bit. Negative LLRs - // represent a hard decision for the bit to be one, and non-negative - // values represent a zero. Hence the check needs to xor all LLRs - // and then assert that the result is non-negative. - const int16_t *llrs_ptr = llrs + codeword_ind; - int16_t *check_ptr = check; - - // Process 8 entries at a time - for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - int16x8_t llrs_reg = vld1q_s16(llrs_ptr); - int16x8_t check_reg = vld1q_s16(check_ptr); - int16x8_t result_reg = veorq_s16(check_reg, llrs_reg); - vst1q_s16(check_ptr, result_reg); - - // Increment pointers - llrs_ptr += num_lanes; - check_ptr += num_lanes; - } - // Process a group of 4 elts - if (tail_cnt > 3U) { - int16x4_t llrs_reg = vld1_s16(llrs_ptr); - int16x4_t check_reg = vld1_s16(check_ptr); - int16x4_t result_reg = veor_s16(check_reg, llrs_reg); - vst1_s16(check_ptr, result_reg); - tail_cnt = z & 0x3; + + if (tail2 > 0U) { + + if (tail2 > 3U) { + + int16x4_t min1 = vld1_s16(min1_buf); + int16x4_t min2 = vld1_s16(min2_buf); + uint16x4_t pos = vld1_u16(pos_buf); + + // check if this the column matches position for the min1 + uint16x4_t pos_mask = vceq_u16(pos, pos_current_4); + + // if yes replace min1 with min2, otherwise min1 + int16x4_t merged_mins = vbsl_s16(pos_mask, min2, min1); + + // apply sign + int16x4_t signs = vld1_s16(sgn_scratch_buf); + merged_mins = vmul_s16(merged_mins, signs); + + // apply sign product + int16x4_t sign_prod = vld1_s16(sgn_buf); + merged_mins = vmul_s16(merged_mins, sign_prod); + + // update r + vst1_s16(ptr_r, merged_mins); + + // update l + int16x4_t llrs_reg = vld1_s16(ptr_l); + llrs_reg = vqadd_s16(llrs_reg, merged_mins); + vst1_s16(ptr_l, llrs_reg); + + ptr_l += 4; + ptr_r += 4; + sgn_scratch_buf += 4; + sgn_buf += 4; + min1_buf += 4; + min2_buf += 4; + pos_buf += 4; + tail2 = (shift) & 0x3; } - // Deal with a tail - for (uint32_t zb = z - tail_cnt; zb < z; ++zb) { - check[zb] ^= llrs[codeword_ind + zb]; + + if (tail2 > 0U) { + for (uint32_t t_cnt = 0; t_cnt < tail2; t_cnt++) { + uint16_t pos = pos_buf[t_cnt]; + int16_t min1 = min1_buf[t_cnt]; + int16_t min2 = min2_buf[t_cnt]; + int16_t val = (pos == col) ? min2 : min1; + val = sgn_scratch_buf[t_cnt] * val; + ptr_r[t_cnt] = sgn_buf[t_cnt] * val; + ptr_l[t_cnt] = ptr_l[t_cnt] + ptr_r[t_cnt]; + } + + ptr_l += tail2; + ptr_r += tail2; + sgn_scratch_buf += tail2; + sgn_buf += tail2; + min1_buf += tail2; + min2_buf += tail2; + pos_buf += tail2; } } - for (uint32_t zb = 0; zb < z && passed; ++zb) { - passed &= check[zb] >= 0; - } + + r_i++; } - return passed; #endif -} -// For each check node m in the layer, compute: -// - the variable-to-check-node messages L(n,m) for each variable node n in -// \psi(m), where \psi(m) is the set of variable nodes connected to m: -// L(n,m) = LLR(n) - R(n,m) -// - the products \prod_{n' \in \psi(m)} L(n',m) (they will be used to compute -// sign(R(n,m)) in a second step) -// - \min_{n \in \psi(m)} |L(n,m)| and the second minimum (they will be used to -// compute |R(n,m)| in a second step) -template -void compute_l_product_min1_and_min2( - int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, - const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, - uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, - int16_t *row_sign_array); - -template<> -void compute_l_product_min1_and_min2( - int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, - const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, - uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, - int16_t *row_sign_array) { - const auto *r_ptr = r; - // Loop through the Z rows in the layer (check node m) - for (uint32_t zb = 0; zb < d->z; ++zb) { - // Loop through the columns in the row (variable node n in psi(m)) - // Column 0 - auto shift = (d->shift_ptr[0] + zb) % d->z; - int16_t l_val = vqsubh_s16(llrs[d->col_ptr[0] * d->z + shift], *(r_ptr++)); - - int16_t row_sign = l_val; - - int16_t row_min = vqabsh_s16(l_val); - - *(l++) = l_val; - - // Column 1 - shift = (d->shift_ptr[1] + zb) % d->z; - l_val = vqsubh_s16(llrs[d->col_ptr[1] * d->z + shift], *(r_ptr++)); - - row_sign ^= l_val; - - int16_t abs_val = vqabsh_s16(l_val); - int16_t row_min2 = max(row_min, abs_val); - row_min = min(row_min, abs_val); - - *(l++) = l_val; - - // Columns n >= 2 - for (uint32_t col = 2; col < d->num_cols; ++col) { - // Compute L(n,m) = LLR(n) - R(n,m) - shift = (d->shift_ptr[col] + zb) % d->z; - l_val = vqsubh_s16(llrs[d->col_ptr[col] * d->z + shift], *(r_ptr++)); - - // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - row_sign ^= l_val; - - // Compute the min(|L(n,m)|) and the second minimum - abs_val = vqabsh_s16(l_val); - row_min2 = max(row_min, min(row_min2, abs_val)); - row_min = min(row_min, abs_val); - - // Store L(n,m) - *(l++) = l_val; - } - - // Store the two minima and the product for Z rows - row_min_array[zb] = row_min; - row_min2_array[zb] = row_min2; - row_sign_array[zb] = row_sign; - } + // update r index for next layer + *r_index = r_i; } -template<> -void compute_l_product_min1_and_min2( - int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, - const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, - uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, - int16_t *row_sign_array) { - // Case for lifting sizes Z such as 4 <= Z < 8 +// Variable nodes transmit their belief information to the connected check +// nodes. Decoding alogrithm implemented is scaled offset min-sum. outputs mins, +// signs and sign product to update the total belief. +void compute_l_r_and_mins(int16_t *__restrict__ l, int16_t *__restrict__ r, + const armral_ldpc_base_graph_t *graph, uint16_t z, + uint32_t lsi, uint16_t layer, + int16_t *__restrict__ row_min1_array, + int16_t *__restrict__ row_min2_array, + int16_t *__restrict__ row_sign_array, + uint16_t *__restrict__ row_pos_array, + int16_t *__restrict__ sign_scratch, + uint32_t *__restrict__ r_index) { + + const uint32_t *col_indices; + uint32_t i; + uint32_t j; + uint32_t r_i = *r_index; + uint32_t t_i = 0; + uint32_t num_lanes = get_num_lanes(); + + i = graph->row_start_inds[layer]; + // Get the number of nonzero entries in the row + j = graph->row_start_inds[layer + 1] - i; + col_indices = graph->col_inds + i; + const uint32_t *shift_ptr = graph->shifts + i * 8 + lsi * j; + + int16_t *sgn_scratch_buf = sign_scratch; + #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg_tail = svwhilelt_b16(0U, tail_size); + svbool_t pg = svptrue_b16(); + svint16_t offset = svdup_n_s16(2); - // Loop through the columns in the row (variable node n in psi(m)) - // Column 0 - int16_t *l_ptr = l; - auto shift = d->shift_ptr[0] % d->z; - const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int16_t *r_ptr = r; + svint16_t plus1 = svdup_n_s16(1); + svint16_t minus1 = svdup_n_s16(-1); - svint16_t r_reg = svld1_s16(pg_tail, r_ptr); - svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); - svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); + // for each column i.e only non -1's + for (uint32_t col = 0; col < j; col++) { + int16_t *min1_buf = row_min1_array; + int16_t *min2_buf = row_min2_array; + int16_t *sgn_buf = row_sign_array; // set to 0 + uint16_t *pos_buf = row_pos_array; - svint16_t row_sign = l_reg; + uint32_t col_block = col_indices[col]; - svint16_t row_min = svqabs_s16_x(pg_tail, l_reg); + int16_t *ptr_r = &r[r_i * z]; + uint32_t shift = shift_ptr[col] % z; - svst1_s16(pg_tail, l_ptr, l_reg); + uint32_t blk1 = (z - shift) / num_lanes; + uint32_t blk2 = shift / num_lanes; - // Column 1 - l_ptr = l + d->z; - shift = d->shift_ptr[1] % d->z; - llrs_ptr = llrs + d->col_ptr[1] * (2 * d->z) + shift; - r_ptr = r + d->z; + uint32_t tail1 = (z - shift) & (num_lanes - 1); + uint32_t tail2 = (shift) & (num_lanes - 1); - r_reg = svld1_s16(pg_tail, r_ptr); - llrs_reg = svld1_s16(pg_tail, llrs_ptr); - l_reg = svqsub_s16(llrs_reg, r_reg); + svbool_t pg_tail1 = svwhilelt_b16(0U, tail1); + svbool_t pg_tail2 = svwhilelt_b16(0U, tail2); - row_sign = sveor_s16_x(pg_tail, row_sign, l_reg); + // Loop over z + // shift to z-1 + int16_t *ptr_l = &l[col_block * z + shift]; // Input,point to shift - svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); - svint16_t row_min2 = svmax_s16_x(pg_tail, row_min, abs_reg); - row_min = svmin_s16_x(pg_tail, row_min, abs_reg); + for (uint32_t v_cnt = 0; v_cnt < blk1; v_cnt++) { + svint16_t llrs_reg = svld1_s16(pg, ptr_l); + svint16_t r_reg = svld1_s16(pg, ptr_r); - svst1_s16(pg_tail, l_ptr, l_reg); + // Subtraction + svint16_t vec = svqsub_s16(llrs_reg, r_reg); - // Columns n >= 2 - for (uint32_t col = 2; col < d->num_cols; ++col) { - l_ptr = l + d->z * col; - shift = d->shift_ptr[col] % d->z; - llrs_ptr = llrs + d->col_ptr[col] * (2 * d->z) + shift; - r_ptr = r + d->z * col; + // Absolute + svint16_t abs_vec = svqabs_s16_x(pg, vec); - // Compute L(n,m) = LLR(n) - R(n,m) - r_reg = svld1_s16(pg_tail, r_ptr); - llrs_reg = svld1_s16(pg_tail, llrs_ptr); - l_reg = svqsub_s16(llrs_reg, r_reg); + // Sign product + svint16_t sgn = svld1_s16(pg, sgn_buf); + sgn = sveor_s16_x(pg, vec, sgn); + // store updated sign + svst1_s16(pg, sgn_buf, sgn); - // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - row_sign = sveor_s16_x(pg_tail, row_sign, l_reg); + // store signs + svbool_t is_positive = svcmpgt_s16(pg, vec, svdup_n_s16(-1)); + svint16_t signs = svsel_s16(is_positive, plus1, minus1); - // Compute the min(|L(n,m)|) and the second minimum - abs_reg = svqabs_s16_x(pg_tail, l_reg); - row_min2 = - svmax_s16_x(pg_tail, row_min, svmin_s16_x(pg_tail, row_min2, abs_reg)); - row_min = svmin_s16_x(pg_tail, row_min, abs_reg); + svst1_s16(pg, sgn_scratch_buf, signs); - // Store L(n,m) - svst1_s16(pg_tail, l_ptr, l_reg); - } + // store updated L + svst1_s16(pg, ptr_l, vec); - // Store the two minima and the product for Z rows - svst1_s16(pg_tail, row_min_array, row_min); - svst1_s16(pg_tail, row_min2_array, row_min2); - svst1_s16(pg_tail, row_sign_array, row_sign); -#else - // Loop through the columns in the row (variable node n in psi(m)) - // Column 0 - int16_t *l_ptr = l; - auto shift = d->shift_ptr[0] % d->z; - const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int16_t *r_ptr = r; + // Find min1 and min2 + svint16_t min1_old = svld1_s16(pg, min1_buf); + svint16_t min2_old = svld1_s16(pg, min2_buf); - int16x4_t r_reg = vld1_s16(r_ptr); - int16x4_t llrs_reg = vld1_s16(llrs_ptr); - int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); + svint16_t min2 = + svmax_s16_x(pg, min1_old, svmin_s16_x(pg, min2_old, abs_vec)); + svint16_t min1 = svmin_s16_x(pg, abs_vec, min1_old); - int16x4_t row_sign = l_reg; + // find min1 position + // check if the current min1 has changed w.r.t previous + // if it has changed, then update the index to current pos + svbool_t pos_mask = svcmpeq_s16(pg, min1, min1_old); + svuint16_t pos_old = svld1_u16(pg, pos_buf); + svuint16_t pos_cur = svdup_n_u16(col); + svuint16_t pos_updt = svsel_u16(pos_mask, pos_old, pos_cur); - int16x4_t row_min = vqabs_s16(row_sign); + svst1_s16(pg, min2_buf, min2); + svst1_s16(pg, min1_buf, min1); + svst1_u16(pg, pos_buf, pos_updt); - vst1_s16(l_ptr, l_reg); + ptr_l += num_lanes; + ptr_r += num_lanes; + min1_buf += num_lanes; + min2_buf += num_lanes; + sgn_buf += num_lanes; + pos_buf += num_lanes; + sgn_scratch_buf += num_lanes; + } - int16_t l_val; - for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { - l_val = vqsubh_s16(llrs_ptr[zb], r_ptr[zb]); + if (tail1 != 0U) { - row_sign_array[zb] = l_val; + svint16_t llrs_reg = svld1_s16(pg_tail1, ptr_l); + svint16_t r_reg = svld1_s16(pg_tail1, ptr_r); - row_min_array[zb] = vqabsh_s16(l_val); + // Subtraction + svint16_t vec = svqsub_s16(llrs_reg, r_reg); - l_ptr[zb] = l_val; - } + // Absolute + svint16_t abs_vec = svqabs_s16_x(pg_tail1, vec); - // Column 1 - l_ptr = l + d->z; - shift = d->shift_ptr[1] % d->z; - llrs_ptr = llrs + d->col_ptr[1] * (2 * d->z) + shift; - r_ptr = r + d->z; + // Sign product + svint16_t sgn = svld1_s16(pg_tail1, sgn_buf); + sgn = sveor_s16_x(pg_tail1, vec, sgn); + // store updated sign + svst1_s16(pg_tail1, sgn_buf, sgn); - r_reg = vld1_s16(r_ptr); - llrs_reg = vld1_s16(llrs_ptr); - l_reg = vqsub_s16(llrs_reg, r_reg); + // store signs + svbool_t is_positive = svcmpgt_s16(pg_tail1, vec, svdup_n_s16(-1)); + svint16_t signs = svsel_s16(is_positive, plus1, minus1); - row_sign = veor_s16(row_sign, l_reg); + svst1_s16(pg_tail1, sgn_scratch_buf, signs); - int16x4_t abs_reg = vqabs_s16(l_reg); - int16x4_t row_min2 = vmax_s16(row_min, abs_reg); - row_min = vmin_s16(row_min, abs_reg); + // store updated L + svst1_s16(pg_tail1, ptr_l, vec); - vst1_s16(l_ptr, l_reg); + // Find min1 and min2 + svint16_t min1_old = svld1_s16(pg_tail1, min1_buf); + svint16_t min2_old = svld1_s16(pg_tail1, min2_buf); - for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { - l_val = vqsubh_s16(llrs_ptr[zb], r_ptr[zb]); + svint16_t min2 = svmax_s16_x(pg_tail1, min1_old, + svmin_s16_x(pg_tail1, min2_old, abs_vec)); + svint16_t min1 = svmin_s16_x(pg_tail1, abs_vec, min1_old); - row_sign_array[zb] ^= l_val; + // find min1 position + // check if the current min1 has changed w.r.t previous + // if it has changed, then update the index to current pos + svbool_t pos_mask = svcmpeq_s16(pg_tail1, min1, min1_old); + svuint16_t pos_old = svld1_u16(pg_tail1, pos_buf); + svuint16_t pos_cur = svdup_n_u16(col); + svuint16_t pos_updt = svsel_u16(pos_mask, pos_old, pos_cur); - int16_t abs_val = vqabsh_s16(l_val); - row_min2_array[zb] = max(row_min_array[zb], abs_val); - row_min_array[zb] = min(row_min_array[zb], abs_val); + svst1_s16(pg_tail1, min2_buf, min2); + svst1_s16(pg_tail1, min1_buf, min1); + svst1_u16(pg_tail1, pos_buf, pos_updt); - l_ptr[zb] = l_val; - } + ptr_l += tail1; + ptr_r += tail1; + min1_buf += tail1; + min2_buf += tail1; + sgn_buf += tail1; + pos_buf += tail1; + sgn_scratch_buf += tail1; + } + + // 0 to shift-1 + + ptr_l = &l[col_block * z]; // point to start + for (uint32_t v_cnt = 0; v_cnt < blk2; v_cnt++) { + svint16_t llrs_reg = svld1_s16(pg, ptr_l); + svint16_t r_reg = svld1_s16(pg, ptr_r); + + // Subtraction + svint16_t vec = svqsub_s16(llrs_reg, r_reg); + + // Absolute + svint16_t abs_vec = svqabs_s16_x(pg, vec); + + // Sign product + svint16_t sgn = svld1_s16(pg, sgn_buf); + sgn = sveor_s16_x(pg, vec, sgn); + // store updated sign + svst1_s16(pg, sgn_buf, sgn); + + // store signs + svbool_t is_positive = svcmpgt_s16(pg, vec, svdup_n_s16(-1)); + svint16_t signs = svsel_s16(is_positive, plus1, minus1); + + svst1_s16(pg, sgn_scratch_buf, signs); + + // store updated L + svst1_s16(pg, ptr_l, vec); + + // Find min1 and min2 + svint16_t min1_old = svld1_s16(pg, min1_buf); + svint16_t min2_old = svld1_s16(pg, min2_buf); + + svint16_t min2 = + svmax_s16_x(pg, min1_old, svmin_s16_x(pg, min2_old, abs_vec)); + svint16_t min1 = svmin_s16_x(pg, abs_vec, min1_old); + + // find min1 position + // check if the current min1 has changed w.r.t previous + // if it has changed, then update the index to current pos + svbool_t pos_mask = svcmpeq_s16(pg, min1, min1_old); + svuint16_t pos_old = svld1_u16(pg, pos_buf); + svuint16_t pos_cur = svdup_n_u16(col); + svuint16_t pos_updt = svsel_u16(pos_mask, pos_old, pos_cur); + + svst1_s16(pg, min2_buf, min2); + svst1_s16(pg, min1_buf, min1); + svst1_u16(pg, pos_buf, pos_updt); + + ptr_l += num_lanes; + ptr_r += num_lanes; + min1_buf += num_lanes; + min2_buf += num_lanes; + sgn_buf += num_lanes; + pos_buf += num_lanes; + sgn_scratch_buf += num_lanes; + } + + if (tail2 != 0U) { - // Columns n >=2 - for (uint32_t col = 2; col < d->num_cols; ++col) { - l_ptr = l + d->z * col; - shift = d->shift_ptr[col] % d->z; - llrs_ptr = llrs + d->col_ptr[col] * (2 * d->z) + shift; - r_ptr = r + d->z * col; + svint16_t llrs_reg = svld1_s16(pg_tail2, ptr_l); + svint16_t r_reg = svld1_s16(pg_tail2, ptr_r); - // Compute L(n,m) = LLR(n) - R(n,m) - r_reg = vld1_s16(r_ptr); - llrs_reg = vld1_s16(llrs_ptr); - l_reg = vqsub_s16(llrs_reg, r_reg); + // Subtraction + svint16_t vec = svqsub_s16(llrs_reg, r_reg); - // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - row_sign = veor_s16(row_sign, l_reg); + // Absolute + svint16_t abs_vec = svqabs_s16_x(pg_tail2, vec); - // Compute the min(|L(n,m)|) and the second minimum - abs_reg = vqabs_s16(l_reg); - row_min2 = vmax_s16(row_min, vmin_s16(row_min2, abs_reg)); - row_min = vmin_s16(row_min, abs_reg); + // Sign product + svint16_t sgn = svld1_s16(pg_tail2, sgn_buf); + sgn = sveor_s16_x(pg_tail2, vec, sgn); + // store updated sign + svst1_s16(pg_tail2, sgn_buf, sgn); - // Store L(n,m) - vst1_s16(l_ptr, l_reg); + // store signs + svbool_t is_positive = svcmpgt_s16(pg_tail2, vec, svdup_n_s16(-1)); + svint16_t signs = svsel_s16(is_positive, plus1, minus1); - // Process tail - for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { - l_val = vqsubh_s16(llrs_ptr[zb], r_ptr[zb]); + svst1_s16(pg_tail2, sgn_scratch_buf, signs); - row_sign_array[zb] ^= l_val; + // store updated L + svst1_s16(pg_tail2, ptr_l, vec); - int16_t abs_val = vqabsh_s16(l_val); - row_min2_array[zb] = - max(row_min_array[zb], min(row_min2_array[zb], abs_val)); - row_min_array[zb] = min(row_min_array[zb], abs_val); + // Find min1 and min2 + svint16_t min1_old = svld1_s16(pg_tail2, min1_buf); + svint16_t min2_old = svld1_s16(pg_tail2, min2_buf); - l_ptr[zb] = l_val; + svint16_t min2 = svmax_s16_x(pg_tail2, min1_old, + svmin_s16_x(pg_tail2, min2_old, abs_vec)); + svint16_t min1 = svmin_s16_x(pg_tail2, abs_vec, min1_old); + + // find min1 position + // check if the current min1 has changed w.r.t previous + // if it has changed, then update the index to current pos + svbool_t pos_mask = svcmpeq_s16(pg_tail2, min1, min1_old); + svuint16_t pos_old = svld1_u16(pg_tail2, pos_buf); + svuint16_t pos_cur = svdup_n_u16(col); + svuint16_t pos_updt = svsel_u16(pos_mask, pos_old, pos_cur); + + svst1_s16(pg_tail2, min2_buf, min2); + svst1_s16(pg_tail2, min1_buf, min1); + svst1_u16(pg_tail2, pos_buf, pos_updt); + + ptr_l += tail2; + ptr_r += tail2; + min1_buf += tail2; + min2_buf += tail2; + sgn_buf += tail2; + pos_buf += tail2; + sgn_scratch_buf += tail2; } + + r_i++; + t_i++; } - // Store the two minima and the product for Z rows - vst1_s16(row_min_array, row_min); - vst1_s16(row_min2_array, row_min2); - vst1_s16(row_sign_array, row_sign); -#endif -} + *r_index = r_i - t_i; -template<> -void compute_l_product_min1_and_min2( - int16_t *l, const int16_t *__restrict__ llrs, const int16_t *__restrict__ r, - const ldpc_layer_data *d, int32_t num_lanes, int32_t full_vec, - uint32_t tail_size, int16_t *row_min_array, int16_t *row_min2_array, - int16_t *row_sign_array) { -#if ARMRAL_ARCH_SVE >= 2 - svbool_t pg = svptrue_b16(); - svbool_t pg_tail = svwhilelt_b16(0U, tail_size); + // offset and scale min1 and min2 + // in the same loop, adjust sign product + uint32_t blk = z / num_lanes; + uint32_t tail = z & (num_lanes - 1); + svbool_t pg_tail = svwhilelt_b16(0U, tail); - // Loop through the columns in the row (variable node n in psi(m)) - // Column 0 - int16_t *l_ptr = l; - auto shift = d->shift_ptr[0] % d->z; - const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int16_t *r_ptr = r; - int16_t *sign_ptr = row_sign_array; - int16_t *min_ptr = row_min_array; + svint16_t scale = svdup_n_s16(24576); // 0.75 - for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - svint16_t r_reg = svld1_s16(pg, r_ptr); - svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); - svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); + int16_t *min1_buf = row_min1_array; + int16_t *min2_buf = row_min2_array; + int16_t *sgn_buf = row_sign_array; // set to 0 - svst1_s16(pg, sign_ptr, l_reg); + for (uint16_t z1 = 0; z1 < blk; z1++) { - svst1_s16(pg, min_ptr, svqabs_s16_x(pg, l_reg)); + svint16_t sgn = svld1_s16(pg, sgn_buf); - svst1_s16(pg, l_ptr, l_reg); + svbool_t is_positive = svcmpgt_s16(pg, sgn, svdup_n_s16(-1)); + sgn = svsel_s16(is_positive, plus1, minus1); + svst1_s16(pg, sgn_buf, sgn); - sign_ptr += num_lanes; - min_ptr += num_lanes; - r_ptr += num_lanes; - l_ptr += num_lanes; - llrs_ptr += num_lanes; - } + svint16_t min1 = svld1_s16(pg, min1_buf); + svint16_t min2 = svld1_s16(pg, min2_buf); - if (tail_size != 0) { - svint16_t r_reg = svld1_s16(pg_tail, r_ptr); - svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); - svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); + // apply offset + svint16_t vec1 = svqsub_s16(min1, offset); + svint16_t vec2 = svqsub_s16(min2, offset); - svst1_s16(pg_tail, sign_ptr, l_reg); + // if min1 < 0, then min1 = 0; + is_positive = svcmpgt_s16(pg, vec1, svdup_n_s16(-1)); + vec1 = svsel_s16(is_positive, vec1, svdup_n_s16(0)); - svst1_s16(pg_tail, min_ptr, svqabs_s16_x(pg_tail, l_reg)); + // apply scale + svint32_t res = svmullt_s32(vec1, scale); + svint16_t mint = svqrshrnt_n_s32(svdup_n_s16(0), res, 15); + res = svmullb_s32(vec1, scale); + svint16_t minb = svqrshrnt_n_s32(svdup_n_s16(0), res, 15); + min1 = svtrn2_s16(minb, mint); - svst1_s16(pg_tail, l_ptr, l_reg); - } + // if min2 < 0, then min1 = 0; + is_positive = svcmpgt_s16(pg, vec2, svdup_n_s16(-1)); + vec2 = svsel_s16(is_positive, vec2, svdup_n_s16(0)); + + // apply scale + res = svmullt_s32(vec2, scale); + mint = svqrshrnt_n_s32(svdup_n_s16(0), res, 15); + res = svmullb_s32(vec2, scale); + minb = svqrshrnt_n_s32(svdup_n_s16(0), res, 15); + min2 = svtrn2_s16(minb, mint); - // Column 1 - shift = d->shift_ptr[1] % d->z; - l_ptr = l + d->z; - llrs_ptr = llrs + d->col_ptr[1] * (2 * d->z) + shift; - r_ptr = r + d->z; - sign_ptr = row_sign_array; - min_ptr = row_min_array; - int16_t *min2_ptr = row_min2_array; - - for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - svint16_t r_reg = svld1_s16(pg, r_ptr); - svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); - svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - - svint16_t sign_reg = svld1_s16(pg, sign_ptr); - svst1_s16(pg, sign_ptr, sveor_s16_x(pg, sign_reg, l_reg)); - - svint16_t min_reg = svld1_s16(pg, min_ptr); - svint16_t abs_reg = svqabs_s16_x(pg, l_reg); - svst1_s16(pg, min2_ptr, svmax_s16_x(pg, min_reg, abs_reg)); - svst1_s16(pg, min_ptr, svmin_s16_x(pg, min_reg, abs_reg)); - - svst1_s16(pg, l_ptr, l_reg); - - sign_ptr += num_lanes; - min_ptr += num_lanes; - min2_ptr += num_lanes; - r_ptr += num_lanes; - l_ptr += num_lanes; - llrs_ptr += num_lanes; + // store scaled, offseted min's + svst1_s16(pg, min1_buf, min1); + svst1_s16(pg, min2_buf, min2); + + min1_buf += num_lanes; + min2_buf += num_lanes; + sgn_buf += num_lanes; } - if (tail_size != 0) { - svint16_t r_reg = svld1_s16(pg_tail, r_ptr); - svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); - svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); + if (tail > 0U) { + + svint16_t sgn = svld1_s16(pg, sgn_buf); + + svbool_t is_positive = svcmpgt_s16(pg_tail, sgn, svdup_n_s16(-1)); + sgn = svsel_s16(is_positive, plus1, minus1); + svst1_s16(pg_tail, sgn_buf, sgn); + + svint16_t min1 = svld1_s16(pg_tail, min1_buf); + svint16_t min2 = svld1_s16(pg_tail, min2_buf); + + // apply offset + svint16_t vec1 = svqsub_s16(min1, offset); + svint16_t vec2 = svqsub_s16(min2, offset); + + // if min1 < 0, then min1 = 0; + is_positive = svcmpgt_s16(pg_tail, vec1, svdup_n_s16(-1)); + vec1 = svsel_s16(is_positive, vec1, svdup_n_s16(0)); - svint16_t sign_reg = svld1_s16(pg_tail, sign_ptr); - svst1_s16(pg_tail, sign_ptr, sveor_s16_x(pg_tail, sign_reg, l_reg)); + // apply scale + svint32_t res = svmullt_s32(vec1, scale); + svint16_t mint = svqrshrnt_n_s32(svdup_n_s16(0), res, 15); + res = svmullb_s32(vec1, scale); + svint16_t minb = svqrshrnt_n_s32(svdup_n_s16(0), res, 15); + min1 = svtrn2_s16(minb, mint); - svint16_t min_reg = svld1_s16(pg_tail, min_ptr); - svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); - svst1_s16(pg_tail, min2_ptr, svmax_s16_x(pg_tail, min_reg, abs_reg)); - svst1_s16(pg_tail, min_ptr, svmin_s16_x(pg_tail, min_reg, abs_reg)); + // if min2 < 0, then min1 = 0; + is_positive = svcmpgt_s16(pg_tail, vec2, svdup_n_s16(-1)); + vec2 = svsel_s16(is_positive, vec2, svdup_n_s16(0)); - svst1_s16(pg_tail, l_ptr, l_reg); + // apply scale + res = svmullt_s32(vec2, scale); + mint = svqrshrnt_n_s32(svdup_n_s16(0), res, 15); + res = svmullb_s32(vec2, scale); + minb = svqrshrnt_n_s32(svdup_n_s16(0), res, 15); + min2 = svtrn2_s16(minb, mint); + + // store scaled, offseted min's + svst1_s16(pg_tail, min1_buf, min1); + svst1_s16(pg_tail, min2_buf, min2); + + min1_buf += tail; + min2_buf += tail; + sgn_buf += tail; } - // Columns n >= 2 - for (uint32_t col = 2; col < d->num_cols; ++col) { - l_ptr = l + d->z * col; - shift = d->shift_ptr[col] % d->z; - llrs_ptr = llrs + d->col_ptr[col] * (2 * d->z) + shift; - r_ptr = r + d->z * col; - sign_ptr = row_sign_array; - min_ptr = row_min_array; - min2_ptr = row_min2_array; - - // Loop through the Z rows in the layer (check node m) - for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - // Compute L(n,m) = LLR(n) - R(n,m) - svint16_t r_reg = svld1_s16(pg, r_ptr); - svint16_t llrs_reg = svld1_s16(pg, llrs_ptr); - svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); - - // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - svint16_t sign_reg = svld1_s16(pg, sign_ptr); - svst1_s16(pg, sign_ptr, sveor_s16_x(pg, sign_reg, l_reg)); - - // Compute the min(|L(n,m)|) and the second minimum - svint16_t min_reg = svld1_s16(pg, min_ptr); - svint16_t min2_reg = svld1_s16(pg, min2_ptr); - svint16_t abs_reg = svqabs_s16_x(pg, l_reg); - svst1_s16(pg, min2_ptr, - svmax_s16_x(pg, min_reg, svmin_s16_x(pg, min2_reg, abs_reg))); - svst1_s16(pg, min_ptr, svmin_s16_x(pg, min_reg, abs_reg)); - - // Store L(n,m) - svst1_s16(pg, l_ptr, l_reg); - - sign_ptr += num_lanes; - min_ptr += num_lanes; - min2_ptr += num_lanes; - r_ptr += num_lanes; - l_ptr += num_lanes; - llrs_ptr += num_lanes; +#else + int16x8_t offset8 = vdupq_n_s16(2); + int16x4_t offset4 = vdup_n_s16(2); + + int16x8_t plus1 = vdupq_n_s16(1); + int16x8_t minus1 = vdupq_n_s16(-1); + + int16x4_t plus1_4 = vdup_n_s16(1); + int16x4_t minus1_4 = vdup_n_s16(-1); + + int16_t offset = 2; + + // for each column i.e only non -1's + for (uint32_t col = 0; col < j; col++) { + uint32_t col_block = col_indices[col]; + + int16_t *ptr_r = &r[r_i * z]; + uint32_t shift = shift_ptr[col] % z; + + uint32_t blk1 = (z - shift) / num_lanes; + uint32_t blk2 = shift / num_lanes; + uint32_t tail1 = (z - shift) & (num_lanes - 1); + uint32_t tail2 = (shift) & (num_lanes - 1); + + int16_t *min1_buf = row_min1_array; + int16_t *min2_buf = row_min2_array; + int16_t *sgn_buf = row_sign_array; // set to 0 + uint16_t *pos_buf = row_pos_array; + + // Loop over z + // shift to z-1 + int16_t *ptr_l = &l[col_block * z + shift]; // Input,point to shift + + for (uint32_t v_cnt = 0; v_cnt < blk1; v_cnt++) { + int16x8_t llrs_reg = vld1q_s16(ptr_l); + int16x8_t r_reg = vld1q_s16(ptr_r); + + // Subtraction + int16x8_t vec16 = vqsubq_s16(llrs_reg, r_reg); + + // Absoluate + int16x8_t abs_vec16 = vqabsq_s16(vec16); + + // Store signs + uint16x8_t is_positive = vcgtq_s16(vec16, vdupq_n_s16(-1)); + int16x8_t signs = vbslq_s16(is_positive, plus1, minus1); + vst1q_s16(sgn_scratch_buf, signs); + + // Sign product + int16x8_t old_sgn = vld1q_s16(sgn_buf); + int16x8_t sgn = vmulq_s16(signs, old_sgn); + // store updated sign + vst1q_s16(sgn_buf, sgn); + + // store updated L + vst1q_s16(ptr_l, vec16); + + // Find min1 and min2 + int16x8_t min1_old = vld1q_s16(min1_buf); + int16x8_t min2_old = vld1q_s16(min2_buf); + + int16x8_t min2 = vmaxq_s16(min1_old, vminq_s16(min2_old, abs_vec16)); + int16x8_t min1 = vminq_s16(abs_vec16, min1_old); + + // find min1 position + // check if the current min1 has changed w.r.t previous + // if it has changed, then update the index to current pos + uint16x8_t pos_mask = vceqq_s16(min1, min1_old); + uint16x8_t pos_old = vld1q_u16(pos_buf); + uint16x8_t pos_cur = vdupq_n_u16(col); + uint16x8_t pos_updt = vbslq_u16(pos_mask, pos_old, pos_cur); + + vst1q_s16(min2_buf, min2); + vst1q_s16(min1_buf, min1); + vst1q_u16(pos_buf, pos_updt); + + ptr_l += num_lanes; + ptr_r += num_lanes; + min1_buf += num_lanes; + min2_buf += num_lanes; + sgn_buf += num_lanes; + pos_buf += num_lanes; + sgn_scratch_buf += num_lanes; } - // Process tail - if (tail_size != 0) { - svint16_t r_reg = svld1_s16(pg_tail, r_ptr); - svint16_t llrs_reg = svld1_s16(pg_tail, llrs_ptr); - svint16_t l_reg = svqsub_s16(llrs_reg, r_reg); + if (tail1 > 0U) { + + if (tail1 > 3U) { + + int16x4_t llrs_reg = vld1_s16(ptr_l); + int16x4_t r_reg = vld1_s16(ptr_r); + + // Subtraction + int16x4_t vec16 = vqsub_s16(llrs_reg, r_reg); + + // Absolute + int16x4_t abs_vec16 = vqabs_s16(vec16); + + // Store signs + uint16x4_t is_positive = vcgt_s16(vec16, vdup_n_s16(-1)); + int16x4_t signs = vbsl_s16(is_positive, plus1_4, minus1_4); + vst1_s16(sgn_scratch_buf, signs); + + // Sign product + int16x4_t old_sgn = vld1_s16(sgn_buf); + int16x4_t sgn = vmul_s16(signs, old_sgn); + // store updated sign + vst1_s16(sgn_buf, sgn); + + // store updated L + vst1_s16(ptr_l, vec16); + + // Find min1 and min2 + int16x4_t min1_old = vld1_s16(min1_buf); + int16x4_t min2 = vld1_s16(min2_buf); + + min2 = vmax_s16(min1_old, vmin_s16(min2, abs_vec16)); + int16x4_t min1 = vmin_s16(abs_vec16, min1_old); + + // Find min1 pos + // check if the current min1 has changed w.r.t previous + // if it has changed, then update the index to current pos + uint16x4_t pos_mask = vceq_s16(min1, min1_old); + uint16x4_t pos_old = vld1_u16(pos_buf); + uint16x4_t pos_cur = vdup_n_u16(col); + uint16x4_t pos_updt = vbsl_u16(pos_mask, pos_old, pos_cur); + + vst1_s16(min2_buf, min2); + vst1_s16(min1_buf, min1); + vst1_u16(pos_buf, pos_updt); + + ptr_l += 4; + ptr_r += 4; + min1_buf += 4; + min2_buf += 4; + sgn_buf += 4; + pos_buf += 4; + sgn_scratch_buf += 4; + tail1 = (z - shift) & 0x3; + } + + if (tail1 > 0U) { + + int16_t val; + int8_t sign; + for (uint32_t t_cnt = 0; t_cnt < tail1; t_cnt++) { - svint16_t sign_reg = svld1_s16(pg_tail, sign_ptr); - svst1_s16(pg_tail, sign_ptr, sveor_s16_x(pg_tail, sign_reg, l_reg)); + val = vqsubh_s16(ptr_l[t_cnt], ptr_r[t_cnt]); + ptr_l[t_cnt] = val; - svint16_t min_reg = svld1_s16(pg_tail, min_ptr); - svint16_t min2_reg = svld1_s16(pg_tail, min2_ptr); - svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); - svst1_s16(pg_tail, min2_ptr, - svmax_s16_x(pg_tail, min_reg, - svmin_s16_x(pg_tail, min2_reg, abs_reg))); - svst1_s16(pg_tail, min_ptr, svmin_s16_x(pg_tail, min_reg, abs_reg)); + sign = (int8_t)(val >= 0); - svst1_s16(pg_tail, l_ptr, l_reg); + sgn_scratch_buf[t_cnt] = vqsubh_s16(2 * sign, 1); + sgn_buf[t_cnt] = sgn_scratch_buf[t_cnt] * sgn_buf[t_cnt]; + + val = vqabsh_s16(val); + + min2_buf[t_cnt] = + std::max(min1_buf[t_cnt], std::min(min2_buf[t_cnt], val)); + + if (min1_buf[t_cnt] > val) { + pos_buf[t_cnt] = col; + min1_buf[t_cnt] = val; + } + } + + ptr_r += tail1; + min1_buf += tail1; + min2_buf += tail1; + sgn_buf += tail1; + pos_buf += tail1; + sgn_scratch_buf += tail1; + } } - } -#else - // Loop through the columns in the row (variable node n in psi(m)) - // Column 0 - int16_t *l_ptr = l; - auto shift = d->shift_ptr[0] % d->z; - const int16_t *llrs_ptr = llrs + d->col_ptr[0] * (2 * d->z) + shift; - const int16_t *r_ptr = r; - int16_t *sign_ptr = row_sign_array; - int16_t *min_ptr = row_min_array; - uint32_t tail_cnt = tail_size; - - for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { - int16x8_t r_reg = vld1q_s16(r_ptr); - int16x8_t llrs_reg = vld1q_s16(llrs_ptr); - int16x8_t l_reg = vqsubq_s16(llrs_reg, r_reg); - - vst1q_s16(sign_ptr, l_reg); - - vst1q_s16(min_ptr, vqabsq_s16(l_reg)); - - vst1q_s16(l_ptr, l_reg); - - sign_ptr += num_lanes; - min_ptr += num_lanes; - r_ptr += num_lanes; - l_ptr += num_lanes; - llrs_ptr += num_lanes; - } - if (tail_cnt > 3U) { - int16x4_t r_reg = vld1_s16(r_ptr); - int16x4_t llrs_reg = vld1_s16(llrs_ptr); - int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); + // 0 to shift-1 + ptr_l = &l[col_block * z]; // point to start + for (uint32_t v_cnt = 0; v_cnt < blk2; v_cnt++) { + int16x8_t llrs_reg = vld1q_s16(ptr_l); + int16x8_t r_reg = vld1q_s16(ptr_r); + + // Subtraction + int16x8_t vec16 = vqsubq_s16(llrs_reg, r_reg); + + // Absoluate + int16x8_t abs_vec16 = vqabsq_s16(vec16); + + // Store signs + uint16x8_t is_positive = vcgtq_s16(vec16, vdupq_n_s16(-1)); + int16x8_t signs = vbslq_s16(is_positive, plus1, minus1); + vst1q_s16(sgn_scratch_buf, signs); + + // Sign product + int16x8_t old_sgn = vld1q_s16(sgn_buf); + int16x8_t sgn = vmulq_s16(signs, old_sgn); + // store updated sign + vst1q_s16(sgn_buf, sgn); + + // store updated L + vst1q_s16(ptr_l, vec16); + + // Find min1 and min2 + int16x8_t min1_old = vld1q_s16(min1_buf); + int16x8_t min2 = vld1q_s16(min2_buf); + + min2 = vmaxq_s16(min1_old, vminq_s16(min2, abs_vec16)); + int16x8_t min1 = vminq_s16(abs_vec16, min1_old); + + // find min1 position + // check if the current min1 has changed w.r.t previous + // if it has changed, then update the index to current pos + uint16x8_t pos_mask = vceqq_s16(min1, min1_old); + uint16x8_t pos_old = vld1q_u16(pos_buf); + uint16x8_t pos_cur = vdupq_n_u16(col); + uint16x8_t pos_updt = vbslq_u16(pos_mask, pos_old, pos_cur); + + vst1q_s16(min2_buf, min2); + vst1q_s16(min1_buf, min1); + vst1q_u16(pos_buf, pos_updt); + + ptr_l += num_lanes; + ptr_r += num_lanes; + min1_buf += num_lanes; + min2_buf += num_lanes; + sgn_buf += num_lanes; + pos_buf += num_lanes; + sgn_scratch_buf += num_lanes; + } - vst1_s16(sign_ptr, l_reg); + if (tail2 > 0U) { - vst1_s16(min_ptr, vqabs_s16(l_reg)); + if (tail2 > 3U) { - vst1_s16(l_ptr, l_reg); + int16x4_t llrs_reg = vld1_s16(ptr_l); + int16x4_t r_reg = vld1_s16(ptr_r); - tail_cnt = d->z & 0x3; - } + // Subtraction + int16x4_t vec16 = vqsub_s16(llrs_reg, r_reg); - for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { - l[zb] = vqsubh_s16(llrs[d->col_ptr[0] * (2 * d->z) + shift + zb], r[zb]); + // Absolute + int16x4_t abs_vec16 = vqabs_s16(vec16); - row_sign_array[zb] = l[zb]; + // Store signs + uint16x4_t is_positive = vcgt_s16(vec16, vdup_n_s16(-1)); + int16x4_t signs = vbsl_s16(is_positive, plus1_4, minus1_4); + vst1_s16(sgn_scratch_buf, signs); - row_min_array[zb] = vqabsh_s16(l[zb]); - } + // Sign product + int16x4_t old_sgn = vld1_s16(sgn_buf); + int16x4_t sgn = vmul_s16(signs, old_sgn); + // store updated sign + vst1_s16(sgn_buf, sgn); - // Column 1 - shift = d->shift_ptr[1] % d->z; - llrs_ptr = llrs + d->col_ptr[1] * (2 * d->z) + shift; - sign_ptr = row_sign_array; - min_ptr = row_min_array; - int16_t *min2_ptr = row_min2_array; - l_ptr = l + d->z; - r_ptr = r + d->z; - tail_cnt = tail_size; - - for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { - int16x8_t r_reg = vld1q_s16(r_ptr); - int16x8_t llrs_reg = vld1q_s16(llrs_ptr); - int16x8_t l_reg = vqsubq_s16(llrs_reg, r_reg); - - int16x8_t sign_reg = vld1q_s16(sign_ptr); - vst1q_s16(sign_ptr, veorq_s16(sign_reg, l_reg)); - - int16x8_t min_reg = vld1q_s16(min_ptr); - int16x8_t abs_reg = vqabsq_s16(l_reg); - vst1q_s16(min2_ptr, vmaxq_s16(min_reg, abs_reg)); - vst1q_s16(min_ptr, vminq_s16(min_reg, abs_reg)); - - vst1q_s16(l_ptr, l_reg); - - sign_ptr += num_lanes; - min_ptr += num_lanes; - min2_ptr += num_lanes; - r_ptr += num_lanes; - l_ptr += num_lanes; - llrs_ptr += num_lanes; - } + // store updated L + vst1_s16(ptr_l, vec16); + + // Find min1 and min2 + int16x4_t min1_old = vld1_s16(min1_buf); + int16x4_t min2 = vld1_s16(min2_buf); + + min2 = vmax_s16(min1_old, vmin_s16(min2, abs_vec16)); + int16x4_t min1 = vmin_s16(abs_vec16, min1_old); + + // Find min1 pos + // check if the current min1 has changed w.r.t previous + // if it has changed, then update the index to current pos + uint16x4_t pos_mask = vceq_s16(min1, min1_old); + uint16x4_t pos_old = vld1_u16(pos_buf); + uint16x4_t pos_cur = vdup_n_u16(col); + uint16x4_t pos_updt = vbsl_u16(pos_mask, pos_old, pos_cur); + + vst1_s16(min2_buf, min2); + vst1_s16(min1_buf, min1); + vst1_u16(pos_buf, pos_updt); + + ptr_l += 4; + ptr_r += 4; + min1_buf += 4; + min2_buf += 4; + sgn_buf += 4; + pos_buf += 4; + sgn_scratch_buf += 4; + tail2 = (shift) & 0x3; + } + + if (tail2 > 0U) { + int16_t val; + int8_t sign; + for (uint32_t t_cnt = 0; t_cnt < tail2; t_cnt++) { + + val = vqsubh_s16(ptr_l[t_cnt], ptr_r[t_cnt]); + ptr_l[t_cnt] = val; - if (tail_cnt > 3U) { - int16x4_t r_reg = vld1_s16(r_ptr); - int16x4_t llrs_reg = vld1_s16(llrs_ptr); - int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); + sign = (int8_t)(val >= 0); - int16x4_t sign_reg = vld1_s16(sign_ptr); - vst1_s16(sign_ptr, veor_s16(sign_reg, l_reg)); + sgn_scratch_buf[t_cnt] = vqsubh_s16(2 * sign, 1); + sgn_buf[t_cnt] = sgn_scratch_buf[t_cnt] * sgn_buf[t_cnt]; - int16x4_t min_reg = vld1_s16(min_ptr); - int16x4_t abs_reg = vqabs_s16(l_reg); - vst1_s16(min2_ptr, vmax_s16(min_reg, abs_reg)); - vst1_s16(min_ptr, vmin_s16(min_reg, abs_reg)); + val = vqabsh_s16(val); - vst1_s16(l_ptr, l_reg); + min2_buf[t_cnt] = + std::max(min1_buf[t_cnt], std::min(min2_buf[t_cnt], val)); - tail_cnt = d->z & 0x3; + if (min1_buf[t_cnt] > val) { + pos_buf[t_cnt] = col; + min1_buf[t_cnt] = val; + } + } + + ptr_r += tail2; + min1_buf += tail2; + min2_buf += tail2; + sgn_buf += tail2; + pos_buf += tail2; + sgn_scratch_buf += tail2; + } + } + + r_i++; + t_i++; } - for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { - l[d->z + zb] = - vqsubh_s16(llrs[d->col_ptr[1] * (2 * d->z) + shift + zb], r[d->z + zb]); + *r_index = r_i - t_i; + + // offset and scale min1 and min2 + // in the same loop, adjust sign product + uint32_t blk = z / num_lanes; + uint32_t tail = z & (num_lanes - 1); + + int16x4_t scale = vdup_n_s16(24576); // 0.75 - row_sign_array[zb] ^= l[d->z + zb]; + int16_t *min1_buf = row_min1_array; + int16_t *min2_buf = row_min2_array; - int16_t abs_val = vqabsh_s16(l[d->z + zb]); - row_min2_array[zb] = max(row_min_array[zb], abs_val); - row_min_array[zb] = min(row_min_array[zb], abs_val); + for (uint16_t z1 = 0; z1 < blk; z1++) { + + int16x8_t min1 = vld1q_s16(min1_buf); + int16x8_t min2 = vld1q_s16(min2_buf); + + // apply offset + int16x8_t vec1 = vqsubq_s16(min1, offset8); + int16x8_t vec2 = vqsubq_s16(min2, offset8); + + // if min1 < 0, then min1 = 0; + uint16x8_t is_positive = vcgtq_s16(vec1, vdupq_n_s16(-1)); + vec1 = vbslq_s16(is_positive, vec1, vdupq_n_s16(0)); + + // apply scale + int32x4_t mul_low = vmull_s16(vget_low_s16(vec1), scale); + int32x4_t mul_high = vmull_s16(vget_high_s16(vec1), scale); + int32x4_t shifted_low = vrshrq_n_s32(mul_low, 15); + int32x4_t shifted_high = vrshrq_n_s32(mul_high, 15); + min1 = vcombine_s16(vqmovn_s32(shifted_low), vqmovn_s32(shifted_high)); + + // if min2 < 0, then min1 = 0; + is_positive = vcgtq_s16(vec2, vdupq_n_s16(-1)); + vec2 = vbslq_s16(is_positive, vec2, vdupq_n_s16(0)); + + // apply scale + mul_low = vmull_s16(vget_low_s16(vec2), scale); + mul_high = vmull_s16(vget_high_s16(vec2), scale); + shifted_low = vrshrq_n_s32(mul_low, 15); + shifted_high = vrshrq_n_s32(mul_high, 15); + min2 = vcombine_s16(vqmovn_s32(shifted_low), vqmovn_s32(shifted_high)); + + // store scaled, offseted min's + vst1q_s16(min1_buf, min1); + vst1q_s16(min2_buf, min2); + + min1_buf += num_lanes; + min2_buf += num_lanes; } - // Columns n >= 2 - for (uint32_t col = 2; col < d->num_cols; ++col) { - sign_ptr = row_sign_array; - min_ptr = row_min_array; - min2_ptr = row_min2_array; - shift = d->shift_ptr[col] % d->z; - llrs_ptr = llrs + d->col_ptr[col] * (2 * d->z) + shift; - l_ptr = l + d->z * col; - r_ptr = r + d->z * col; - tail_cnt = tail_size; - - // Loop through the Z rows in the layer (check node m) - // Process 8 entries at a time - for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { - // Compute L(n,m) = LLR(n) - R(n,m) - int16x8_t r_reg = vld1q_s16(r_ptr); - int16x8_t llrs_reg = vld1q_s16(llrs_ptr); - int16x8_t l_reg = vqsubq_s16(llrs_reg, r_reg); - - // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - int16x8_t sign_reg = vld1q_s16(sign_ptr); - vst1q_s16(sign_ptr, veorq_s16(sign_reg, l_reg)); - - // Compute the min(|L(n,m)|) and the second minimum - int16x8_t min_reg = vld1q_s16(min_ptr); - int16x8_t min2_reg = vld1q_s16(min2_ptr); - int16x8_t abs_reg = vqabsq_s16(l_reg); - vst1q_s16(min2_ptr, vmaxq_s16(min_reg, vminq_s16(min2_reg, abs_reg))); - vst1q_s16(min_ptr, vminq_s16(min_reg, abs_reg)); - - // Store L(n,m) - vst1q_s16(l_ptr, l_reg); - - sign_ptr += num_lanes; - min_ptr += num_lanes; - min2_ptr += num_lanes; - r_ptr += num_lanes; - l_ptr += num_lanes; - llrs_ptr += num_lanes; - } + if (tail > 0U) { - // Process a group of 4 elts - if (tail_cnt > 3U) { - // Compute L(n,m) = LLR(n) - R(n,m) - int16x4_t r_reg = vld1_s16(r_ptr); - int16x4_t llrs_reg = vld1_s16(llrs_ptr); - int16x4_t l_reg = vqsub_s16(llrs_reg, r_reg); + if (tail > 3U) { - // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - int16x4_t sign_reg = vld1_s16(sign_ptr); - vst1_s16(sign_ptr, veor_s16(sign_reg, l_reg)); + int16x4_t min1 = vld1_s16(min1_buf); + int16x4_t min2 = vld1_s16(min2_buf); - // Compute the min(|L(n,m)|) and the second minimum - int16x4_t min_reg = vld1_s16(min_ptr); - int16x4_t min2_reg = vld1_s16(min2_ptr); - int16x4_t abs_reg = vqabs_s16(l_reg); - vst1_s16(min2_ptr, vmax_s16(min_reg, vmin_s16(min2_reg, abs_reg))); - vst1_s16(min_ptr, vmin_s16(min_reg, abs_reg)); + int16x4_t vec1 = vqsub_s16(min1, offset4); + int16x4_t vec2 = vqsub_s16(min2, offset4); - // Store L(n,m) - vst1_s16(l_ptr, l_reg); + uint16x4_t is_positive = vcgt_s16(vec1, vdup_n_s16(-1)); + vec1 = vbsl_s16(is_positive, vec1, vdup_n_s16(0)); - tail_cnt = d->z & 0x3; - } + int32x4_t res = vmull_s16(vec1, scale); + res = vrshrq_n_s32(res, 15); + min1 = vqmovn_s32(res); - // Process tail - for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { - // Compute L(n,m) = LLR(n) - R(n,m) - l[d->z * col + zb] = vqsubh_s16( - llrs[d->col_ptr[col] * (2 * d->z) + shift + zb], r[d->z * col + zb]); + is_positive = vcgt_s16(vec2, vdup_n_s16(-1)); + vec2 = vbsl_s16(is_positive, vec2, vdup_n_s16(0)); - // Compute the product of L(n',m), for all the columns (all n' in psi(m)) - row_sign_array[zb] ^= l[d->z * col + zb]; + res = vmull_s16(vec2, scale); + res = vrshrq_n_s32(res, 15); + min2 = vqmovn_s32(res); - // Compute the min(|L(n,m)|) and the second minimum - int16_t abs_val = vqabsh_s16(l[d->z * col + zb]); - row_min2_array[zb] = - max(row_min_array[zb], min(row_min2_array[zb], abs_val)); - row_min_array[zb] = min(row_min_array[zb], abs_val); + vst1_s16(min1_buf, min1); + vst1_s16(min2_buf, min2); + min1_buf += 4; + min2_buf += 4; + tail = (tail) & 0x3; + } + + if (tail > 0U) { + for (uint32_t t_cnt = 0; t_cnt < tail; t_cnt++) { + + int16_t min1 = min1_buf[t_cnt]; + int16_t min2 = min2_buf[t_cnt]; + + min1 -= offset; + min2 -= offset; + + if (min1 < 0) { + min1 = 0; + } + if (min2 < 0) { + min2 = 0; + } + int32_t t1 = (min1 * 24576) >> 15; + int32_t t2 = (min2 * 24576) >> 15; + min1_buf[t_cnt] = (int16_t)t1; + min2_buf[t_cnt] = (int16_t)t2; + } } } + #endif } -// For each check node m in the layer, compute: -// - The check-to-variable-node messages R(n,m) for each n in \psi(m), where -// \psi(m) is the set of variable nodes connected to check node m: -// sign(R(n,m)) = \prod_{n' \in \psi(m)/n} sign(L(n',m)) = -// = \prod_{n' \in \psi(m)} sign(L(n',m)) / sign(L(n,m)) -// |R(n,m)| = \min_{n' \in \psi(m)/n} |L(n',m)| = -// = the first minimum when n' != n, the second minimum otherwise -// - The log likelihood ratios for each n in \psi(m): -// LLR(n) = R(n,m) + L(n,m) -template -void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, - const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, uint32_t tail_size, - const int16_t *row_min_array, - const int16_t *row_min2_array, - const int16_t *row_sign_array); - -template<> -void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, - const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, uint32_t tail_size, - const int16_t *row_min_array, - const int16_t *row_min2_array, - const int16_t *row_sign_array) { - // Loop through the Z rows in the layer (check node m) - for (uint32_t zb = 0; zb < d->z; ++zb) { - const int16_t *l_ptr = l + zb * d->num_cols; - // Loop through the columns in the row (variable node n in psi(m)) - for (uint32_t col = 0; col < d->num_cols; ++col) { - // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the - // product) - int16_t col_sign = (row_sign_array[zb] ^ l_ptr[col]) < 0 ? -1 : 1; - - // Compute R(n,m) - int16_t abs_val = vqabsh_s16(l_ptr[col]); - int16_t r_val = - col_sign * (abs_val == row_min_array[zb] ? row_min2_array[zb] - : row_min_array[zb]); - - // Compute LLR(n) = R(n,m) + L(n,m) - auto shift = (d->shift_ptr[col] + zb) % d->z; - auto col_ind = d->col_ptr[col] * d->z + shift; - llrs[col_ind] = vqaddh_s16(r_val, l_ptr[col]); - - // Store R(n,m) for the next iteration - r[col] = r_val; +bool hard_decision(int16_t *ptr_l, uint8_t *crc_buff, uint8_t *ptr_data, + uint32_t k, uint32_t crc_flag) { + + uint32_t num_lanes = get_num_lanes(); + uint32_t k_prime = k + 24; + uint32_t full_vec = (k_prime) / num_lanes; + uint32_t tail_cnt = (k_prime) & (num_lanes - 1); + uint8_t *data = (uint8_t *)crc_buff; + uint32_t pad_bytes = 0; + + // if the decoded data is less than 8 bytes / not multiple of 8 bytes, prefix + // zero padding + if (crc_flag != 0U) { + if (((k_prime >> 3) % 16) != 0U) { + pad_bytes = 16 - ((k_prime >> 3) % 16); + memset(data, 0, pad_bytes); + data = data + pad_bytes; } } -} -template<> -void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, - const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, uint32_t tail_size, - const int16_t *row_min_array, - const int16_t *row_min2_array, - const int16_t *row_sign_array) { - // Case for lifting sizes 4 <= Z < 8 (rows in the layer) #if ARMRAL_ARCH_SVE >= 2 - svbool_t pg_tail = svwhilelt_b16(0U, tail_size); - - svint16_t row_min = svld1_s16(pg_tail, row_min_array); - svint16_t row_min2 = svld1_s16(pg_tail, row_min2_array); - svint16_t row_sign = svld1_s16(pg_tail, row_sign_array); - - // Loop through the columns in the row (variable node n in psi(m)) - for (uint32_t col = 0; col < d->num_cols; ++col) { - auto shift = d->shift_ptr[col] % d->z; - auto col_ind = d->col_ptr[col] * (2 * d->z); - int16_t *r_ptr = r + d->z * col; - const int16_t *l_ptr = l + d->z * col; - int16_t *llrs_ptr = llrs + col_ind + shift; - - // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the - // product) - svint16_t l_reg = svld1_s16(pg_tail, l_ptr); - svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); - svint16_t eor_reg = sveor_s16_x(pg_tail, row_sign, l_reg); - svbool_t pg_tail_neg = svcmplt_n_s16(pg_tail, eor_reg, 0); - - // Compute R(n,m) - svbool_t pg_tail_eq = svcmpeq_s16(pg_tail, abs_reg, row_min); - svint16_t tmp_reg = svsel_s16(pg_tail_eq, row_min2, row_min); - svint16_t r_reg = svneg_s16_m(tmp_reg, pg_tail_neg, tmp_reg); - - // Compute LLR(n) = R(n,m) + L(n,m) - svint16_t result = svqadd_s16_x(pg_tail, r_reg, l_reg); - svst1_s16(pg_tail, llrs_ptr, result); - - // Store R(n,m) for the next iteration - svst1_s16(pg_tail, r_ptr, r_reg); - - // Rearrange LLRs - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); - // copy (z - shift) elts in the main block to the replicated block - memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int16_t)); - } -#else - int16x4_t row_min = vld1_s16(row_min_array); - int16x4_t row_min2 = vld1_s16(row_min2_array); - int16x4_t row_sign = vld1_s16(row_sign_array); - - // Loop through the columns in the row (variable node n in psi(m)) - for (uint32_t col = 0; col < d->num_cols; ++col) { - auto shift = d->shift_ptr[col] % d->z; - auto col_ind = d->col_ptr[col] * (2 * d->z); - int16_t *r_ptr = r + d->z * col; - const int16_t *l_ptr = l + d->z * col; - int16_t *llrs_ptr = llrs + col_ind + shift; - - // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the - // product) - int16x4_t l_reg = vld1_s16(l_ptr); - int16x4_t abs_reg = vqabs_s16(l_reg); - int16x4_t eor_reg = veor_s16(row_sign, l_reg); - // sign_col_reg will contain a 1 in all lanes for negative values - uint16x4_t sign_col_reg = vcltz_s16(eor_reg); - - // Compute R(n,m) - // Get a mask for the minimum value in a lane - uint16x4_t check_eq = vceq_s16(abs_reg, row_min); - int16x4_t tmp_reg = vbsl_s16(check_eq, row_min2, row_min); - // Negate the absolute values - int16x4_t neg_abs_reg = vneg_s16(tmp_reg); - int16x4_t r_reg = vbsl_s16(sign_col_reg, neg_abs_reg, tmp_reg); - - // Compute LLR(n) = R(n,m) + L(n,m) - int16x4_t result = vqadd_s16(r_reg, l_reg); - vst1_s16(llrs_ptr, result); - - // Store R(n,m) for the next iteration - vst1_s16(r_ptr, r_reg); - - // Process tail - for (uint32_t zb = d->z - tail_size; zb < d->z; ++zb) { - // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the - // product) - int16_t col_sign = (row_sign_array[zb] ^ l[d->z * col + zb]) < 0 ? -1 : 1; - - // Compute R(n,m) - int16_t abs_val = vqabsh_s16(l[d->z * col + zb]); - int16_t r_val = - col_sign * (abs_val == row_min_array[zb] ? row_min2_array[zb] - : row_min_array[zb]); - - // Compute LLR(n) = R(n,m) + L(n,m) - llrs[col_ind + shift + zb] = vqaddh_s16(r_val, *(l + d->z * col + zb)); - - // Store R(n,m) for the next iteration - r[d->z * col + zb] = r_val; + svuint16_t ones = svdup_n_u16(1); + svuint16_t zeros = svdup_n_u16(0); + + svbool_t pg = svptrue_b16(); + svbool_t pg_tail = svwhilelt_b16(0U, tail_cnt); + svuint16_t shifts = svindex_u16(num_lanes - 1, -1); + + if (num_lanes == 8) { + for (uint32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { + + svint16_t d = svld1_s16(pg, ptr_l); + svbool_t is_negative = svcmpgt_s16(pg, d, svdup_n_s16(0)); + + svuint16_t bits = svsel_u16(is_negative, zeros, ones); + + svuint16_t byte = svlsl_u16_m(pg, bits, shifts); + *data++ = (uint8_t)svaddv_u16(pg, byte); + ptr_l += num_lanes; } - // Rearrange LLRs - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); - // copy (z - shift) elts in the main block to the replicated block - memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int16_t)); - } -#endif -} + if (tail_cnt != 0U) { + svint16_t d = svld1_s16(pg_tail, ptr_l); + svbool_t is_negative = svcmpgt_s16(pg_tail, d, svdup_n_s16(0)); -template<> -void compute_r_and_llrs(const int16_t *l, int16_t *r, int16_t *llrs, - const ldpc_layer_data *d, int32_t num_lanes, - int32_t full_vec, uint32_t tail_size, - const int16_t *row_min_array, - const int16_t *row_min2_array, - const int16_t *row_sign_array) { -#if ARMRAL_ARCH_SVE >= 2 - svbool_t pg = svptrue_b16(); - svbool_t pg_tail = svwhilelt_b16(0U, tail_size); - - // Loop through the columns in the row (variable node n in psi(m)) - for (uint32_t col = 0; col < d->num_cols; ++col) { - auto shift = d->shift_ptr[col] % d->z; - auto col_ind = d->col_ptr[col] * (2 * d->z); - int16_t *llrs_ptr = llrs + col_ind + shift; - const int16_t *l_ptr = l + d->z * col; - int16_t *r_ptr = r + d->z * col; - const int16_t *sign_ptr = row_sign_array; - const int16_t *min_ptr = row_min_array; - const int16_t *min2_ptr = row_min2_array; - - // Loop through the Z rows in the layer (check node m) - for (int32_t vec_idx = 0; vec_idx < full_vec; ++vec_idx) { - // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the - // product) - svint16_t l_reg = svld1_s16(pg, l_ptr); - svint16_t sign_reg = svld1_s16(pg, sign_ptr); - svint16_t eor_reg = sveor_s16_x(pg, sign_reg, l_reg); - svbool_t pg_neg = svcmplt_n_s16(pg, eor_reg, 0); - - // Compute R(n,m) - svint16_t min_reg = svld1_s16(pg, min_ptr); - svint16_t min2_reg = svld1_s16(pg, min2_ptr); - svint16_t abs_reg = svqabs_s16_x(pg, l_reg); - svbool_t pg_eq = svcmpeq_s16(pg, abs_reg, min_reg); - svint16_t tmp_reg = svsel_s16(pg_eq, min2_reg, min_reg); - svint16_t r_reg = svneg_s16_m(tmp_reg, pg_neg, tmp_reg); - - // Compute LLR(n) = R(n,m) + L(n,m) - svint16_t result = svqadd_s16_x(pg, r_reg, l_reg); - svst1_s16(pg, llrs_ptr, result); - - // Store R(n,m) for the next iteration - svst1_s16(pg, r_ptr, r_reg); - - // Increment pointers - l_ptr += num_lanes; - r_ptr += num_lanes; - llrs_ptr += num_lanes; - sign_ptr += num_lanes; - min_ptr += num_lanes; - min2_ptr += num_lanes; + svuint16_t bits = svsel_u16(is_negative, zeros, ones); + + svuint16_t byte = svlsl_u16_m(pg_tail, bits, shifts); + *data++ = (uint8_t)svaddv_u16(pg_tail, byte); + ptr_l += tail_cnt; } + } else if (num_lanes == 16) { + uint16_t *data_ptr = (uint16_t *)data; + + for (uint32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { + + svint16_t d = svld1_s16(pg, ptr_l); + svbool_t is_negative = svcmpgt_s16(pg, d, svdup_n_s16(0)); - if (tail_size != 0) { - // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the - // product) - svint16_t l_reg = svld1_s16(pg_tail, l_ptr); - svint16_t sign_reg = svld1_s16(pg_tail, sign_ptr); - svint16_t eor_reg = sveor_s16_x(pg_tail, sign_reg, l_reg); - svbool_t pg_tail_neg = svcmplt_n_s16(pg_tail, eor_reg, 0); - - // Compute R(n,m) - svint16_t min_reg = svld1_s16(pg_tail, min_ptr); - svint16_t min2_reg = svld1_s16(pg_tail, min2_ptr); - svint16_t abs_reg = svqabs_s16_x(pg_tail, l_reg); - svbool_t pg_tail_eq = svcmpeq_s16(pg_tail, abs_reg, min_reg); - svint16_t tmp_reg = svsel_s16(pg_tail_eq, min2_reg, min_reg); - svint16_t r_reg = svneg_s16_m(tmp_reg, pg_tail_neg, tmp_reg); - - // Compute LLR(n) = R(n,m) + L(n,m) - svint16_t result = svqadd_s16_x(pg_tail, r_reg, l_reg); - svst1_s16(pg_tail, llrs_ptr, result); - - // Store R(n,m) for the next iteration - svst1_s16(pg_tail, r_ptr, r_reg); + svuint16_t bits = svsel_u16(is_negative, zeros, ones); + svuint16_t word = svlsl_u16_m(pg, bits, shifts); + word = svrevb_u16_x(pg, word); + + *data_ptr++ = svaddv_u16(pg, word); + ptr_l += num_lanes; } - // Rearrange LLRs - // copy shifted elements in the replicated block - // back to the beginning of the main block - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); - // copy (z - shift) elts in the main block to the replicated block - memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int16_t)); + if (tail_cnt != 0U) { + svint16_t d = svld1_s16(pg_tail, ptr_l); + svbool_t is_negative = svcmpgt_s16(pg_tail, d, svdup_n_s16(0)); + + svuint16_t bits = svsel_u16(is_negative, zeros, ones); + svuint16_t word = svlsl_u16_m(pg_tail, bits, shifts); + word = svrevb_u16_x(pg_tail, word); + + *data_ptr++ = svaddv_u16(pg_tail, word); + } } + #else - // Loop through the columns in the row (variable node n in psi(m)) - for (uint32_t col = 0; col < d->num_cols; ++col) { - int16_t *r_ptr = r + d->z * col; - const int16_t *sign_ptr = row_sign_array; - const int16_t *min_ptr = row_min_array; - const int16_t *min2_ptr = row_min2_array; - auto shift = d->shift_ptr[col] % d->z; - auto col_ind = d->col_ptr[col] * (2 * d->z); - const int16_t *l_ptr = l + d->z * col; - int16_t *llrs_ptr = llrs + col_ind + shift; - uint32_t tail_cnt = tail_size; - - // Loop through the Z rows in the layer (check node m) - // Process 8 entries at a time - for (int32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { - // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the - // product) - int16x8_t l_reg = vld1q_s16(l_ptr); - int16x8_t sign_reg = vld1q_s16(sign_ptr); - int16x8_t eor_reg = veorq_s16(sign_reg, l_reg); - // sign_col_reg will contain a 1 in all lanes for negative values - uint16x8_t sign_col_reg = vcltzq_s16(eor_reg); - - // Compute R(n,m) - int16x8_t min_reg = vld1q_s16(min_ptr); - int16x8_t min2_reg = vld1q_s16(min2_ptr); - int16x8_t abs_reg = vqabsq_s16(l_reg); - // Get a mask for the minimum value in a lane - uint16x8_t check_eq = vceqq_s16(abs_reg, min_reg); - int16x8_t tmp_reg = vbslq_s16(check_eq, min2_reg, min_reg); - // Negate the absolute values - int16x8_t neg_abs_reg = vnegq_s16(tmp_reg); - int16x8_t r_reg = vbslq_s16(sign_col_reg, neg_abs_reg, tmp_reg); - - // Compute LLR(n) = R(n,m) + L(n,m) - int16x8_t result = vqaddq_s16(r_reg, l_reg); - vst1q_s16(llrs_ptr, result); - - // Store R(n,m) for the next iteration - vst1q_s16(r_ptr, r_reg); - - // Increment pointers - r_ptr += num_lanes; - l_ptr += num_lanes; - llrs_ptr += num_lanes; - sign_ptr += num_lanes; - min_ptr += num_lanes; - min2_ptr += num_lanes; - } + int8x8_t shifts = {7, 6, 5, 4, 3, 2, 1, 0}; + for (uint32_t v_cnt = 0; v_cnt < full_vec; v_cnt++) { - // Process a group of 4 elts - if (tail_cnt > 3U) { - // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the - // product) - int16x4_t l_reg = vld1_s16(l_ptr); - int16x4_t sign_reg = vld1_s16(sign_ptr); - int16x4_t eor_reg = veor_s16(sign_reg, l_reg); - // sign_col_reg will contain a 1 in all lanes for negative values - uint16x4_t sign_col_reg = vcltz_s16(eor_reg); - - // Compute R(n,m) - int16x4_t min_reg = vld1_s16(min_ptr); - int16x4_t min2_reg = vld1_s16(min2_ptr); - int16x4_t abs_reg = vqabs_s16(l_reg); - // Get a mask for the minimum value in a lane - uint16x4_t check_eq = vceq_s16(abs_reg, min_reg); - int16x4_t tmp_reg = vbsl_s16(check_eq, min2_reg, min_reg); - // Negate the absolute values - int16x4_t neg_abs_reg = vneg_s16(tmp_reg); - int16x4_t r_reg = vbsl_s16(sign_col_reg, neg_abs_reg, tmp_reg); - - // Compute LLR(n) = R(n,m) + L(n,m) - int16x4_t result = vqadd_s16(r_reg, l_reg); - vst1_s16(llrs_ptr, result); - - // Store R(n,m) for the next iteration - vst1_s16(r_ptr, r_reg); - - tail_cnt = d->z & 0x3; - } + int16x8_t d = vld1q_s16(ptr_l); - // Process tail - for (uint32_t zb = d->z - tail_cnt; zb < d->z; ++zb) { - // Compute the product of sign(L(n',m)) without L(n,m) (the sign of the - // product) - int16_t col_sign = (row_sign_array[zb] ^ l[d->z * col + zb]) < 0 ? -1 : 1; + uint16x4_t is_negative_high = vclt_s16(vget_high_s16(d), vdup_n_s16(0)); + is_negative_high = vand_u16(is_negative_high, vdup_n_u16(1)); - // Compute R(n,m) - int16_t abs_val = vqabsh_s16(l[d->z * col + zb]); - int16_t r_val = - col_sign * (abs_val == row_min_array[zb] ? row_min2_array[zb] - : row_min_array[zb]); + uint16x4_t is_negative_low = vclt_s16(vget_low_s16(d), vdup_n_s16(0)); + is_negative_low = vand_u16(is_negative_low, vdup_n_u16(1)); - // Compute LLR(n) = R(n,m) + L(n,m) - llrs[col_ind + shift + zb] = vqaddh_s16(r_val, *(l + d->z * col + zb)); + uint16x8_t vec16 = vcombine_u16(is_negative_low, is_negative_high); + uint8x8_t byte = vqmovn_u16(vec16); - // Store R(n,m) for the next iteration - r[d->z * col + zb] = r_val; + byte = vshl_u8(byte, shifts); + uint8_t byte1 = vaddv_u8(byte); + + *data++ = byte1; + ptr_l += num_lanes; + } + + if (tail_cnt != 0U) { + uint8_t tail_byte = 0; + uint8_t i = 0; + while (i < tail_cnt) { + tail_byte |= ((uint8_t)(*ptr_l++ < 0)) << (7 - i); + i++; } - // Rearrange LLRs - // copy shifted elements in the replicated block - // back to the beginning of the main block - memcpy(llrs + col_ind, llrs + col_ind + d->z, shift * sizeof(int16_t)); - // copy (z - shift) elts in the main block to the replicated block - memcpy(llrs + col_ind + d->z + shift, llrs + col_ind + shift, - (d->z - shift) * sizeof(int16_t)); + *data = tail_byte; } #endif + + // Generate the CRC parity bits + uint64_t crc = 0; + if (crc_flag != 0U) { + armral_crc24_b_be((k_prime >> 3) + pad_bytes, (const uint64_t *)crc_buff, + &crc); + // Removing the Zero padding + if (pad_bytes != 0U) { + for (uint32_t i = 0; i < (k_prime >> 3); i++) { + ptr_data[i] = crc_buff[i + pad_bytes]; + } + } + } else { + memcpy(ptr_data, crc_buff, (k + 7) >> 3); + } + + return (crc == 0U); } -template -void __attribute__((flatten)) -run_iterations(uint32_t num_its, uint32_t z, uint32_t lsi, - const armral_ldpc_base_graph_t *graph, int16_t *r, int16_t *l, - int16_t *new_llrs, int32_t num_lanes, int32_t full_vec, - uint32_t tail_size, int16_t *row_min_array, - int16_t *row_min2_array, int16_t *row_sign_array, int16_t *check, - bool check_convergence, - std::optional> &crc_checker) { - for (uint32_t i = 0; i < num_its; ++i) { - ldpc_layer_data d(z, lsi, graph); - auto *r_ptr = r; - - // Loop through the layers (groups of Z rows) - compute_l_product_min1_and_min2(l, new_llrs, r_ptr, &d, num_lanes, - full_vec, tail_size, row_min_array, - row_min2_array, row_sign_array); - compute_r_and_llrs(l, r_ptr, new_llrs, &d, num_lanes, full_vec, - tail_size, row_min_array, row_min2_array, - row_sign_array); - - for (uint32_t row = 1; row < graph->nrows; ++row) { - d.next(); - r_ptr = r + d.row_start_ind * z; - - // Variable-to-check node messages update - compute_l_product_min1_and_min2(l, new_llrs, r_ptr, &d, num_lanes, - full_vec, tail_size, row_min_array, - row_min2_array, row_sign_array); - // LLRs update - compute_r_and_llrs(l, r_ptr, new_llrs, &d, num_lanes, full_vec, - tail_size, row_min_array, row_min2_array, - row_sign_array); +inline void load_ptr_l(int16_t *ptr_l, const int8_t *llrs_ptr, + uint32_t len_in) { +#if ARMRAL_ARCH_SVE >= 2 + svint8_t vec8; + svbool_t pg = svptrue_b8(); + + uint32_t num_lanes = get_num_lanes(); + uint32_t full_blk = len_in / (2 * num_lanes); + uint32_t tail_cnt = len_in % (2 * num_lanes); + + for (uint32_t num_block = 0; num_block < full_blk; num_block++) { + vec8 = svld1_s8(pg, llrs_ptr); + + svint16_t t1 = svmovlb_s16(vec8); + svint16_t t2 = svmovlt_s16(vec8); + + svint16_t result1 = svzip1_s16(t1, t2); + svint16_t result2 = svzip2_s16(t1, t2); + + svst1_s16(pg, ptr_l, result1); + ptr_l += num_lanes; + svst1_s16(pg, ptr_l, result2); + ptr_l += num_lanes; + llrs_ptr += 2 * num_lanes; + } + + if (tail_cnt != 0U) { + for (uint32_t i = 0; i < tail_cnt; i++) { + ptr_l[i] = (int16_t)llrs_ptr[i]; } + } + +#else + + uint32_t full_blk = len_in / 16; + uint32_t tail_cnt = len_in % 16; + + for (uint32_t num_block = 0; num_block < full_blk; num_block++) { + int8x16_t vec = vld1q_s8(llrs_ptr); + int8x8_t vec_h = vget_high_s8(vec); + int16x8_t vec_h_16 = vmovl_s8(vec_h); + int8x8_t vec_l = vget_low_s8(vec); + int16x8_t vec_l_16 = vmovl_s8(vec_l); + vst1q_s16(ptr_l, vec_l_16); + ptr_l += 8; + vst1q_s16(ptr_l, vec_h_16); + llrs_ptr += 16; + ptr_l += 8; + } - // CRC check and early termination - bool crc_passed = crc_checker.has_value() && crc_checker->check(new_llrs); - if (check_convergence && - (crc_passed || parity_check(new_llrs, z, lsi, graph, num_lanes, - full_vec, tail_size, check))) { - break; + if (tail_cnt != 0U) { + for (uint32_t i = 0; i < tail_cnt; i++) { + ptr_l[i] = (int16_t)llrs_ptr[i]; } } + +#endif } -} // anonymous namespace +template +bool decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, uint32_t z, + uint32_t crc_idx, uint32_t num_its, uint8_t *data_out, + Allocator &allocator) { + + bool crc_passed = false; -template -void armral::ldpc::decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, - uint32_t z, uint32_t crc_idx, uint32_t num_its, - uint8_t *data_out, Allocator &allocator) { // Get the base graph and the lifting size const auto *graph = armral_ldpc_get_base_graph(bg); uint32_t lsi = get_lifting_index(z); - // Only allocate the CRC checker if necessary. - std::optional> maybe_crc_checker; - if (crc_idx != ARMRAL_LDPC_NO_CRC) { - maybe_crc_checker = crc_checker{z, crc_idx, allocator}; - } + // (graph->row_start_inds[2] - graph->row_start_inds[1]) Max no of non -1 in + // columns is 19. Note: Min is calculated over 8 byte block lengths, so need + // lesser memory + uint32_t layer_size = 19 * z; - const uint32_t num_llrs = (graph->ncodeword_bits + 2) * z; + uint32_t num_llrs = (graph->ncodeword_bits + 2) * z; + // max no of non zeros enteries in H Matrix is 316 + uint32_t r_mat_size = 316 * z; - // Assign memory for the things that we need - // We know that the first block rows have the largest number of non-zero - // entries, so the largest layer will be for the first block rows. In - // particular, for both base graphs, the second row is of longest length. - uint32_t mat_size = graph->row_start_inds[graph->nrows] * z; - uint32_t layer_size = - (graph->row_start_inds[2] - graph->row_start_inds[1]) * z; - // We need to keep a record of matrix L (variable-to-check-node messages) - auto l = allocate_uninitialized(allocator, layer_size); - // We need to keep a record of matrix R (check-to-variable-node messages) - auto r = allocate_zeroed(allocator, mat_size); + auto l = allocate_uninitialized(allocator, num_llrs); + // matrix R (check-to-variable-node messages) + auto r = allocate_zeroed(allocator, r_mat_size); - auto row_min_array = allocate_zeroed(allocator, z); - auto row_min2_array = allocate_zeroed(allocator, z); - auto row_sign_array = allocate_zeroed(allocator, z); + uint32_t num_lanes = get_num_lanes(); + uint32_t z_len = z / num_lanes; + uint32_t offset = (z % num_lanes) ? 1 : 0; + z_len = (z_len + offset) * num_lanes; + uint32_t k = (bg == 0) ? (z * 22) : (z * 10); - auto check = allocate_zeroed(allocator, z); - -#if ARMRAL_ARCH_SVE >= 2 - bool z_is_tiny = (z == 2); -#else - bool z_is_tiny = (z < 4); -#endif - - // Keep a record of the current, and previous values of the LLRs - // Copy the inputs LLRs - const auto *llrs_ptr = llrs; - size_t new_llrs_size = num_llrs; - std::optional> maybe_out_llrs; - if (!z_is_tiny) { - // Double the storage required to replicate LLRs for optimization - new_llrs_size *= 2; - // Extra buffer to pack the LLRs again - maybe_out_llrs = allocate_uninitialized(allocator, num_llrs); - } - auto new_llrs = allocate_uninitialized(allocator, new_llrs_size); + auto row_min1_array = allocate_uninitialized(allocator, z_len); + auto row_min2_array = allocate_uninitialized(allocator, z_len); + auto row_sign_array = allocate_zeroed(allocator, z_len); + auto row_pos_array = allocate_zeroed(allocator, z); + auto sign_scratch = allocate_uninitialized(allocator, layer_size); + auto crc_buff = allocate_zeroed(allocator, ((k + 7) >> 3) + 15); // NOTE: All allocations are now done! if constexpr (Allocator::is_counting) { - return; + return false; } - if (z_is_tiny) { - // Set the value of the current LLRs from the ones passed in. - // We need to take account of the punctured columns. - // Also widen to int16_t for use in intermediate calculations. - memset(new_llrs.get(), 0, 2 * z * sizeof(int16_t)); - for (uint32_t i = 0; i < z * graph->ncodeword_bits; i++) { - new_llrs[2 * z + i] = (int16_t)llrs[i]; - } - } else { - // Each block of Z elements replicated b1|b1|b2|b2 ... - // We need to take account of the punctured columns. - // Also widen to int16_t for use in intermediate calculations. - memset(new_llrs.get(), 0, 4 * z * sizeof(int16_t)); - auto *new_llrs_ptr = &new_llrs[4 * z]; - for (uint32_t num_block = 0; num_block < graph->ncodeword_bits; - num_block++) { - for (uint32_t i = 0; i < z; i++) { - new_llrs_ptr[i] = (int16_t)llrs_ptr[i]; - new_llrs_ptr[z + i] = (int16_t)llrs_ptr[i]; - } - new_llrs_ptr += 2 * z; - llrs_ptr += z; - } - } + uint32_t r_index = 0; + + // initialization with channel LLRs. 16-bit buffer "l" will be used for + // in-place calculations + int16_t *ptr_l = l.get(); + const auto *llrs_ptr = llrs; + + // 0 memset 2z LLRs from input to fill the punctured bits + memset(ptr_l, 0, sizeof(int16_t) * 2 * z); + ptr_l = ptr_l + 2 * z; + + load_ptr_l(ptr_l, llrs_ptr, graph->ncodeword_bits * z); + + uint32_t full_blk = z_len / num_lanes; + + for (uint32_t it = 0; it < num_its; ++it) { + r_index = 0; + for (uint32_t layer = 0; layer < graph->nrows; layer++) { - // Precompute number of full vector and tail + // reset the sign buffer + memset(row_sign_array.get(), 0, sizeof(int16_t) * z); + + // reset the min1 min2 buf to max + int16_t *ptr1 = row_min1_array.get(); + int16_t *ptr2 = row_min2_array.get(); + int16_t *ptr3 = row_sign_array.get(); + + for (uint32_t i = 0; i < full_blk; i++) { #if ARMRAL_ARCH_SVE >= 2 - int32_t num_lanes = svcnth(); - int32_t full_vec = z / num_lanes; - uint32_t tail_size = z % num_lanes; - bool is_tail_only = (tail_size == z && !z_is_tiny); + svbool_t pg = svptrue_b16(); + svint16_t vec = svdup_n_s16(0x7FFF); + svint16_t vec_sign = svdup_n_s16(0x1); + svst1_s16(pg, ptr1, vec); + svst1_s16(pg, ptr2, vec); + svst1_s16(pg, ptr3, vec_sign); #else - int32_t num_lanes = 8; - int32_t full_vec = z / num_lanes; - uint32_t tail_size = z % num_lanes; - bool is_tail_only = (tail_size == z && !z_is_tiny); - - if (is_tail_only) { - // tail size = Z - 4 - tail_size -= 4; - } + int16x8_t v8 = vdupq_n_s16(0x7FFF); + int16x8_t v_sign8 = vdupq_n_s16(0x1); + vst1q_s16(ptr1, v8); + vst1q_s16(ptr2, v8); + vst1q_s16(ptr3, v_sign8); #endif + ptr1 += num_lanes; + ptr2 += num_lanes; + ptr3 += num_lanes; + } - if (z_is_tiny) { - run_iterations(num_its, z, lsi, graph, r.get(), l.get(), - new_llrs.get(), num_lanes, full_vec, tail_size, - row_min_array.get(), row_min2_array.get(), - row_sign_array.get(), check.get(), - check_convergence, maybe_crc_checker); + compute_l_r_and_mins(l.get(), r.get(), graph, z, lsi, layer, + row_min1_array.get(), row_min2_array.get(), + row_sign_array.get(), row_pos_array.get(), + sign_scratch.get(), &r_index); - // Hard decode into the output variable - llrs_to_bits(num_llrs, new_llrs.get(), data_out); - } else { - if (is_tail_only) { - run_iterations(num_its, z, lsi, graph, r.get(), l.get(), - new_llrs.get(), num_lanes, full_vec, tail_size, - row_min_array.get(), row_min2_array.get(), - row_sign_array.get(), check.get(), - check_convergence, maybe_crc_checker); - } else { - run_iterations(num_its, z, lsi, graph, r.get(), l.get(), - new_llrs.get(), num_lanes, full_vec, tail_size, - row_min_array.get(), row_min2_array.get(), - row_sign_array.get(), check.get(), - check_convergence, maybe_crc_checker); + update_l_and_r(l.get(), r.get(), graph, z, lsi, layer, + row_min1_array.get(), row_min2_array.get(), + row_sign_array.get(), row_pos_array.get(), + sign_scratch.get(), &r_index); } - // Pack LLRs, copy back to original storage - auto *out_llrs = maybe_out_llrs.value().get(); - for (uint32_t num_block = 0; num_block < graph->ncodeword_bits + 2; - num_block++) { - memcpy(out_llrs + num_block * z, &new_llrs[2 * num_block * z], - z * sizeof(int16_t)); + + // early exit if crc Passes + if (crc_idx) { + if (it < (num_its - 1)) { + crc_passed = + hard_decision(l.get(), crc_buff.get(), &data_out[0], crc_idx, true); + if (crc_passed) { + return crc_passed; + } + } } + } - // Hard decode into the output variable - llrs_to_bits(num_llrs, out_llrs, data_out); + if (crc_idx == ARMRAL_LDPC_NO_CRC) { // do only decisions + crc_passed = hard_decision(l.get(), crc_buff.get(), &data_out[0], + graph->nmessage_bits * z, false); + } else { + crc_passed = + hard_decision(l.get(), crc_buff.get(), &data_out[0], crc_idx, true); } + return crc_passed; } -template void armral::ldpc::decode_block( +} // namespace armral::ldpc + +template bool armral::ldpc::decode_block( const int8_t *llrs, armral_ldpc_graph_t bg, uint32_t z, uint32_t crc_idx, uint32_t num_its, uint8_t *data_out, heap_allocator &); -template void armral::ldpc::decode_block( +template bool armral::ldpc::decode_block( const int8_t *llrs, armral_ldpc_graph_t bg, uint32_t z, uint32_t crc_idx, uint32_t num_its, uint8_t *data_out, buffer_bump_allocator &); @@ -1496,20 +1618,22 @@ armral_status armral_ldpc_decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, uint32_t z, uint32_t crc_idx, uint32_t num_its, uint8_t *data_out) { + heap_allocator allocator{}; - armral::ldpc::decode_block(llrs, bg, z, crc_idx, num_its, data_out, - allocator); - return ARMRAL_SUCCESS; + bool result = armral::ldpc::decode_block(llrs, bg, z, crc_idx, num_its, + data_out, allocator); + return (result) ? ARMRAL_SUCCESS : ARMRAL_RESULT_FAIL; } armral_status armral_ldpc_decode_block_noalloc(const int8_t *llrs, armral_ldpc_graph_t bg, uint32_t z, uint32_t crc_idx, uint32_t num_its, uint8_t *data_out, void *buffer) { + buffer_bump_allocator allocator{buffer}; - armral::ldpc::decode_block(llrs, bg, z, crc_idx, num_its, data_out, - allocator); - return ARMRAL_SUCCESS; + bool result = armral::ldpc::decode_block(llrs, bg, z, crc_idx, num_its, + data_out, allocator); + return (result) ? ARMRAL_SUCCESS : ARMRAL_RESULT_FAIL; } uint32_t armral_ldpc_decode_block_noalloc_buffer_size(armral_ldpc_graph_t bg, @@ -1517,7 +1641,7 @@ uint32_t armral_ldpc_decode_block_noalloc_buffer_size(armral_ldpc_graph_t bg, uint32_t crc_idx, uint32_t num_its) { counting_allocator allocator{}; - armral::ldpc::decode_block(nullptr, bg, z, crc_idx, num_its, nullptr, - allocator); + armral::ldpc::decode_block(nullptr, bg, z, crc_idx, num_its, nullptr, + allocator); return allocator.required_bytes(); } diff --git a/src/UpperPHY/LDPC/ldpc_coding.hpp b/src/UpperPHY/LDPC/ldpc_coding.hpp index 046b0659657641e6ce6286e78a10749d5e481acb..b13f508b0ce5f2692a462c3d3a9972d3be00812e 100644 --- a/src/UpperPHY/LDPC/ldpc_coding.hpp +++ b/src/UpperPHY/LDPC/ldpc_coding.hpp @@ -14,8 +14,8 @@ constexpr uint32_t num_lifting_sets = 8; uint32_t get_lifting_index(uint32_t lifting_size); -template -void decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, uint32_t z, +template +bool decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, uint32_t z, uint32_t crc_idx, uint32_t num_its, uint8_t *data_out, Allocator &allocator); diff --git a/test/UpperPHY/LDPC/Decoding/main.cpp b/test/UpperPHY/LDPC/Decoding/main.cpp index 88bd5815abc767b9931ded04b5348c8a66878fe7..a722b4a2df2dcff725ede3de044aac73a75e2a2c 100644 --- a/test/UpperPHY/LDPC/Decoding/main.cpp +++ b/test/UpperPHY/LDPC/Decoding/main.cpp @@ -102,6 +102,8 @@ template bool run_ldpc_decoding_test(uint32_t its, uint32_t z, armral_ldpc_graph_t bg, uint32_t crc_idx, LDPCDecodingFunction ldpc_decoding_under_test) { + + bool passed = true; const auto *graph = armral_ldpc_get_base_graph(bg); // Allocate a random input to be encoded @@ -112,7 +114,8 @@ bool run_ldpc_decoding_test(uint32_t its, uint32_t z, armral_ldpc_graph_t bg, // If we are doing CRC checking, then we need to attach CRC bits to the input if (crc_idx != ARMRAL_LDPC_NO_CRC) { auto info_to_encode = random.vector((len_in + 7) / 8); - ldpc_crc_attachment(info_to_encode.data(), crc_idx + 24, len_in, + len_in = crc_idx + 24; + ldpc_crc_attachment(info_to_encode.data(), len_in, z * graph->nmessage_bits, to_encode.data()); } @@ -138,14 +141,12 @@ bool run_ldpc_decoding_test(uint32_t its, uint32_t z, armral_ldpc_graph_t bg, armral_demodulation(mod_num_symbols, ulp, mod_type, data_mod.data(), data_demod_soft.data()); - auto decoded = random.vector((encoded_len + 2 * z + 7) / 8); - ldpc_decoding_under_test(data_demod_soft.data(), bg, z, crc_idx, its, - decoded.data()); - auto decoded_bytes = - armral::bits_to_bytes(encoded_len + 2 * z, decoded.data()); - - // Make sure that the codeword passes the parity check - bool passed = perform_parity_check(decoded_bytes.data(), z, bg); + auto decoded = random.vector((len_in + 7) / 8); + if (ldpc_decoding_under_test(data_demod_soft.data(), bg, z, crc_idx, its, + decoded.data()) != ARMRAL_SUCCESS) { + return false; + } + auto decoded_bytes = armral::bits_to_bytes(len_in, decoded.data()); // Also check that the decoded message is equal to the original message auto bytes_in = armral::bits_to_bytes(len_in, to_encode.data()); @@ -163,6 +164,7 @@ bool run_all_tests(char const *name, std::array bgs{LDPC_BASE_GRAPH_1, LDPC_BASE_GRAPH_2}; std::array num_its{1, 2, 5, 10}; std::array zs{2, 6, 13, 20, 30, 56, 144, 208, 224, 256, 320, 384}; + // Crc-index is zero based indexing std::array crc_idx_1{4225, 4553, 5257, 6313, 7721}; std::array crc_idx_2{1921, 2057, 2377, 2857, 3497}; for (auto bg : bgs) { @@ -170,7 +172,7 @@ bool run_all_tests(char const *name, for (uint32_t i = 0; i < zs.size(); i++) { auto z = zs[i]; assert(z < 208 || i >= 7); - auto crc_idx = (z >= 208) ? crc_ids[i - 7] : ARMRAL_LDPC_NO_CRC; + auto crc_idx = (z >= 208) ? (crc_ids[i - 7] - 1) : ARMRAL_LDPC_NO_CRC; for (auto its : num_its) { printf("[%s] z = %d, crc_idx = %u, its = %d\n", name, z, crc_idx, its); auto check = run_ldpc_decoding_test(its, z, bg, crc_idx,