diff --git a/bench/MatrixPseudoInv/Direct/bench.py b/bench/MatrixPseudoInv/Direct/bench.py index 9be0c4927fdf15c5669cabaf8c042c5a52949756..1cd4980e065c908397759173a7eb14c59c6cac68 100755 --- a/bench/MatrixPseudoInv/Direct/bench.py +++ b/bench/MatrixPseudoInv/Direct/bench.py @@ -20,7 +20,7 @@ j = { "cases": [] } -size1 = [2, 3, 4, 8, 16] +size1 = [1, 2, 3, 4, 8, 16] size2 = [32, 64, 128, 256] for (m, n) in itertools.chain(zip(size1, size2), zip(size2, size1)): diff --git a/include/armral.h b/include/armral.h index 402161edaad983eb4c8627185111be0385ba11b1..efc2484d14f7125f8b45545f133a3de28ecddf8a 100644 --- a/include/armral.h +++ b/include/armral.h @@ -1650,8 +1650,8 @@ armral_cmplx_mat_inverse_batch_f32_pa(uint32_t num_mats, uint32_t size, * * The input matrix `p_src` and output matrix `p_dst` are stored contiguously * in memory, in row-major order. The number of rows `m` in the input matrix - * must be 2, 3, 4, 8 or 16 if `m <= n`. The number of columns `n` in the input - * matrix must be 2, 3, 4, 8 or 16 if `m > n`. + * must be 1, 2, 3, 4, 8 or 16 if `m <= n`. The number of columns `n` in the + * input matrix must be 1, 2, 3, 4, 8 or 16 if `m > n`. * * @param[in] m The number of rows in input matrix `A`. * @param[in] n The number of columns in input matrix `A`. @@ -1693,8 +1693,8 @@ armral_cmplx_pseudo_inverse_direct_f32(uint16_t m, uint16_t n, float32_t lambda, * * The input matrix `p_src` and output matrix `p_dst` are stored contiguously * in memory, in row-major order. The number of rows `m` in the input matrix - * must be 2, 3, 4, 8 or 16 if `m <= n`. The number of columns `n` in the input - * matrix must be 2, 3, 4, 8 or 16 if `m > n`. + * must be 1, 2, 3, 4, 8 or 16 if `m <= n`. The number of columns `n` in the + * input matrix must be 1, 2, 3, 4, 8 or 16 if `m > n`. * * This function takes a pre-allocated buffer (`buffer`) to use internally. * This variant will not call any system memory allocators. @@ -1907,7 +1907,8 @@ armral_status armral_demodulation(uint32_t n_symbols, uint16_t ulp, * @param[in] p_src_h Points to the estimated channel matrix `H`. * @param[in] p_src_y Points to the received signal matrix `y`. * @param[in] noise_pwr_est Points to the estimated noise power `σ^2`. - * @param[out] p_dst_x_tild Points to the transmitted signal estimate `x_tild`. + * @param[out] p_dst_x_tild Points to the transmitted signal estimate + * `x_tild`. * @return An `armral_status` value that indicates success or failure. */ armral_status armral_cmplx_channel_equalization_f32( @@ -1968,7 +1969,8 @@ armral_status armral_cmplx_channel_equalization_f32( * @param[in] p_src_h Points to the estimated channel matrix `H`. * @param[in] p_src_y Points to the received signal matrix `y`. * @param[in] noise_pwr_est Points to the estimated noise power `σ^2`. - * @param[out] p_dst_x_tild Points to the transmitted signal estimate `x_tild`. + * @param[out] p_dst_x_tild Points to the transmitted signal estimate + * `x_tild`. * @param[in] buffer Workspace buffer to be used internally. * @return An `armral_status` value that indicates success or failure. */ @@ -2929,8 +2931,8 @@ armral_status armral_polar_decode_block(uint32_t n, const uint8_t *frozen, /** * Matches the rate of the Polar encoded code block to the rate of the channel - * using sub-block interleaving, bit selection, and channel interleaving based on - * Downlink or Uplink direction. This is as described in the 3GPP Technical + * using sub-block interleaving, bit selection, and channel interleaving based + * on Downlink or Uplink direction. This is as described in the 3GPP Technical * Specification (TS) 38.212 section 5.4.1. * * The code rate of the code block is defined by the ratio of the rate-matched @@ -2942,7 +2944,8 @@ armral_status armral_polar_decode_block(uint32_t n, const uint8_t *frozen, * @param[in] n The number of bits in the code block. * @param[in] e The number of bits in the rate-matched message. * @param[in] k The number of information bits in the code block. - * @param[in] i_bil Flag to enable/disable the interleaving of coded bits. + * @param[in] i_bil Flag to enable/disable the interleaving of coded + * bits. * @param[in] p_d_seq_in Points to `n` bits representing the Polar encoded * message. * @param[out] p_f_seq_out Points to `e` bits representing the rate-matched @@ -2979,7 +2982,8 @@ armral_status armral_polar_rate_matching(uint32_t n, uint32_t e, uint32_t k, * @param[in] n The number of bits in the code block. * @param[in] e The number of bits in the rate-matched message. * @param[in] k The number of information bits in the code block. - * @param[in] i_bil Flag to enable/disable the interleaving of coded bits. + * @param[in] i_bil Flag to enable/disable the interleaving of coded + * bits. * @param[in] p_d_seq_in Points to `n` bits representing the Polar encoded * message. * @param[out] p_f_seq_out Points to `e` bits representing the rate-matched diff --git a/src/BasicMathFun/MatrixPseudoInv/cmplx_mat_pseudo_inverse.hpp b/src/BasicMathFun/MatrixPseudoInv/cmplx_mat_pseudo_inverse.hpp index 007cfc3dd2a0b9b6a1b58a25c15aa500dcd693ea..b82db7bf28d106081330b8d5edb6a13c2db354cc 100644 --- a/src/BasicMathFun/MatrixPseudoInv/cmplx_mat_pseudo_inverse.hpp +++ b/src/BasicMathFun/MatrixPseudoInv/cmplx_mat_pseudo_inverse.hpp @@ -82,12 +82,17 @@ void left_pseudo_inverse(uint16_t m, const float32_t lambda, armral_cmplx_mat_mult_ahb_f32(m, n, n, p_src, p_src, mat_aha); // Compute C += lambda * I - add_lambda(lambda, p_dst); + add_lambda(lambda, mat_aha); // Compute B = C^(-1) auto mat_inv = allocate_uninitialized(allocator, n * n); - armral::cmplx_herm_mat_inv::invert_hermitian_matrix(mat_aha, - mat_inv.get()); + if constexpr (n == 1) { + mat_inv[0].re = 1.F / mat_aha[0].re; + mat_inv[0].im = 0.F; + } else { + armral::cmplx_herm_mat_inv::invert_hermitian_matrix(mat_aha, + mat_inv.get()); + } // Compute B * A^H mat_mult_bah_f32(m, n, p_src, mat_inv.get(), p_dst); @@ -108,8 +113,13 @@ void right_pseudo_inverse(uint16_t n, const float32_t lambda, // Compute B = C^(-1) auto mat_inv = allocate_uninitialized(allocator, m * m); - armral::cmplx_herm_mat_inv::invert_hermitian_matrix(mat_aah, - mat_inv.get()); + if constexpr (m == 1) { + mat_inv[0].re = 1.F / mat_aah[0].re; + mat_inv[0].im = 0.F; + } else { + armral::cmplx_herm_mat_inv::invert_hermitian_matrix(mat_aah, + mat_inv.get()); + } // Compute A^H * B armral_cmplx_mat_mult_ahb_f32(m, n, m, p_src, mat_inv.get(), p_dst); @@ -130,6 +140,10 @@ cmplx_pseudo_inverse_direct(uint16_t m, uint16_t n, const float32_t lambda, // columns then use the left pseudo-inverse if (m > n) { switch (n) { + case 1: { + left_pseudo_inverse<1>(m, lambda, p_src, p_dst, allocator); + break; + } case 2: { left_pseudo_inverse<2>(m, lambda, p_src, p_dst, allocator); break; @@ -160,6 +174,10 @@ cmplx_pseudo_inverse_direct(uint16_t m, uint16_t n, const float32_t lambda, // If the number of rows in the input matrix is less than or equal to the number // of columns then use the right pseudo-inverse switch (m) { + case 1: { + right_pseudo_inverse<1>(n, lambda, p_src, p_dst, allocator); + break; + } case 2: { right_pseudo_inverse<2>(n, lambda, p_src, p_dst, allocator); break; diff --git a/test/MatrixPseudoInv/direct/main.cpp b/test/MatrixPseudoInv/direct/main.cpp index a497da9abbae9d339b85ec0fa2abc38d6fd60511..50034457a35895df93481700c636faccce2e957e 100644 --- a/test/MatrixPseudoInv/direct/main.cpp +++ b/test/MatrixPseudoInv/direct/main.cpp @@ -30,13 +30,14 @@ bool run_all_tests(char const *test_name, char const *function_name, bool passed = true; const std::tuple params[] = { - {2, 5, -0.968591}, {2, 84, 0.191647}, {2, 2, 1.457848}, - {2, 67, 0.0}, {3, 18, -1.218053}, {3, 138, 1.597186}, - {3, 3, -1.2435186}, {3, 161, 0.0}, {4, 20, -0.474817}, - {4, 105, 0.944802}, {4, 4, 1.645646}, {4, 94, 0.0}, - {8, 35, -1.991369}, {8, 200, -1.244298}, {8, 8, 1.445767}, - {8, 190, 0.0}, {16, 32, 0.809352}, {16, 80, 1.810591}, - {16, 16, -0.426745}, {16, 117, 0.0}}; + {1, 1, 0.186745}, {1, 21, -0.314205}, {1, 66, 1.495806}, + {1, 121, 0.0}, {2, 5, -0.968591}, {2, 84, 0.191647}, + {2, 2, 1.457848}, {2, 67, 0.0}, {3, 18, -1.218053}, + {3, 138, 1.597186}, {3, 3, -1.2435186}, {3, 161, 0.0}, + {4, 20, -0.474817}, {4, 105, 0.944802}, {4, 4, 1.645646}, + {4, 94, 0.0}, {8, 35, -1.991369}, {8, 200, -1.244298}, + {8, 8, 1.445767}, {8, 190, 0.0}, {16, 32, 0.809352}, + {16, 80, 1.810591}, {16, 16, -0.426745}, {16, 117, 0.0}}; for (const auto &[dim1, dim2, l] : params) { printf("[%s] m=%d, n=%d, l=%f\n", function_name, dim1, dim2, l); passed &= run_pseudo_inverse_direct_cf32_test(function_name, dim1, dim2, l, diff --git a/utils/reference_linalg.hpp b/utils/reference_linalg.hpp index 8cf476bda13dcf6af9077bde2c566fda1c2dcc6f..417241f12805db9f1039ec3d6a7319099773e32c 100644 --- a/utils/reference_linalg.hpp +++ b/utils/reference_linalg.hpp @@ -485,7 +485,12 @@ reference_left_pseudo_inverse_direct(uint32_t m, uint32_t n, float32_t lambda, // Compute B = C^(-1) std::vector mat_inv(n * n); - reference_matinv_block(n, mat_aha, mat_inv.data()); + if (n == 1) { + mat_inv[0].re = 1.F / mat_aha[0].re; + mat_inv[0].im = 0.F; + } else { + reference_matinv_block(n, mat_aha, mat_inv.data()); + } // Compute B * A^H reference_matmul_bah_cf32(m, n, p_src, mat_inv.data(), p_dst); @@ -512,7 +517,12 @@ static inline void reference_right_pseudo_inverse_direct( // Compute B = C^(-1) std::vector mat_inv(m * m); - reference_matinv_block(m, mat_aah, mat_inv.data()); + if (m == 1) { + mat_inv[0].re = 1.F / mat_aah[0].re; + mat_inv[0].im = 0.F; + } else { + reference_matinv_block(m, mat_aah, mat_inv.data()); + } // Compute A^H * B reference_matmul_ahb_cf32(m, n, m, p_src, mat_inv.data(), p_dst);