diff --git a/bench/UpperPHY/Turbo/Batch/Decoding/main.cpp b/bench/UpperPHY/Turbo/Batch/Decoding/main.cpp index 1f143d49ed1365e9e45fe4c47947497c0ad1e986..4963cf6ae7c42cad002dc905aa4bdcf38c006bd5 100644 --- a/bench/UpperPHY/Turbo/Batch/Decoding/main.cpp +++ b/bench/UpperPHY/Turbo/Batch/Decoding/main.cpp @@ -53,7 +53,8 @@ void run_turbo_decode_batch_perf(const uint32_t num_prbs, itl_ptr + j * num_blocks * (num_bits + 4), num_bits, ans_ptr + j * num_blocks * num_bytes, 2.F, num_iters, num_blocks, nullptr, allocator, trellis_termination, decode_block_step, - batched_trellis_termination, decode_batch_step); + batched_trellis_termination, decode_batch_step, + decode_butterfly_step); #else heap_allocator allocator{}; armral::turbo::decode( @@ -62,7 +63,8 @@ void run_turbo_decode_batch_perf(const uint32_t num_prbs, itl_ptr + j * num_blocks * (num_bits + 4), num_bits, ans_ptr + j * num_blocks * num_bytes, 2.F, num_iters, num_blocks, nullptr, allocator, trellis_termination, decode_block_step, - batched_trellis_termination, decode_batch_step); + batched_trellis_termination, decode_batch_step, + decode_butterfly_step); #endif } } diff --git a/bench/UpperPHY/Turbo/Single/Decoding/main.cpp b/bench/UpperPHY/Turbo/Single/Decoding/main.cpp index b919566e605fcab2ae962fe44399ff2d9563e271..cba125fd158cbdcd334857d1b936e047246f3947 100644 --- a/bench/UpperPHY/Turbo/Single/Decoding/main.cpp +++ b/bench/UpperPHY/Turbo/Single/Decoding/main.cpp @@ -47,14 +47,14 @@ void run_turbo_decoding_perf(const uint32_t num_prbs, const uint32_t num_bits, sys_ptr + j * (num_bits + 4), par_ptr + j * (num_bits + 4), itl_ptr + j * (num_bits + 4), num_bits, ans_ptr + j * num_bytes, 2.F, num_iters, 1, nullptr, allocator, trellis_termination, - decode_block_step, nullptr, nullptr); + decode_block_step, nullptr, nullptr, decode_butterfly_step); #else heap_allocator allocator{}; armral::turbo::decode( sys_ptr + j * (num_bits + 4), par_ptr + j * (num_bits + 4), itl_ptr + j * (num_bits + 4), num_bits, ans_ptr + j * num_bytes, 2.F, - num_iters, 1, nullptr, allocator, trellis_termination, - decode_block_step, nullptr, nullptr); + num_iters, 1, nullptr, allocator, trellis_termination, nullptr, + nullptr, decode_block_step, decode_butterfly_step); #endif } } diff --git a/simulation/turbo_awgn/turbo_awgn.cpp b/simulation/turbo_awgn/turbo_awgn.cpp index 1c02f387710e65dec55eb7ba13cde4c30328823e..ee91fbf586fe5916644d692a66486b405b1d308a 100644 --- a/simulation/turbo_awgn/turbo_awgn.cpp +++ b/simulation/turbo_awgn/turbo_awgn.cpp @@ -196,6 +196,19 @@ struct turbo_error_counts { } }; +inline int8_t saturating_mul(int8_t a, int8_t b) { + int16_t result = (int16_t)a * (int16_t)b; + + if (result > INT8_MAX) { + return INT8_MAX; + } + if (result < INT8_MIN) { + return INT8_MIN; + } + + return (int8_t)result; +} + // Perform an end-to-end encoding, rate matching, modulation, transmission, // demodulation, rate recovery, and decoding and count the number of errors void run_check(armral::utils::random_state *state, double snr_db, uint32_t ulp, @@ -255,6 +268,13 @@ void run_check(armral::utils::random_state *state, double snr_db, uint32_t ulp, data->data_mod + mod_offset, data->data_demod_soft + demod_soft_offset); + if (data->num_blocks == 1) { + int8_t *data_demod_soft = data->data_demod_soft + demod_soft_offset; + for (uint32_t i = 0; i < data->len_encoded; i++) { + data_demod_soft[i] = saturating_mul(data_demod_soft[i], -1); + } + } + // The LLRs are updated by rate recovery and must be zero the first time // rate recovery is performed. Since different input data is created for // every loop iteration, we need to reset the LLRs each time. @@ -272,10 +292,16 @@ void run_check(armral::utils::random_state *state, double snr_db, uint32_t ulp, } // Run turbo decoding for num_blocks blocks - armral_turbo_decode_batch(data->num_blocks, data->sys_recovered, - data->par_recovered, data->itl_recovered, - data->len_out, data->data_decoded, iter_max, - data->permutation_indices); + if (data->num_blocks > 1) { + armral_turbo_decode_batch(data->num_blocks, data->sys_recovered, + data->par_recovered, data->itl_recovered, + data->len_out, data->data_decoded, iter_max, + data->permutation_indices); + } else { + armral_turbo_decode_block( + data->sys_recovered, data->par_recovered, data->itl_recovered, + data->len_out, data->data_decoded, iter_max, data->permutation_indices); + } results->num_bit_errors = 0; results->num_block_errors = 0; diff --git a/src/UpperPHY/Turbo/arm_turbo_decoder.cpp b/src/UpperPHY/Turbo/arm_turbo_decoder.cpp index 4938df71880497006362355b74705b1f7885987c..666f9cedb36a9c288fae363839945c7f0df2fa74 100644 --- a/src/UpperPHY/Turbo/arm_turbo_decoder.cpp +++ b/src/UpperPHY/Turbo/arm_turbo_decoder.cpp @@ -17,13 +17,15 @@ template armral_status armral::turbo::decode( const int8_t *sys, const int8_t *par, const int8_t *itl, uint32_t k, uint8_t *dst, float32_t l_c, uint32_t max_iter, uint32_t num_blocks, uint16_t *perm_idxs, heap_allocator &, trellis_term_func_t, - decode_step_func_t, trellis_term_func_t, decode_step_func_t); + decode_step_func_t, trellis_term_func_t, decode_step_func_t, + decode_butterfly_func_t); template armral_status armral::turbo::decode( const int8_t *sys, const int8_t *par, const int8_t *itl, uint32_t k, uint8_t *dst, float32_t l_c, uint32_t max_iter, uint32_t num_blocks, uint16_t *perm_idxs, buffer_bump_allocator &, trellis_term_func_t, - decode_step_func_t, trellis_term_func_t, decode_step_func_t); + decode_step_func_t, trellis_term_func_t, decode_step_func_t, + decode_butterfly_func_t); // Permutation indices armral_status armral_turbo_perm_idx_init(uint16_t *buffer) { @@ -37,25 +39,25 @@ armral_status armral_turbo_decode_block(const int8_t *sys, const int8_t *par, uint8_t *dst, uint32_t max_iter, uint16_t *perm_idxs) { heap_allocator allocator{}; - return armral::turbo::decode(sys, par, itl, k, dst, 2.F, max_iter, 1, - perm_idxs, allocator, trellis_termination, - decode_block_step, nullptr, nullptr); + return armral::turbo::decode( + sys, par, itl, k, dst, 2.F, max_iter, 1, perm_idxs, allocator, + trellis_termination, nullptr, nullptr, nullptr, decode_butterfly_step); } armral_status armral_turbo_decode_block_noalloc( const int8_t *sys, const int8_t *par, const int8_t *itl, uint32_t k, uint8_t *dst, uint32_t max_iter, uint16_t *perm_idxs, void *buffer) { buffer_bump_allocator allocator{buffer}; - return armral::turbo::decode(sys, par, itl, k, dst, 2.F, max_iter, 1, - perm_idxs, allocator, trellis_termination, - decode_block_step, nullptr, nullptr); + return armral::turbo::decode( + sys, par, itl, k, dst, 2.F, max_iter, 1, perm_idxs, allocator, + trellis_termination, nullptr, nullptr, nullptr, decode_butterfly_step); } uint32_t armral_turbo_decode_block_noalloc_buffer_size(uint32_t k) { counting_allocator allocator{}; (void)armral::turbo::decode( nullptr, nullptr, nullptr, k, nullptr, 2.F, 1, 1, nullptr, allocator, - trellis_termination, decode_block_step, nullptr, nullptr); + nullptr, decode_block_step, nullptr, nullptr, decode_butterfly_step); return allocator.required_bytes(); } @@ -70,7 +72,7 @@ armral_status armral_turbo_decode_batch(uint32_t num_blocks, const int8_t *sys, return armral::turbo::decode( sys, par, itl, k, dst, 2.F, max_iter, num_blocks, perm_idxs, allocator, trellis_termination, decode_block_step, batched_trellis_termination, - decode_batch_step); + decode_batch_step, nullptr); } armral_status @@ -79,10 +81,11 @@ armral_turbo_decode_batch_noalloc(uint32_t num_blocks, const int8_t *sys, uint32_t k, uint8_t *dst, uint32_t max_iter, uint16_t *perm_idxs, void *buffer) { buffer_bump_allocator allocator{buffer}; + return armral::turbo::decode( - sys, par, itl, k, dst, 2.F, max_iter, num_blocks, perm_idxs, allocator, + sys, par, itl, k, dst, 0.5F, max_iter, num_blocks, perm_idxs, allocator, trellis_termination, decode_block_step, batched_trellis_termination, - decode_batch_step); + decode_batch_step, nullptr); } uint32_t armral_turbo_decode_batch_noalloc_buffer_size(uint32_t k) { @@ -90,6 +93,6 @@ uint32_t armral_turbo_decode_batch_noalloc_buffer_size(uint32_t k) { (void)armral::turbo::decode( nullptr, nullptr, nullptr, k, nullptr, 2.F, 0, 8, nullptr, allocator, trellis_termination, decode_block_step, batched_trellis_termination, - decode_batch_step); + decode_batch_step, nullptr); return allocator.required_bytes(); } diff --git a/src/UpperPHY/Turbo/arm_turbo_decoder_single.hpp b/src/UpperPHY/Turbo/arm_turbo_decoder_single.hpp index 8e36f5aae501b6c312b6c89a586be833e4f2ddfb..7ced9a089500e53981ce36f4dc8e261d3a898ab6 100644 --- a/src/UpperPHY/Turbo/arm_turbo_decoder_single.hpp +++ b/src/UpperPHY/Turbo/arm_turbo_decoder_single.hpp @@ -261,4 +261,238 @@ void decode_block_step(const int16x8_t *sys, const int16x8_t *par, } } +// A single max-log-MAP decoder that works on an array of systematic bits (sys), +// an array of parity bits (par), and an array of extrinsic values from a +// previous decoding stage (extrinsic) +void decode_butterfly_step(uint8_t siso_flag, uint32_t length, uint32_t k, + int16x8_t *extr, int16x8_t *sys, int16x8_t *par, + void *gamma, int16x8_t *beta, int16x8_t *apos) { + + int16x8_t *intr; + // backward recursion + uint8x16_t beta_state_idx_1 = {0, 1, 8, 9, 10, 11, 2, 3, + 4, 5, 12, 13, 14, 15, 6, 7}; + uint8x16_t beta_state_idx_2 = {8, 9, 0, 1, 2, 3, 10, 11, + 12, 13, 4, 5, 6, 7, 14, 15}; + + // initialize last stage of trellis + int16x8_t beta_tmp = {0, -16384, -16384, -16384, + -16384, -16384, -16384, -16384}; + beta[0] = beta_tmp; + + int8x16_t beta_branch_indices_1 = {0, 1, 0, 1, 2, 3, 2, 3, + 2, 3, 2, 3, 0, 1, 0, 1}; + int8x16_t beta_branch_indices_2 = {6, 7, 6, 7, 4, 5, 4, 5, + 4, 5, 4, 5, 6, 7, 6, 7}; + + int16x4x8_t *gamma4 = static_cast(gamma); + int32_t l = k - length * 8; // fractional bits + + int32_t beta_idx = 1; + + if (siso_flag == 0) { + intr = extr; + + for (int32_t trellis_idx = length, gamma_idx = length; trellis_idx >= 0; + trellis_idx--, gamma_idx--) { + + int16x8_t intr_info = + vqaddq_s16(extr[trellis_idx] >> 1, sys[trellis_idx]); // 8 elements + extr[trellis_idx] = intr_info; + int16x8_t gamma_00 = vqsubq_s16(vqnegq_s16(intr_info), par[trellis_idx]); + int16x8_t gamma_10 = vqsubq_s16(par[trellis_idx], intr_info); + int16x8_t gamma_01 = vqsubq_s16(intr_info, par[trellis_idx]); + int16x8_t gamma_11 = vqaddq_s16(intr_info, par[trellis_idx]); + + vst4q_s16((int16_t *)&gamma4[gamma_idx], + int16x8x4_t({gamma_00, gamma_10, gamma_01, gamma_11})); + + for (int32_t j = l - 1; j >= 0; j--, beta_idx++) { + int8x8_t gamma_tmp = vreinterpret_s8_s16(gamma4[gamma_idx].val[j]); + + int16x8_t gamma_at_state_1 = + vreinterpretq_s16_s8(vtbl1q_s8(gamma_tmp, beta_branch_indices_1)); + + int16x8_t gamma_at_state_2 = + vreinterpretq_s16_s8(vtbl1q_s8(gamma_tmp, beta_branch_indices_2)); + + int16x8_t prev_state = beta_tmp; + + uint8x16_t tmp1 = vreinterpretq_u8_s16(prev_state); + uint8x16_t tmp2 = vqtbl1q_u8(tmp1, beta_state_idx_1); + uint8x16_t tmp3 = vqtbl1q_u8(tmp1, beta_state_idx_2); + + int16x8_t prev_state_1 = vreinterpretq_s16_u8(tmp2); + int16x8_t prev_state_2 = vreinterpretq_s16_u8(tmp3); + + int16x8_t prev_state_prob_1 = + vqaddq_s16(gamma_at_state_1, prev_state_1); + int16x8_t prev_state_prob_2 = + vqaddq_s16(gamma_at_state_2, prev_state_2); + + int16x8_t beta_max_vec = + vmaxq_s16(prev_state_prob_1, prev_state_prob_2); + + // normalization + int16_t beta_max = vmaxvq_s16(beta_max_vec); + + int16x8_t beta_norm_vec = vdupq_n_s16(beta_max); + beta_tmp = vqsubq_s16(beta_max_vec, beta_norm_vec); + beta[beta_idx] = beta_tmp; + } + + l = 8; + } + } else { + + intr = sys; + + for (int32_t trellis_idx = length, gamma_idx = length; trellis_idx >= 0; + trellis_idx--, gamma_idx--) { + int16x8_t intr_info = + vqaddq_s16(extr[trellis_idx] >> 1, sys[trellis_idx]); + int16x8_t gamma_00 = vqsubq_s16(vqnegq_s16(intr_info), par[trellis_idx]); + int16x8_t gamma_10 = vqsubq_s16(par[trellis_idx], intr_info); + int16x8_t gamma_01 = vqsubq_s16(intr_info, par[trellis_idx]); + int16x8_t gamma_11 = vqaddq_s16(intr_info, par[trellis_idx]); + + vst4q_s16((int16_t *)&gamma4[gamma_idx], + int16x8x4_t({gamma_00, gamma_10, gamma_01, gamma_11})); + + for (int32_t j = l - 1; j >= 0; j--, beta_idx++) { + + int8x8_t gamma_tmp = vreinterpret_s8_s16(gamma4[gamma_idx].val[j]); + + int16x8_t gamma_at_state_1 = + vreinterpretq_s16_s8(vtbl1q_s8(gamma_tmp, beta_branch_indices_1)); + + int16x8_t gamma_at_state_2 = + vreinterpretq_s16_s8(vtbl1q_s8(gamma_tmp, beta_branch_indices_2)); + + int16x8_t prev_state = beta_tmp; + + uint8x16_t tmp1 = vreinterpretq_u8_s16(prev_state); + uint8x16_t tmp2 = vqtbl1q_u8(tmp1, beta_state_idx_1); + uint8x16_t tmp3 = vqtbl1q_u8(tmp1, beta_state_idx_2); + + int16x8_t prev_state_1 = vreinterpretq_s16_u8(tmp2); + int16x8_t prev_state_2 = vreinterpretq_s16_u8(tmp3); + + int16x8_t prev_state_prob_1 = + vqaddq_s16(gamma_at_state_1, prev_state_1); + int16x8_t prev_state_prob_2 = + vqaddq_s16(gamma_at_state_2, prev_state_2); + + int16x8_t beta_max_vec = + vmaxq_s16(prev_state_prob_1, prev_state_prob_2); + + // normalization + int16_t beta_max = vmaxvq_s16(beta_max_vec); + + int16x8_t beta_norm_vec = vdupq_n_s16(beta_max); + beta_tmp = vqsubq_s16(beta_max_vec, beta_norm_vec); + + beta[beta_idx] = beta_tmp; + } + + l = 8; + } + } + + // forward recursion + uint8x16_t alpha_prev_state_idx_1 = {0, 1, 6, 7, 8, 9, 14, 15, + 2, 3, 4, 5, 10, 11, 12, 13}; + uint8x16_t alpha_prev_state_idx_2 = {2, 3, 4, 5, 10, 11, 12, 13, + 0, 1, 6, 7, 8, 9, 14, 15}; + + int8x16_t alpha_branch_idx_1 = {0, 1, 2, 3, 2, 3, 0, 1, + 0, 1, 2, 3, 2, 3, 0, 1}; + int8x16_t alpha_branch_idx_2 = {6, 7, 4, 5, 4, 5, 6, 7, + 6, 7, 4, 5, 4, 5, 6, 7}; + + uint8x16_t alpha_state_idx = {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}; + + int8x16_t gamma_idx_1 = {0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 0, 1}; + int8x16_t gamma_idx_2 = {6, 7, 6, 7, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7}; + + // init alpha + int16x8_t alpha_next = {0, -16384, -16384, -16384, + -16384, -16384, -16384, -16384}; + + beta_idx = k - 1; + for (uint32_t trellis_idx = 0, gamma_idx = 0; trellis_idx <= length; + trellis_idx++, gamma_idx++) { + + int16x8_t gamma_apos_1; + int16x8_t gamma_apos_2; + + for (int32_t j = 0; (j < 8) && (beta_idx >= 0); j++, beta_idx--) { + + // calculate alpha state + int16x8_t curr_state = alpha_next; + uint8x16_t tmp1 = vreinterpretq_u8_s16(curr_state); + uint8x16_t tmp2 = vqtbl1q_u8(tmp1, alpha_prev_state_idx_1); + uint8x16_t tmp3 = vqtbl1q_u8(tmp1, alpha_prev_state_idx_2); + + int16x8_t curr_state_1 = vreinterpretq_s16_u8(tmp2); + int16x8_t curr_state_2 = vreinterpretq_s16_u8(tmp3); + + int8x8_t gamma_tmp = vreinterpret_s8_s16(gamma4[gamma_idx].val[j]); + + int16x8_t gamma_at_state_1 = + vreinterpretq_s16_s8(vtbl1q_s8(gamma_tmp, alpha_branch_idx_1)); + + int16x8_t gamma_at_state_2 = + vreinterpretq_s16_s8(vtbl1q_s8(gamma_tmp, alpha_branch_idx_2)); + + int16x8_t next_state_prob_1 = vqaddq_s16(gamma_at_state_1, curr_state_1); + int16x8_t next_state_prob_2 = vqaddq_s16(gamma_at_state_2, curr_state_2); + + // next alpha state + alpha_next = vmaxq_s16(next_state_prob_1, next_state_prob_2); + // normalization + int16_t alpha_max = vmaxvq_s16(alpha_next); + int16x8_t alpha_norm_vec = vdupq_n_s16(alpha_max); + alpha_next = vqsubq_s16(alpha_next, alpha_norm_vec); + + // calculates llr - alpha + gamma + beta + tmp1 = vreinterpretq_u8_s16(curr_state); + tmp2 = vqtbl1q_u8(tmp1, alpha_state_idx); + + curr_state_1 = vreinterpretq_s16_u8(tmp2); + + int16x8_t next_state = beta[beta_idx]; + tmp1 = vreinterpretq_u8_s16(next_state); + tmp2 = vqtbl1q_u8(tmp1, beta_state_idx_1); + tmp3 = vqtbl1q_u8(tmp1, beta_state_idx_2); + + int16x8_t next_state_1 = vreinterpretq_s16_u8(tmp2); + int16x8_t next_state_2 = vreinterpretq_s16_u8(tmp3); + + gamma_at_state_1 = + vreinterpretq_s16_s8(vtbl1q_s8(gamma_tmp, gamma_idx_1)); + gamma_at_state_2 = + vreinterpretq_s16_s8(vtbl1q_s8(gamma_tmp, gamma_idx_2)); + + int16x8_t apos_prob_1 = vqaddq_s16(gamma_at_state_1, next_state_1); + int16x8_t apos_prob_2 = vqaddq_s16(gamma_at_state_2, next_state_2); + + apos_prob_1 = vqaddq_s16(apos_prob_1, curr_state_1); + apos_prob_2 = vqaddq_s16(apos_prob_2, curr_state_1); + + // max across all states + gamma_apos_1[j] = vmaxvq_s16(apos_prob_1); + gamma_apos_2[j] = vmaxvq_s16(apos_prob_2); + } + + // llr extr = ratio of (probability of 1/ prob of 0) + int16x8_t llr = vqsubq_s16(gamma_apos_2, gamma_apos_1); + apos[trellis_idx] = llr; + + // extrinsic info update + extr[trellis_idx] = vqsubq_s16(llr, intr[trellis_idx]); + } +} + } // namespace diff --git a/src/UpperPHY/Turbo/turbo_code_common.hpp b/src/UpperPHY/Turbo/turbo_code_common.hpp index 027d081cc772a629ad8f9fc29cacbb3d8a6099f5..0cd9be908c016808105444583b2678d57f40eeaa 100644 --- a/src/UpperPHY/Turbo/turbo_code_common.hpp +++ b/src/UpperPHY/Turbo/turbo_code_common.hpp @@ -125,6 +125,21 @@ inline void turbo_llrs_to_bits(uint32_t n, const int16x8_t *llr, } } +// Reusing the turbo_llrs_to_bits func, but comparing with greater than +// instead of lesser than, as the Demod signs have been flipped +inline void turbo_hard_decisions(uint32_t n, const int16x8_t *llr, + uint8_t *data_out) { + uint32_t full_bytes = n >> 3; + constexpr uint16x8_t ones = {128, 64, 32, 16, 8, 4, 2, 1}; + + for (uint32_t i = 0; i < full_bytes; ++i) { + // The first bit to write in the byte is the most significant + uint16x8_t pred = vcgtzq_s16(llr[i]); + uint16x8_t mask = vandq_u16(pred, ones); + data_out[i] = (uint8_t)vaddvq_u16(mask); + } +} + template void permute(perm_idx_lookup *perm_lookup, Vec *from, Vec *to, uint32_t arr_length, uint32_t vec_length = 8) { @@ -140,6 +155,17 @@ void permute(perm_idx_lookup *perm_lookup, Vec *from, Vec *to, } } +template +uint32_t compare_hard_decisions(Vec *d1, Vec *d2, uint32_t n) { + uint32_t sum = 0; + for (uint32_t i = 0; i < n; i++) { + Vec v0 = d1[i] == d2[i]; + sum = sum + v0 * 8; + } + + return sum; +} + template void unpermute(perm_idx_lookup *perm_lookup, Vec *from, Vec *to, uint32_t arr_length, uint32_t vec_length = 8) { @@ -164,6 +190,12 @@ void interleave(const T *src, uint32_t ldsrc, T *dst, uint32_t lddst) { } } +using decode_butterfly_func_t = void (*)(uint8_t siso_flag, uint32_t length, + uint32_t k, int16x8_t *extr, + int16x8_t *sys, int16x8_t *par, + void *gamma, int16x8_t *beta, + int16x8_t *apos); + using trellis_term_func_t = void (*)(const int16x8_t *sys, const int16x8_t *par, uint32_t k, int16x8_t *beta_tail, int16x8_t l_c); @@ -203,13 +235,14 @@ decode_loop(int16x8_t *sys, int16x8_t *pys, int16x8_t *par, int16x8_t *itl, // permute sys into pys // perm_idx is < k, so we can do this before handling trellis termination permute(perm_lookup, sys, pys, kv); + // Unperturb the trellis termination bits. They are transmitted as: // X0 Z1 X'0 Z'1 Z0 X2 Z'0 X'2 X1 Z2 X'1 Z'2 // but need to appended to the inputs as: - // X0 X1 X2 - // Z0 Z1 Z2 - // X'0 X'1 X'2 - // Z'0 Z'1 Z'2 + // X0 X1 X2 => sys_nat + // Z0 Z1 Z2 =>par_nat + // X'0 X'1 X'2 => sys_interlv + // Z'0 Z'1 Z'2 => par_interlv // Order like this so we can copy inplace if constexpr (batched) { // sys[k] = sys[k]; @@ -238,6 +271,7 @@ decode_loop(int16x8_t *sys, int16x8_t *pys, int16x8_t *par, int16x8_t *itl, itl[kv][2] = itl[kv][3]; pys[kv][2] = par[kv][3]; } + // Prescale l_c to avoid doing it a bunch later const int16x8_t channel_reliability = vdupq_n_s16((int16_t)l_c / 2); // Initialize alpha (= zero or min) and beta (from the trellis termination) @@ -308,6 +342,131 @@ decode_loop(int16x8_t *sys, int16x8_t *pys, int16x8_t *par, int16x8_t *itl, return ARMRAL_SUCCESS; } +inline void update_apos(uint32_t length, const int16x8_t *a, const int16x8_t *b, + int16x8_t *c) { + for (uint32_t i = 0; i < length; i++) { + c[i] = vqaddq_s16(a[i], b[i]); + } +} + +template +armral_status decode_butterfly_loop( + int16x8_t *sys_nat, int16x8_t *sys_interlv, int16x8_t *par_nat, + int16x8_t *par_interlv, int16x8_t *intr_nat, int16x8_t *intr_interlv, + int16x8_t *extr_nat, int16x8_t *apos_nat, int16x4_t *gamma, int16x8_t *beta, + uint32_t k, uint8_t **hd_2d, uint8_t *dst, uint32_t max_iter, + perm_idx_lookup *perm_lookup, + decode_butterfly_func_t decode_map_butterfly) { + + constexpr uint32_t vec_len = 8; + uint32_t kv = batched ? k : k / vec_len; + uint32_t k_hd = (k + vec_len - 1) / vec_len; + if (k % vec_len) { + kv = kv + 1; + } + + // permute sys into pys + // perm_idx is < k, so we can do this before handling trellis termination + permute(perm_lookup, sys_nat, sys_interlv, kv); + + // Unperturb the trellis termination bits. They are transmitted as: + // X0 Z1 X'0 Z'1 Z0 X2 Z'0 X'2 X1 Z2 X'1 Z'2 + // but need to appended to the inputs as: + // X0 X1 X2 + // Z0 Z1 Z2 + // X'0 X'1 X'2 + // Z'0 Z'1 Z'2 + // Order like this so we can copy inplace + if constexpr (batched) { + + sys_interlv[k] = sys_nat[k + 2]; + sys_nat[k + 2] = par_nat[k + 1]; + par_nat[k + 1] = sys_nat[k + 1]; + sys_nat[k + 1] = par_interlv[k]; + par_interlv[k] = par_nat[k + 2]; + par_nat[k + 2] = par_interlv[k + 1]; + par_interlv[k + 1] = sys_nat[k + 3]; + sys_interlv[k + 1] = par_interlv[k + 2]; + par_interlv[k + 2] = par_interlv[k + 3]; + sys_interlv[k + 2] = par_nat[k + 3]; + } else { + + sys_interlv[kv][0] = sys_nat[kv][2]; + sys_nat[kv][2] = par_nat[kv][1]; + par_nat[kv][1] = sys_nat[kv][1]; + sys_nat[kv][1] = par_interlv[kv][0]; + par_interlv[kv][0] = par_nat[kv][2]; + par_nat[kv][2] = par_interlv[kv][1]; + par_interlv[kv][1] = sys_nat[kv][3]; + sys_interlv[kv][1] = par_interlv[kv][2]; + par_interlv[kv][2] = par_interlv[kv][3]; + sys_interlv[kv][2] = par_nat[kv][3]; + } + + // DECODE + for (uint32_t num_iter = 0, count = 0; num_iter < max_iter; ++num_iter) { + + // Map decoder + decode_map_butterfly(0, kv, k + 3, extr_nat, sys_nat, par_nat, + (void *)gamma, beta, apos_nat); + + if constexpr (check_convergence) { + // hard_decision Map1 + turbo_hard_decisions(k_hd << 3, apos_nat, hd_2d[0]); + + // Early Termination + if (k == compare_hard_decisions((uint8_t *)hd_2d[0], + (uint8_t *)hd_2d[1], k >> 3)) { + count++; + // wait for atleast 2 consequetive half iterations have same HD result + if (count >= 2) { + break; + } + } else { + count = 0; + } + } + + // permute ext info + permute(perm_lookup, extr_nat, intr_interlv, kv); + + // Map decoder 2 + decode_map_butterfly(1, kv, k + 3, intr_interlv, sys_interlv, par_interlv, + (void *)gamma, beta, apos_nat); + + // undo permuation + unpermute(perm_lookup, intr_interlv, extr_nat, kv); + + // aposteriori update + update_apos(kv, extr_nat, sys_nat, apos_nat); + + if constexpr (check_convergence) { + // hard_decision Map2 + turbo_hard_decisions(k_hd << 3, apos_nat, hd_2d[1]); + + // Early Termination + if (k == compare_hard_decisions((uint8_t *)hd_2d[0], + (uint8_t *)hd_2d[1], k >> 3)) { + count++; + // wait for atleast 2 consequetive half iterations have same HD result + if (count >= 2) { + break; + } + + } else { + count = 0; + } + } + } + + // bits to byte pack + if constexpr (check_convergence) { + memcpy(dst, hd_2d[1], k >> 3); + } + + return ARMRAL_SUCCESS; +} + template armral_status decode(const int8_t *sys_i8, const int8_t *par_i8, const int8_t *itl_i8, uint32_t k, uint8_t *dst, @@ -316,7 +475,8 @@ armral_status decode(const int8_t *sys_i8, const int8_t *par_i8, trellis_term_func_t trellis_termination_single, decode_step_func_t decode_step_single, trellis_term_func_t trellis_termination_batched, - decode_step_func_t decode_step_batched) { + decode_step_func_t decode_step_batched, + decode_butterfly_func_t decode_butterfly_single) { // Outer decoder function for a single block or batch of `num_blocks` blocks. // The inputs and the output to this function are uninterleaved. // @@ -344,47 +504,16 @@ armral_status decode(const int8_t *sys_i8, const int8_t *par_i8, auto par = allocate_uninitialized(allocator, len); auto pys = allocate_uninitialized(allocator, len); // permuted sys auto itl = allocate_uninitialized(allocator, len); - auto extrinsic = allocate_zeroed(allocator, kv); - auto perm_extrinsic = allocate_uninitialized(allocator, kv); - // Allocate space for log likelihood ratios from both stages of decoding - auto llr = allocate_uninitialized(allocator, kv); - auto perm_llr = allocate_uninitialized(allocator, kv); - auto prev_perm_llr = allocate_zeroed(allocator, kv); - // Allocate space to hold alpha and gamma - // alpha stores the forward-accumulated state probabilities for each decoded - // bit, where the LTE encoder has 8 states and there are k bits to decode - // plus the starting condition (no alpha for the 3 trellis termination bits) - auto alpha = allocate_uninitialized(allocator, 8 * (kv + 1)); - auto beta_tail = allocate_uninitialized(allocator, beta_tail_len); - auto perm_beta_tail = - allocate_uninitialized(allocator, beta_tail_len); - // gamma stores the conditional state transition probabilities for each of the - // k bits to decode. There are 16 transitions per k but only 4 unique values. - // Use int16_t so decode_step has same signature for batched and single. - auto gamma = allocate_uninitialized(allocator, kv * vec_len * 4); // PERM_IDXS // Get the permutation vector for the input value of k. // If perm_idxs is uninitialized (==nullptr) then generate indices here. unique_ptr perm_lookup_unique; perm_idx_lookup *perm_lookup = nullptr; + // Find the index into the array of parameter arrays corresponding // to the current k. Subtract 40 because k=40 is the lowest value. uint32_t param_idx = perm_params_lookup[(k - 40) >> 3]; - if (perm_idxs != nullptr) { - if constexpr (Allocator::is_counting) { // NOTE: All allocations done. - return ARMRAL_SUCCESS; - } - perm_lookup = (perm_idx_lookup *)perm_idxs + perm_lookup_offset[param_idx]; - } else { - perm_lookup_unique = allocate_uninitialized(allocator, k); - if constexpr (Allocator::is_counting) { // NOTE: All allocations done. - return ARMRAL_SUCCESS; - } - perm_lookup = perm_lookup_unique.get(); - // Generate the permutation vector for the input value of k. - k_perm_idx_init(k, param_idx, perm_lookup); - } // How many elements in input data? // This is different than `len` which accounts for vectorizing @@ -392,56 +521,160 @@ armral_status decode(const int8_t *sys_i8, const int8_t *par_i8, uint32_t b = 0; // block index uint32_t dat_offset = 0; uint32_t dst_offset = 0; - for (; batched && b < num_blocks - vec_len + 1; - b += vec_len, dat_offset += vec_len * dat_len, dst_offset += k) { - // Decode 8 blocks - - if (b > 0) { // Re-zero buffers which should start at 0 - for (uint32_t i = 0; i < k; i++) { - extrinsic[i] = vdupq_n_s16(0); - prev_perm_llr[i] = vdupq_n_s16(0); + + if (num_blocks > 1) { + + auto extrinsic = allocate_zeroed(allocator, kv); + auto perm_extrinsic = allocate_uninitialized(allocator, kv); + // Allocate space for log likelihood ratios from both stages of decoding + auto llr = allocate_uninitialized(allocator, kv); + auto perm_llr = allocate_uninitialized(allocator, kv); + auto prev_perm_llr = allocate_zeroed(allocator, kv); + // Allocate space to hold alpha and gamma + // alpha stores the forward-accumulated state probabilities for each decoded + // bit, where the LTE encoder has 8 states and there are k bits to decode + // plus the starting condition (no alpha for the 3 trellis termination bits) + auto alpha = allocate_uninitialized(allocator, 8 * (kv + 1)); + auto beta_tail = + allocate_uninitialized(allocator, beta_tail_len); + auto perm_beta_tail = + allocate_uninitialized(allocator, beta_tail_len); + // gamma stores the conditional state transition probabilities for each of + // the k bits to decode. There are 16 transitions per k but only 4 unique + // values. Use int16_t so decode_step has same signature for batched and + // single. + auto gamma = allocate_uninitialized(allocator, kv * vec_len * 4); + + if (perm_idxs != nullptr) { + if constexpr (Allocator::is_counting) { // NOTE: All allocations done. + return ARMRAL_SUCCESS; + } + perm_lookup = + (perm_idx_lookup *)perm_idxs + perm_lookup_offset[param_idx]; + } else { + perm_lookup_unique = + allocate_uninitialized(allocator, k); + if constexpr (Allocator::is_counting) { // NOTE: All allocations done. + return ARMRAL_SUCCESS; + } + perm_lookup = perm_lookup_unique.get(); + // Generate the permutation vector for the input value of k. + k_perm_idx_init(k, param_idx, perm_lookup); + } + + for (; batched && b < num_blocks - vec_len + 1; + b += vec_len, dat_offset += vec_len * dat_len, dst_offset += k) { + // Decode 8 blocks + + if (b > 0) { // Re-zero buffers which should start at 0 + for (uint32_t i = 0; i < k; i++) { + extrinsic[i] = vdupq_n_s16(0); + prev_perm_llr[i] = vdupq_n_s16(0); + } } + + // Convert type and vectorize, then interleave (use pys as a buffer) + convert_llrs(dat_len * vec_len, sys_i8 + dat_offset, pys.get()); + interleave((int16_t *)pys.get(), dat_len, (int16_t *)sys.get(), vec_len); + + convert_llrs(dat_len * vec_len, par_i8 + dat_offset, pys.get()); + interleave((int16_t *)pys.get(), dat_len, (int16_t *)par.get(), vec_len); + + convert_llrs(dat_len * vec_len, itl_i8 + dat_offset, pys.get()); + interleave((int16_t *)pys.get(), dat_len, (int16_t *)itl.get(), vec_len); + + decode_loop( + sys.get(), pys.get(), par.get(), itl.get(), extrinsic.get(), + perm_extrinsic.get(), llr.get(), perm_llr.get(), prev_perm_llr.get(), + alpha.get(), beta_tail.get(), perm_beta_tail.get(), gamma.get(), k, + dst + dst_offset, l_c, max_iter, perm_lookup, + trellis_termination_batched, decode_step_batched); } - // Convert type and vectorize, then interleave (use pys as a buffer) - convert_llrs(dat_len * vec_len, sys_i8 + dat_offset, pys.get()); - interleave((int16_t *)pys.get(), dat_len, (int16_t *)sys.get(), vec_len); + for (; b < num_blocks; + ++b, dat_offset += dat_len, dst_offset += k / vec_len) { + // Decode 1 block + + if (b > 0) { // Re-zero buffers which should start at 0 + for (uint32_t i = 0; i < k / vec_len; i++) { + extrinsic[i] = vdupq_n_s16(0); + prev_perm_llr[i] = vdupq_n_s16(0); + } + } + + // Convert type and vectorize + convert_llrs(dat_len, sys_i8 + dat_offset, sys.get()); + convert_llrs(dat_len, par_i8 + dat_offset, par.get()); + convert_llrs(dat_len, itl_i8 + dat_offset, itl.get()); + + decode_loop( + sys.get(), pys.get(), par.get(), itl.get(), extrinsic.get(), + perm_extrinsic.get(), llr.get(), perm_llr.get(), prev_perm_llr.get(), + alpha.get(), beta_tail.get(), perm_beta_tail.get(), gamma.get(), k, + dst + dst_offset, l_c, max_iter, perm_lookup, + trellis_termination_single, decode_step_single); + } - convert_llrs(dat_len * vec_len, par_i8 + dat_offset, pys.get()); - interleave((int16_t *)pys.get(), dat_len, (int16_t *)par.get(), vec_len); + } else { // New Max Map Decoder butterfly scheme - convert_llrs(dat_len * vec_len, itl_i8 + dat_offset, pys.get()); - interleave((int16_t *)pys.get(), dat_len, (int16_t *)itl.get(), vec_len); + auto intr_nat = allocate_zeroed(allocator, len); // k+3 + auto intr_interlv = allocate_zeroed(allocator, len); - decode_loop( - sys.get(), pys.get(), par.get(), itl.get(), extrinsic.get(), - perm_extrinsic.get(), llr.get(), perm_llr.get(), prev_perm_llr.get(), - alpha.get(), beta_tail.get(), perm_beta_tail.get(), gamma.get(), k, - dst + dst_offset, l_c, max_iter, perm_lookup, - trellis_termination_batched, decode_step_batched); - } - for (; b < num_blocks; - ++b, dat_offset += dat_len, dst_offset += k / vec_len) { - // Decode 1 block - - if (b > 0) { // Re-zero buffers which should start at 0 - for (uint32_t i = 0; i < k / vec_len; i++) { - extrinsic[i] = vdupq_n_s16(0); - prev_perm_llr[i] = vdupq_n_s16(0); + auto extr_nat = allocate_zeroed(allocator, len); // k+3 + auto apos_nat = allocate_zeroed(allocator, len); + + auto gamma = allocate_zeroed(allocator, ((k + 3 + 7) >> 3) << 3); + auto beta = allocate_uninitialized( + allocator, ((k + 3 + 7) >> 3) << 3); //(k + 3) * 8 states + + uint8_t *hd_2d[2]; + // buffer for the both map decoder HD + auto dst1 = + allocate_uninitialized(allocator, ((k + 15) >> 4) << 4); + auto dst2 = + allocate_uninitialized(allocator, ((k + 15) >> 4) << 4); + + hd_2d[0] = dst1.get(); + hd_2d[1] = dst2.get(); + + if (perm_idxs != nullptr) { + if constexpr (Allocator::is_counting) { // NOTE: All allocations done. + return ARMRAL_SUCCESS; + } + perm_lookup = + (perm_idx_lookup *)perm_idxs + perm_lookup_offset[param_idx]; + } else { + perm_lookup_unique = + allocate_uninitialized(allocator, k); + if constexpr (Allocator::is_counting) { // NOTE: All allocations done. + return ARMRAL_SUCCESS; } + perm_lookup = perm_lookup_unique.get(); + // Generate the permutation vector for the input value of k. + k_perm_idx_init(k, param_idx, perm_lookup); } - // Convert type and vectorize - convert_llrs(dat_len, sys_i8 + dat_offset, sys.get()); - convert_llrs(dat_len, par_i8 + dat_offset, par.get()); - convert_llrs(dat_len, itl_i8 + dat_offset, itl.get()); - - decode_loop( - sys.get(), pys.get(), par.get(), itl.get(), extrinsic.get(), - perm_extrinsic.get(), llr.get(), perm_llr.get(), prev_perm_llr.get(), - alpha.get(), beta_tail.get(), perm_beta_tail.get(), gamma.get(), k, - dst + dst_offset, l_c, max_iter, perm_lookup, - trellis_termination_single, decode_step_single); + // Decode 1 block at a time + for (; b < num_blocks; + ++b, dat_offset += dat_len, dst_offset += k / vec_len) { + + // initialize 2nd Map Dec hard bits + int8x8_t *d = (int8x8_t *)dst2.get(); + for (uint32_t i = 0; i < ((k + 7) >> 3); i++) { + d[i] = vdup_n_s8(-1); + } + + // Convert type and vectorize + convert_llrs(dat_len, sys_i8 + dat_offset, sys.get()); + convert_llrs(dat_len, par_i8 + dat_offset, par.get()); + convert_llrs(dat_len, itl_i8 + dat_offset, itl.get()); + + decode_butterfly_loop( + sys.get(), pys.get(), par.get(), itl.get(), intr_nat.get(), + intr_interlv.get(), extr_nat.get(), apos_nat.get(), gamma.get(), + beta.get(), k, hd_2d, dst + dst_offset, max_iter, perm_lookup, + decode_butterfly_single); + } } return ARMRAL_SUCCESS; } diff --git a/test/UpperPHY/Turbo/turbo_decode_test_utils.hpp b/test/UpperPHY/Turbo/turbo_decode_test_utils.hpp index 19020be65a15c7b78a8eead5f94200b80d16c6a7..9291c445f5a17bea251779098701186aedadd47f 100644 --- a/test/UpperPHY/Turbo/turbo_decode_test_utils.hpp +++ b/test/UpperPHY/Turbo/turbo_decode_test_utils.hpp @@ -25,8 +25,22 @@ void interleave(T *src, uint32_t ldsrc, T *dst, uint32_t lddst) { } } +inline int8_t saturating_mul(int8_t a, int8_t b) { + int16_t result = (int16_t)a * (int16_t)b; + + if (result > INT8_MAX) { + return INT8_MAX; + } + if (result < INT8_MIN) { + return INT8_MIN; + } + + return (int8_t)result; +} + static inline void setup_block_data(uint32_t k, uint32_t b, uint8_t *src, - int8_t *sys, int8_t *par, int8_t *itl) { + int8_t *sys, int8_t *par, int8_t *itl, + uint32_t num_blocks) { auto k_bytes = k / 8; std::vector sys_encode(k_bytes + 1, 255); @@ -58,6 +72,15 @@ static inline void setup_block_data(uint32_t k, uint32_t b, uint8_t *src, armral_demodulation(mod_num_symbols, ulp, mod_type, sys_mod.data(), sys); armral_demodulation(mod_num_symbols, ulp, mod_type, par_mod.data(), par); armral_demodulation(mod_num_symbols, ulp, mod_type, itl_mod.data(), itl); + + // demod is flipping bits, undo + if (num_blocks == 1) { + for (uint32_t i = 0; i < (k + 4); i++) { + sys[i] = saturating_mul(sys[i], -1); + par[i] = saturating_mul(par[i], -1); + itl[i] = saturating_mul(itl[i], -1); + } + } } // Check that the decoder returns the original @@ -79,7 +102,7 @@ run_one_turbo_decoding_test(uint32_t num_blocks, char const *name, uint32_t k, for (uint32_t b = 0; b < num_blocks; ++b) { setup_block_data(k, b, &src[b * (k_bytes + 1)], &sys[b * len], - &par[b * len], &itl[b * len]); + &par[b * len], &itl[b * len], num_blocks); } // Decode the encoded data. We set the maximum number of decoder iterations to @@ -117,6 +140,7 @@ run_one_turbo_decoding_test(uint32_t num_blocks, char const *name, uint32_t k, if (passed) { printf("[%s_BatchSize-%u_k-%u] - check result: OK\n", name, num_blocks, k); } + return passed; }