From 5e0844415f2cb4dbc03ee6225515844784987233 Mon Sep 17 00:00:00 2001 From: Rosie Sumpter Date: Wed, 17 Jul 2024 10:24:11 +0000 Subject: [PATCH] Release 24.07 Co-Authored-By: Nick Dingle Co-Authored-By: Amy Wignall Co-Authored-By: Radu Salavat --- .gitlab-ci.yml | 17 +- CHANGELOG.md | 86 +- CMakeLists.txt | 32 +- Doxyfile.in | 2 +- README.md | 17 +- RELEASE_NOTES.md | 202 +- THIRD_PARTY_LICENSES.md | 3 - .../Batch/GeneralMatInv/NonPA/main.cpp | 2 +- .../MatrixInv/Batch/GeneralMatInv/PA/main.cpp | 8 +- .../Batch/HermitianMatInv/NonPA/main.cpp | 10 +- .../Batch/HermitianMatInv/PA/main.cpp | 10 +- .../MatrixInv/Single/GeneralMatInv/main.cpp | 2 +- .../MatrixInv/Single/HermitianMatInv/main.cpp | 3 +- .../Single/MatrixMult16/32b/main.cpp | 2 +- .../Single/MatrixMult16/64b/main.cpp | 2 +- .../Single/MatrixMult32/2x2/IQ/main.cpp | 12 +- .../Single/MatrixMult32/4x4/IQ/main.cpp | 12 +- .../Single/MatrixMult32/general/bench.py | 17 +- .../Single/MatrixMult32/general/main.cpp | 17 +- .../Single/MatrixMultAAH32/main.cpp | 2 +- .../Single/MatrixMultAHB32/bench.py | 16 +- .../Single/MatrixMultAHB32/main.cpp | 18 +- .../MuLaw/Compression/14bit/main.cpp | 2 +- .../MuLaw/Compression/8bit/main.cpp | 2 +- .../MuLaw/Compression/9bit/main.cpp | 2 +- .../MuLaw/Decompression/14bit/main.cpp | 2 +- .../MuLaw/Decompression/8bit/main.cpp | 2 +- .../MuLaw/Decompression/9bit/main.cpp | 2 +- .../ORanBlockFloat/Compression/12bit/main.cpp | 2 +- .../ORanBlockFloat/Compression/14bit/main.cpp | 2 +- .../ORanBlockFloat/Compression/8bit/main.cpp | 2 +- .../ORanBlockFloat/Compression/9bit/main.cpp | 2 +- .../Decompression/12bit/main.cpp | 2 +- .../Decompression/14bit/main.cpp | 2 +- .../Decompression/8bit/main.cpp | 2 +- .../Decompression/9bit/main.cpp | 2 +- .../Compression/14bit/main.cpp | 2 +- .../Compression/8bit/main.cpp | 2 +- .../Compression/9bit/main.cpp | 2 +- .../Decompression/14bit/main.cpp | 2 +- .../Decompression/8bit/main.cpp | 2 +- .../Decompression/9bit/main.cpp | 2 +- bench/LowerPHY/SeqGenerator/main.cpp | 2 +- bench/UpperPHY/LDPC/Decoding/main.cpp | 2 +- bench/UpperPHY/LDPC/Encoding/main.cpp | 2 +- bench/UpperPHY/Turbo/Decoding/main.cpp | 2 +- bench/UpperPHY/Turbo/RateMatching/main.cpp | 2 +- bench/UpperPHY/Turbo/RateRecovery/main.cpp | 2 +- cmake_uninstall.cmake.in | 2 +- docs/examples.md | 2 +- docs/frontmatter.md | 3 +- include/armral.h | 875 ++++---- simulation/CMakeLists.txt | 22 - .../convolutional_awgn/convolutional_awgn.cpp | 7 +- simulation/ldpc_awgn/ldpc_awgn.cpp | 11 +- .../modulation_awgn/modulation_awgn.cpp | 4 +- simulation/polar_awgn/polar_awgn.cpp | 9 +- simulation/turbo_awgn/turbo_awgn.cpp | 7 +- .../arm_cmplx_hermitian_mat_inversion_f32.cpp | 20 +- .../MatrixMult/arm_cmplx_mat_mult_f32.c | 1165 ---------- .../MatrixMult/arm_cmplx_mat_vec_mult_f32.c | 4 +- .../MatrixMult/arm_cmplx_mat_vec_mult_i16.c | 18 +- ...h_f32.cpp => arm_cmplx_matmul_aah_f32.cpp} | 72 +- ...ahb_f32.c => arm_cmplx_matmul_ahb_f32.cpp} | 278 +-- .../MatrixMult/arm_cmplx_matmul_f32.cpp | 1985 +++++++++++++++++ .../MatrixMult/arm_cmplx_matmul_i16.cpp | 14 + .../MatrixMult/arm_cmplx_matmul_i16_32bit.cpp | 15 + ...lx_mat_mult_i16.c => cmplx_matmul_i16.hpp} | 325 +-- ...i16_32bit.c => cmplx_matmul_i16_32bit.hpp} | 400 ++-- .../arm_cmplx_pseudo_inverse_direct_f32.cpp | 6 +- .../VectorDotProd/arm_cmplx_vecdot_f32.c | 8 +- .../VectorDotProd/arm_cmplx_vecdot_f32_2.c | 12 +- .../VectorDotProd/arm_cmplx_vecdot_i16.c | 13 +- .../VectorDotProd/arm_cmplx_vecdot_i16_2.c | 12 +- .../arm_cmplx_vecdot_i16_2_32bit.c | 12 +- .../arm_cmplx_vecdot_i16_32bit.c | 12 +- .../VectorMult/arm_cmplx_vecmul_f32.c | 12 +- .../VectorMult/arm_cmplx_vecmul_f32_2.c | 8 +- .../VectorMult/arm_cmplx_vecmul_i16.cpp | 12 +- .../VectorMult/arm_cmplx_vecmul_i16_2.c | 13 +- .../arm_mu_law_decompression.cpp | 16 +- .../arm_block_float_decompression.cpp | 14 +- .../arm_block_scaling_decompression.cpp | 16 +- src/LowerPHY/Correlation/arm_correlation.c | 24 +- src/LowerPHY/FFT/fft_plan.cpp | 4 +- src/LowerPHY/FFT/rader.cpp | 6 +- .../SeqGenerator/arm_mat_seq_generator.cpp | 2 +- src/MatrixFactorizations/SVD/arm_svd.cpp | 297 +-- src/MatrixFactorizations/SVD/matrix_view.hpp | 14 +- .../arm_convolutional_decoder.cpp | 6 +- src/UpperPHY/LDPC/ldpc_coding.hpp | 2 +- src/UpperPHY/LDPC/ldpc_decoder.cpp | 8 +- src/UpperPHY/LDPC/ldpc_encoder.cpp | 73 +- src/UpperPHY/LDPC/ldpc_rate_matching.cpp | 2 +- src/UpperPHY/Polar/arm_polar_decoder.cpp | 4 +- src/UpperPHY/Polar/arm_polar_encoder.c | 12 +- src/UpperPHY/Polar/arm_polar_frozen_bits.cpp | 6 +- src/UpperPHY/Turbo/arm_turbo_decoder.cpp | 6 +- ...decoder_fp16.hpp => arm_turbo_decoder.hpp} | 335 ++- .../Turbo/arm_turbo_rate_matching.cpp | 2 +- .../Turbo/arm_turbo_rate_recovery.cpp | 2 +- src/UpperPHY/Turbo/turbo_decoder_fp32.hpp | 533 ----- .../utils/bits_to_bytes.hpp | 13 +- src/utils/vec_mul.hpp | 20 +- test/BasicMathFun/MatrixInv/Batch/main.cpp | 15 +- test/BasicMathFun/MatrixInv/Single/main.cpp | 8 + .../MatrixMult/Batch/ArmSolve/main.cpp | 32 +- .../Batch/MatrixVectorMult16/main.cpp | 44 +- .../Batch/MatrixVectorMult32/main.cpp | 35 +- .../MatrixMult/Single/MatrixMult16/main.cpp | 32 +- .../MatrixMult/Single/MatrixMult32/main.cpp | 54 +- .../Single/MatrixMultAAH32/main.cpp | 14 +- .../Single/MatrixMultAHB32/main.cpp | 33 +- .../Single/MatrixVectorMult16/main.cpp | 28 +- .../Single/MatrixVectorMult32/main.cpp | 13 +- .../MatrixPseudoInv/Direct/main.cpp | 19 +- .../VectorDotProd/VecDot16/main.cpp | 21 +- .../VectorDotProd/VecDot16_2/main.cpp | 29 +- .../VectorDotProd/VecDot16_2_32bit/main.cpp | 43 +- .../VectorDotProd/VecDot16_32bit/main.cpp | 25 +- .../VectorDotProd/VecDot32/main.cpp | 15 +- .../VectorDotProd/VecDot32_2/main.cpp | 27 +- .../BasicMathFun/VectorMult/VecMul16/main.cpp | 25 +- .../VectorMult/VecMul16_2/main.cpp | 54 +- .../BasicMathFun/VectorMult/VecMul32/main.cpp | 14 +- .../VectorMult/VecMul32_2/main.cpp | 28 +- test/DuRuInterface/MuLaw/Compression/main.cpp | 20 +- .../MuLaw/Decompression/main.cpp | 20 +- .../ORanBlockFloat/Compression/main.cpp | 18 +- .../ORanBlockFloat/Decompression/main.cpp | 23 +- .../ORanBlockScaling/Compression/main.cpp | 8 +- .../ORanBlockScaling/Decompression/main.cpp | 7 +- test/LowerPHY/Correlation/main.cpp | 29 +- test/LowerPHY/FFT/FFT16/main.cpp | 33 +- test/LowerPHY/FFT/FFT32/main.cpp | 32 +- test/LowerPHY/FIR/FIR16/main.cpp | 28 +- test/LowerPHY/FIR/FIR16Decimate2/main.cpp | 14 +- test/LowerPHY/FIR/FIR32/main.cpp | 14 +- test/LowerPHY/FIR/FIR32Decimate2/main.cpp | 14 +- test/LowerPHY/Scrambling/main.cpp | 11 +- test/LowerPHY/SeqGenerator/main.cpp | 8 +- test/MatrixFactorizations/SVD/main.cpp | 32 +- test/MatrixFactorizations/SVD/svd_test.hpp | 136 +- test/UpperPHY/CRC/main.cpp | 4 +- test/UpperPHY/ConvolutionalDecoder/main.cpp | 18 +- test/UpperPHY/ConvolutionalEncoder/main.cpp | 32 +- test/UpperPHY/Demodulation/main.cpp | 15 +- test/UpperPHY/LDPC/Decoding/main.cpp | 18 +- test/UpperPHY/LDPC/Encoding/main.cpp | 36 +- test/UpperPHY/LDPC/RateMatching/main.cpp | 28 +- test/UpperPHY/LDPC/RateRecovery/main.cpp | 13 +- test/UpperPHY/Modulation/main.cpp | 32 +- test/UpperPHY/Polar/CrcAttachment/main.cpp | 7 +- test/UpperPHY/Polar/Decoding/main.cpp | 19 +- test/UpperPHY/Polar/Encoding/main.cpp | 10 +- test/UpperPHY/Polar/Frozen/main.cpp | 7 +- test/UpperPHY/Polar/RateMatching/main.cpp | 9 +- test/UpperPHY/Polar/RateRecovery/main.cpp | 6 +- .../Polar/SubchannelDeinterleave/main.cpp | 11 +- .../Polar/SubchannelInterleave/main.cpp | 11 +- test/UpperPHY/Turbo/RateMatching/main.cpp | 18 +- test/UpperPHY/Turbo/RateRecovery/main.cpp | 14 +- utils/cf32_utils.hpp | 64 +- utils/cs16_utils.hpp | 133 +- utils/fft_utils.hpp | 9 +- utils/int8_utils.hpp | 86 - utils/int_utils.hpp | 149 ++ utils/matrix_utils.hpp | 44 +- utils/qint64.hpp | 6 +- utils/reference_linalg.hpp | 125 +- utils/rng.cpp | 3 +- utils/rng.hpp | 51 +- 172 files changed, 5055 insertions(+), 4297 deletions(-) delete mode 100644 src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_f32.c rename src/BasicMathFun/MatrixMult/{arm_cmplx_mat_mult_aah_f32.cpp => arm_cmplx_matmul_aah_f32.cpp} (90%) rename src/BasicMathFun/MatrixMult/{arm_cmplx_mat_mult_ahb_f32.c => arm_cmplx_matmul_ahb_f32.cpp} (70%) create mode 100644 src/BasicMathFun/MatrixMult/arm_cmplx_matmul_f32.cpp create mode 100644 src/BasicMathFun/MatrixMult/arm_cmplx_matmul_i16.cpp create mode 100644 src/BasicMathFun/MatrixMult/arm_cmplx_matmul_i16_32bit.cpp rename src/BasicMathFun/MatrixMult/{arm_cmplx_mat_mult_i16.c => cmplx_matmul_i16.hpp} (79%) rename src/BasicMathFun/MatrixMult/{arm_cmplx_mat_mult_i16_32bit.c => cmplx_matmul_i16_32bit.hpp} (82%) rename src/UpperPHY/Turbo/{turbo_decoder_fp16.hpp => arm_turbo_decoder.hpp} (56%) delete mode 100644 src/UpperPHY/Turbo/turbo_decoder_fp32.hpp rename utils/bit_utils.hpp => src/utils/bits_to_bytes.hpp (91%) delete mode 100644 utils/int8_utils.hpp create mode 100644 utils/int_utils.hpp diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 86609f5..7a58224 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,14 +1,13 @@ -image: gcc - -before_script: - - apt-get update --yes - - apt-get install --yes cmake gcovr - -build: +default: + image: ubuntu:22.04 tags: - arm64 + before_script: + - apt-get update --yes + - apt-get install --yes cmake g++ gcovr + +build: script: - mkdir build - - cd build && cmake -DBUILD_TESTING=On -DARMRAL_ENABLE_COVERAGE=On .. && make -j check && gcovr --gcov-ignore-parse-errors -r .. + - cd build && cmake -DBUILD_TESTING=On -DARMRAL_ENABLE_COVERAGE=On .. && make -j check && gcovr -r .. coverage: /^TOTAL.*\s+(\d+\%)$/ - diff --git a/CHANGELOG.md b/CHANGELOG.md index 4318feb..500fdf5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,6 @@ documented in this file. ### Added ### Changed -- Documentation is now installed by the `make install` target, if it has been -built. ### Deprecated @@ -19,6 +17,90 @@ built. ### Security +## [24.07] - 2024-07-18 + +### Added +- CMake option `ARMRAL_ENABLE_WEXTRA` to add the compiler flag `-Wextra` when +building the library and tests. + +### Changed +- Documentation is now installed by the `make install` target, if it has been +built. + +- Improved performance of `armral_cmplx_matmul_f32`. For complex 32-bit +floating point matrix multiplication, we recommend you use this function for +all cases. This function calls existing optimized special cases with minimal +overhead and has new optimizations for larger cases. + +- Improved performance of `armral_turbo_decode_block` and +`armral_turbo_decode_block_noalloc`. These functions now operate internally on +16-bit integer values rather than 16-bit or 32-bit floating point values. + +- The following functions now use unsigned integers in their interfaces to +represent the lengths of vectors and the dimensions of matrices: + - `armral_cmplx_vecdot_f32` + - `armral_cmplx_vecdot_f32_2` + - `armral_cmplx_vecdot_i16` + - `armral_cmplx_vecdot_i16_2` + - `armral_cmplx_vecdot_i16_32bit` + - `armral_cmplx_vecdot_i16_2_32bit` + - `armral_cmplx_vecmul_f32` + - `armral_cmplx_vecmul_f32_2` + - `armral_cmplx_vecmul_i16` + - `armral_cmplx_vecmul_i16_2` + - `armral_corr_coeff_i16` + - `armral_svd_cf32` + - `armral_svd_cf32_noalloc` + - `armral_svd_cf32_noalloc_buffer_size` + +- Renamed `armral_cmplx_mat_mult_aah_f32` to be `armral_cmplx_matmul_aah_f32`. +All arguments are in the same order and have the same meaning. + +- Replaced `armral_cmplx_mat_mult_ahb_f32` with `armral_cmplx_matmul_ahb_f32`. +Note that the meanings of the parameters `m`, `n`, and `k` differ between the +old function and the new; a call to the old function of the form + + `armral_cmplx_mat_mult_ahb_f32(dim1, dim2, dim3, a, b, c);` + +becomes + + `armral_cmplx_matmul_ahb_f32(dim2, dim3, dim1, a, b, c);` + +- Replaced `armral_cmplx_mat_mult_i16` with `armral_cmplx_matmul_i16`. +Note that the meanings of the parameters `m`, `n`, and `k` differ between the +old function and the new; a call to the old function of the form + + `armral_cmplx_mat_mult_i16(dim1, dim2, dim3, a, b, c);` + +becomes + + `armral_cmplx_matmul_i16(dim1, dim3, dim2, a, b, c);` + +- Replaced `armral_cmplx_mat_mult_i16_32bit` with `armral_cmplx_matmul_i16_32bit`. +Note that the meanings of the parameters `m`, `n`, and `k` differ between the +old function and the new; a call to the old function of the form + + `armral_cmplx_mat_mult_i16_32bit(dim1, dim2, dim3, a, b, c);` + +becomes + + `armral_cmplx_matmul_i16_32bit(dim1, dim3, dim2, a, b, c);` + +- Replaced `armral_cmplx_matmul_f32` with `armral_cmplx_matmul_f32`. +Note that the meanings of the parameters `m`, `n`, and `k` differ between the +old function and the new; a call to the old function of the form + + `armral_cmplx_mat_mult_f32(dim1, dim2, dim3, a, b, c);` + +becomes + + `armral_cmplx_matmul_f32(dim1, dim3, dim2, a, b, c);` + +### Fixed +- Corrected documentation for `armral_cmplx_mat_inverse_batch_f32` and +`armral_cmplx_mat_inverse_batch_f32_pa` to clarify that these functions have no +restriction on batch sizes. + ## [24.04] - 2024-04-19 ### Added diff --git a/CMakeLists.txt b/CMakeLists.txt index a69e666..10df7fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.3) -project(armral VERSION 24.04) +project(armral VERSION 24.07) if(CMAKE_VERSION VERSION_GREATER 3.4) # Stop CMake from automatically adding -rdynamic to linker flags because it @@ -16,6 +16,8 @@ if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release") endif() +option(ARMRAL_ENABLE_WEXTRA + "Enable -Wextra when building the library and tests" OFF) option(ARMRAL_ENABLE_WERROR "Enable -Werror when building the library and tests" OFF) option(ARMRAL_ENABLE_ASAN @@ -33,11 +35,11 @@ set_property(CACHE ARMRAL_ARCH PROPERTY STRINGS "NEON" "SVE2") set(ARMRAL_LIB_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixInv/arm_cmplx_hermitian_mat_inversion_f32.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixInv/arm_cmplx_mat_inversion_f32.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_aah_f32.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_ahb_f32.c - ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_f32.c - ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_i16.c - ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_i16_32bit.c + ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_aah_f32.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_ahb_f32.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_f32.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_i16.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_i16_32bit.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_mat_vec_mult_f32.c ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_mat_vec_mult_i16.c ${CMAKE_CURRENT_SOURCE_DIR}/src/BasicMathFun/MatrixMult/arm_cmplx_mat_vec_mult_i16_32bit.c @@ -184,6 +186,21 @@ endif() set(ARMRAL_COMPILER_FLAGS "") set(ARMRAL_LINKER_FLAGS "") +if(ARMRAL_ENABLE_WEXTRA) + if(ARMRAL_OVERRIDE_COMPILE_FLAGS) + message( + WARNING + "CMAKE_C_FLAGS and CMAKE_CXX_FLAGS manually specified. Ignoring option ARMRAL_ENABLE_WEXTRA" + ) + else() + # We have the same interfaces for Neon and SVE implementations of functions, + # and sometimes we pass in parameters that are only used in one or the + # other. We therefore disable warnings for unused parameters. + set(ARMRAL_COMPILER_FLAGS ${ARMRAL_COMPILER_FLAGS} -Wextra + -Wno-unused-parameter) + endif() +endif() + if(ARMRAL_ENABLE_WERROR) if(ARMRAL_OVERRIDE_COMPILE_FLAGS) message( @@ -359,7 +376,6 @@ install(FILES LICENSE.md THIRD_PARTY_LICENSES.md if(BUILD_TESTING) include(CTest) - enable_testing() if(NOT DEFINED BENCHMARKER_SOURCE_DIR) set(BENCHMARKER_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) endif() @@ -603,7 +619,7 @@ if(BUILD_TESTING) matrix_mult_f32_4x4 bench/BasicMathFun/MatrixMult/Single/MatrixMult32/4x4/NonIQ/main.cpp) add_armral_bench( - matrix_mult_f32_general + matmul_f32_general bench/BasicMathFun/MatrixMult/Single/MatrixMult32/general/main.cpp) add_armral_bench( matrix_mult_aah_32 diff --git a/Doxyfile.in b/Doxyfile.in index f571d32..06c4468 100644 --- a/Doxyfile.in +++ b/Doxyfile.in @@ -38,7 +38,7 @@ PROJECT_NAME = "Arm RAN Acceleration Library Reference Guide" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = "24.04" +PROJECT_NUMBER = "24.07" # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/README.md b/README.md index 83a9a05..80cf1dc 100644 --- a/README.md +++ b/README.md @@ -69,8 +69,8 @@ including: Notes: * The `-DBUILD_TESTING=On` and `-DBUILD_EXAMPLES=On` options are required - if you want to run the library tests (`-DBUILD_TESTING`) and benchmarks - (`-DBUILD_EXAMPLES`). + if you want to run the library tests and benchmarks (`-DBUILD_TESTING`) + and examples (`-DBUILD_EXAMPLES`). * The `-DCMAKE_INSTALL_PREFIX=` option specifies the base directory used to install the library. The library archive is installed to @@ -157,6 +157,15 @@ including: Default is `Off`. + * `-DARMRAL_ENABLE_WEXTRA={On|Off}` + + Use (`On`), or do not use (`Off`), `-Wextra` to build the library and + tests. `-Wextra` enables additional compiler warnings over the default + `-Wall`. Disabled by default to aid compatibility with untested and future + compiler releases. + + Default is `Off`. + * `-DARMRAL_ENABLE_WERROR={On|Off}` Use (`On`), or do not use (`Off`), `-Werror` to build the library and @@ -319,7 +328,7 @@ directory. More information about the examples that are available in Arm RAN Acceleration Library, and how to use the library in general, is available in -**Use Arm RAN Acceleration Library (ArmRAL)** (see `examples.md`). +**Use Arm RAN Acceleration Library (ArmRAL)**, see `docs/examples.md`. ## Run the simulations @@ -385,7 +394,7 @@ file. The Arm RAN Acceleration Library Reference Guide is available online at: - https://developer.arm.com/documentation/102249/2404 + https://developer.arm.com/documentation/102249/2407 If you have Doxygen installed on your system, you can build a local HTML version of the Arm RAN Acceleration Library documentation using CMake. diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 7830b73..3cfec74 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,4 +1,4 @@ -# Arm RAN Acceleration Library 24.04 Release Note +# Arm RAN Acceleration Library 24.07 Release Notes Non-Confidential Copyright © 2020-2024 Arm Limited (or its affiliates). All rights reserved. @@ -18,8 +18,8 @@ this document. ## Release overview -The following sections describe the product that this release note describes and -its quality status at time of release. +The following sections describe the product to which these release notes relate +and its quality status at time of release. ### Product description @@ -28,7 +28,7 @@ accelerating telecommunications applications such as, but not limited to, 5G Radio Access Networks (RANs). These functions are optimized for Arm AArch64-based processors. -Arm RAN Acceleration Library provides: +ArmRAL provides: - Vector functions - Matrix functions @@ -36,12 +36,12 @@ Arm RAN Acceleration Library provides: - Upper physical layer (Upper PHY) support functions - Distributed Unit-Radio Unit (DU-RU) Interface support functions -Arm RAN Acceleration Library includes functions that operate on 16-bit signed -integers and 16-bit and 32-bit floating-point values. +ArmRAL includes functions that operate on 16-bit signed integers and 16-bit and +32-bit floating-point values. ### Release status -This is the 24.04 release of Arm RAN Acceleration Library. +This is the 24.07 release of ArmRAL. These deliverables are being released under the terms of the agreement between Arm and each licensee (the "Agreement"). All planned verification and @@ -51,42 +51,40 @@ The release is suitable for volume production under the terms of the Agreement. ### Licensing information -Use of Arm RAN Acceleration Library is subject to a BSD-3-Clause license, the -text of which can be found in the `LICENSE.md` file in your product -installation. We will receive inbound contributions under the same license. +Use of ArmRAL is subject to a BSD-3-Clause license, the text of which can be +found in the `LICENSE.md` file in your product installation. We will receive +inbound contributions under the same license. If you require a different license than BSD-3-Clause for compatibility with your end product, please get in contact. ## Release contents -Arm RAN Acceleration Library releases contain documentation and source files. +ArmRAL releases contain documentation and source files. The following subsections describe: -- Cloning the product's git repository from Arm's Gitlab. -- The contents of this release. -- Any changes since the previous release. -- Any known issues and limitations that exist at the time of this release. +- Cloning the product's git repository from Arm's GitLab +- The contents of this release +- Any changes since the previous release +- Any known issues and limitations that exist at the time of this release ### Cloning the source repository -**To obtain the 24.04 release of Arm RAN Acceleration Library by cloning - the repository via HTTPS:** +**To access this release, clone the following repository using HTTPS:** - git clone -b armral-24.04 https://git.gitlab.arm.com/networking/ral + git clone -b armral-24.07 https://git.gitlab.arm.com/networking/ral ### Deliverables -The downloaded product includes the deliverables listed in this section. +The downloaded product includes the following deliverables: -- Arm RAN Acceleration Library 24.04 +- ArmRAL 24.07 - Release Notes (this document) - Documentation - Product documentation is available on the Arm Developer website at: - - + Product documentation is available on the + [Arm Developer website](https://developer.arm.com/documentation/102249/2407). **Note:** Documentation, errata and release notes might change between product releases. For the latest documentation bundle, check the product download @@ -94,90 +92,146 @@ The downloaded product includes the deliverables listed in this section. **Note:** Arm tests its PDFs only in Adobe Acrobat and Acrobat Reader. Arm cannot guarantee the quality of this document when used with any other PDF - reader. A suitable PDF reader can be downloaded from Adobe at - . + reader. A suitable PDF reader can be downloaded from + [Adobe](http://www.adobe.com). ### Differences from previous release The following subsections describe differences from the previous release of -Arm RAN Acceleration Library. +ArmRAL. #### Additions and functionality changes -Describes new features or any technical changes to features or +This section describes new features or any technical changes to features or components in this release. -- Extended `armral_cmplx_pseudo_inverse_direct_f32` and - `armral_cmplx_pseudo_inverse_direct_f32_noalloc` to compute the - regularized pseudo-inverse of a complex 32-bit matrix of size - `M-by-N` for the case where `M` and/or `N == 1`. +- For complex 32-bit floating point matrix multiplication, we recommend that +you use `armral_cmplx_matmul_f32` for all cases. This function calls existing +optimized special cases with minimal overhead and has new optimizations for +larger cases. + +- Renamed `armral_cmplx_mat_mult_aah_f32` to be `armral_cmplx_matmul_aah_f32`. +All arguments are in the same order and have the same meaning. + +- Replaced `armral_cmplx_mat_mult_ahb_f32` with `armral_cmplx_matmul_ahb_f32`. +Note that the meanings of the parameters `m`, `n`, and `k` differ between the +old function and the new; a call to the old function of the form + + `armral_cmplx_mat_mult_ahb_f32(dim1, dim2, dim3, a, b, c);` + +becomes + + `armral_cmplx_matmul_ahb_f32(dim2, dim3, dim1, a, b, c);` + +- Replaced `armral_cmplx_mat_mult_i16` with `armral_cmplx_matmul_i16`. +Note that the meanings of the parameters `m`, `n`, and `k` differ between the +old function and the new; a call to the old function of the form + + `armral_cmplx_mat_mult_i16(dim1, dim2, dim3, a, b, c);` + +becomes + + `armral_cmplx_matmul_i16(dim1, dim3, dim2, a, b, c);` + +- Replaced `armral_cmplx_mat_mult_i16_32bit` with `armral_cmplx_matmul_i16_32bit`. +Note that the meanings of the parameters `m`, `n`, and `k` differ between the +old function and the new; a call to the old function of the form + + `armral_cmplx_mat_mult_i16_32bit(dim1, dim2, dim3, a, b, c);` + +becomes + + `armral_cmplx_matmul_i16_32bit(dim1, dim3, dim2, a, b, c);` -- Added a Makefile target `bench_excel_summary` to run the benchmarks - and create an Excel spreadsheet containing the results. +- Replaced `armral_cmplx_mat_mult_f32` with `armral_cmplx_matmul_f32`. +Note that the meanings of the parameters `m`, `n`, and `k` differ between the +old function and the new; a call to the old function of the form + + `armral_cmplx_mat_mult_f32(dim1, dim2, dim3, a, b, c);` + +becomes + + `armral_cmplx_matmul_f32(dim1, dim3, dim2, a, b, c);` + +- The following functions now use unsigned integers in their interfaces to +represent the lengths of vectors and the dimensions of matrices: + - `armral_cmplx_vecdot_f32` + - `armral_cmplx_vecdot_f32_2` + - `armral_cmplx_vecdot_i16` + - `armral_cmplx_vecdot_i16_2` + - `armral_cmplx_vecdot_i16_32bit` + - `armral_cmplx_vecdot_i16_2_32bit` + - `armral_cmplx_vecmul_f32` + - `armral_cmplx_vecmul_f32_2` + - `armral_cmplx_vecmul_i16` + - `armral_cmplx_vecmul_i16_2` + - `armral_corr_coeff_i16` + - `armral_svd_cf32` + - `armral_svd_cf32_noalloc` + - `armral_svd_cf32_noalloc_buffer_size` + +- Added the CMake option `ARMRAL_ENABLE_WEXTRA` to add the compiler flag +`-Wextra` when building the library and tests. #### Performance improvements -Describes any features or components whose performance has improved in -the current release compared with the previous release. +This section describes any features or components with improved performance. + +- Performance improvements for the following routines: -- Performance improvements for SVE2 implementations of the following routines: + - `armral_cmplx_matmul_f32`. For complex 32-bit floating point matrix + multiplication, we recommend that you use this function for all cases. + This function calls existing optimized special cases with minimal overhead + and has new optimizations for larger cases. - `armral_turbo_decode_block` and `armral_turbo_decode_block_noalloc`. These functions now operate - internally on 16-bit floating point values rather than 32-bit + internally on 16-bit integer values rather than 16-bit or 32-bit floating point values. - - `armral_ldpc_encode_block` and - `armral_ldpc_encode_block_noalloc`. - #### Changes to simulation programs -Describes any changes, new features or components added to the channel -simulation programs in this release. +This section describes any changes, new features or components added to the +channel simulation programs in this release. - There are no changes to the channel simulation programs in this release. #### Resolved issues -Describes any known issues resolved in the current release. +This section describes any known issues resolved in the current release. -- There are no known issues resolved in this release. +- Documentation is now installed by the `make install` target, if it has been +built. + +- Corrected documentation for `armral_cmplx_mat_inverse_batch_f32` and +`armral_cmplx_mat_inverse_batch_f32_pa` to clarify that these functions have no +restriction on batch sizes. ### Known limitations -Describes any known limitations of the current release. +This section describes any known limitations of the current release. - There are no known limitations in this release. ## Support If you have any issues with the installation, content, or use of this -release, raise a question on the Developer Community Forum: - - - +release, raise a question on the +[Developer Community Forum]( + GCC compiler on the + [Arm Developer website](https://developer.arm.com/tools-and-software/open-source-software/developer-tools/gnu-toolchain/gnu-a/downloads). The variant to use for an AArch64 GNU/Linux target is `aarch64-none-linux-gnu`. @@ -187,28 +241,24 @@ To build or run Arm RAN Acceleration Library you will need: Additionally: - To run the benchmarks, you must have the Linux utility tool `perf` installed - and a recent version of Python 3. Arm RAN Acceleration Library has been tested - with Python 3.8.5. + and a recent version of Python 3. ArmRAL has been tested with Python 3.8.5. - To build a local version of the documentation, you must have Doxygen - installed. Arm RAN Acceleration Library has been tested with Doxygen version - 1.8.13. + installed. ArmRAL has been tested with Doxygen version 1.8.13. - To generate code coverage HTML pages, you must have `gcovr` installed. The library has been tested with `gcovr` version 4.2. -**Note:** Arm RAN Acceleration Library runs on AArch64 cores, however -to use the convolutional encoder, CRC, and sequence generator -functions you must run on a core that supports the AArch64 PMULL -extension. If your machine supports the PMULL extension, `pmull` is -listed under the "Features" list given in the `/proc/cpuinfo` file. +**Note:** ArmRAL runs on AArch64 cores, however to use the convolutional +encoder, CRC, and sequence generator functions you must run on a core that +supports the AArch64 PMULL extension. If your machine supports the PMULL +extension, `pmull` is listed under the "Features" list given in the +`/proc/cpuinfo` file. ## Release history -A full release history (with release notes) for Arm RAN Acceleration Library -is available on the Arm Developer website: - - +A full release history (with release notes) for ArmRAL is available on the +[Arm Developer website](https://developer.arm.com/downloads/-/arm-ran-acceleration-library/previous-releases-of-the-arm-ran-acceleration-library>). ## Conventions @@ -221,7 +271,7 @@ with definitions for those terms. The Arm Glossary does not contain terms that are industry standard unless the Arm meaning differs from the generally accepted meaning. -See the Arm Glossary for more information: . +See the [Arm Glossary](https://developer.arm.com/glossary) for more information. ## Non-Confidential Proprietary Notice diff --git a/THIRD_PARTY_LICENSES.md b/THIRD_PARTY_LICENSES.md index 1db67f8..35e7c17 100644 --- a/THIRD_PARTY_LICENSES.md +++ b/THIRD_PARTY_LICENSES.md @@ -1,8 +1,5 @@ This file lists the package level copyright and license information for third party software included in this release of 'Arm RAN Acceleration Library'. -Refer to the License Agreement (End User License Agreement (EULA)) that -accompanies this release of 'Arm RAN Acceleration Library' for terms and -conditions relating to your use of such third party software. The information is grouped into two sections. The first section lists out details of third party software projects, including names of the applicable diff --git a/bench/BasicMathFun/MatrixInv/Batch/GeneralMatInv/NonPA/main.cpp b/bench/BasicMathFun/MatrixInv/Batch/GeneralMatInv/NonPA/main.cpp index 8375897..761df3c 100644 --- a/bench/BasicMathFun/MatrixInv/Batch/GeneralMatInv/NonPA/main.cpp +++ b/bench/BasicMathFun/MatrixInv/Batch/GeneralMatInv/NonPA/main.cpp @@ -17,7 +17,7 @@ void run_general_batch_matinv_perf(uint32_t num_mats, uint32_t dim, "%u\n", dim, num_reps); - const auto a = gen_invertible_matrix_batch(num_mats, dim); + const auto a = armral::utils::gen_invertible_matrix_batch(num_mats, dim); auto res = std::vector(num_mats * dim * dim); const auto *a_ptr = a.data(); diff --git a/bench/BasicMathFun/MatrixInv/Batch/GeneralMatInv/PA/main.cpp b/bench/BasicMathFun/MatrixInv/Batch/GeneralMatInv/PA/main.cpp index 64a2b4f..d2cf451 100644 --- a/bench/BasicMathFun/MatrixInv/Batch/GeneralMatInv/PA/main.cpp +++ b/bench/BasicMathFun/MatrixInv/Batch/GeneralMatInv/PA/main.cpp @@ -17,7 +17,7 @@ void run_general_batch_pa_matinv_perf(uint32_t num_mats, uint32_t dim, "%u\n", dim, num_reps); - const auto a = gen_invertible_matrix_batch(num_mats, dim); + const auto a = armral::utils::gen_invertible_matrix_batch(num_mats, dim); auto res = std::vector(num_mats * dim * dim); std::vector a_pa(dim * dim); @@ -50,20 +50,20 @@ int main(int argc, char **argv) { auto num_reps = (uint32_t)atoi(argv[3]); if (dim != 2 && dim != 3 && dim != 4) { - fprintf(stderr, "unsupported matrix dimension: %d", dim); + fprintf(stderr, "unsupported matrix dimension: %u", dim); exit(EXIT_FAILURE); } if (dim >= 3 && num_mats % 4 != 0) { fprintf(stderr, - "unsupported batch size for matrix dimension >= 3: %d %% 4 != 0", + "unsupported batch size for matrix dimension >= 3: %u %% 4 != 0", num_mats); exit(EXIT_FAILURE); } if (dim == 2 && num_mats % 2 != 0) { fprintf(stderr, - "unsupported batch size for matrix dimension = 2: %d %% 2 != 0", + "unsupported batch size for matrix dimension = 2: %u %% 2 != 0", num_mats); exit(EXIT_FAILURE); } diff --git a/bench/BasicMathFun/MatrixInv/Batch/HermitianMatInv/NonPA/main.cpp b/bench/BasicMathFun/MatrixInv/Batch/HermitianMatInv/NonPA/main.cpp index 8d0e972..25f0857 100644 --- a/bench/BasicMathFun/MatrixInv/Batch/HermitianMatInv/NonPA/main.cpp +++ b/bench/BasicMathFun/MatrixInv/Batch/HermitianMatInv/NonPA/main.cpp @@ -17,8 +17,8 @@ void run_hermitian_batch_matinv_perf(uint32_t num_mats, uint32_t dim, "%u\n", dim, num_reps); - const auto a = - gen_hermitian_matrix_batch(num_mats, dim, false, 1.0, 1.0, true); + const auto a = armral::utils::gen_hermitian_matrix_batch(num_mats, dim, false, + 1.0, 1.0, true); auto res = std::vector(num_mats * dim * dim); const auto *a_ptr = a.data(); @@ -41,20 +41,20 @@ int main(int argc, char **argv) { auto num_reps = (uint32_t)atoi(argv[3]); if (dim != 2 && dim != 3 && dim != 4) { - fprintf(stderr, "unsupported matrix dimension: %d", dim); + fprintf(stderr, "unsupported matrix dimension: %u", dim); exit(EXIT_FAILURE); } if (dim >= 3 && num_mats % 4 != 0) { fprintf(stderr, - "unsupported batch size for matrix dimension >= 3: %d %% 4 != 0", + "unsupported batch size for matrix dimension >= 3: %u %% 4 != 0", num_mats); exit(EXIT_FAILURE); } if (dim == 2 && num_mats % 2 != 0) { fprintf(stderr, - "unsupported batch size for matrix dimension = 2: %d %% 2 != 0", + "unsupported batch size for matrix dimension = 2: %u %% 2 != 0", num_mats); exit(EXIT_FAILURE); } diff --git a/bench/BasicMathFun/MatrixInv/Batch/HermitianMatInv/PA/main.cpp b/bench/BasicMathFun/MatrixInv/Batch/HermitianMatInv/PA/main.cpp index cd8a7a6..f6bb3a6 100644 --- a/bench/BasicMathFun/MatrixInv/Batch/HermitianMatInv/PA/main.cpp +++ b/bench/BasicMathFun/MatrixInv/Batch/HermitianMatInv/PA/main.cpp @@ -17,8 +17,8 @@ void run_hermitian_batch_pa_matinv_perf(uint32_t num_mats, uint32_t dim, "%u\n", dim, num_reps); - const auto a = - gen_hermitian_matrix_batch(num_mats, dim, false, 1.0, 1.0, true); + const auto a = armral::utils::gen_hermitian_matrix_batch(num_mats, dim, false, + 1.0, 1.0, true); auto res = std::vector(num_mats * dim * dim); std::vector a_pa(dim * dim); @@ -52,20 +52,20 @@ int main(int argc, char **argv) { auto num_reps = (uint32_t)atoi(argv[3]); if (dim != 2 && dim != 3 && dim != 4) { - fprintf(stderr, "unsupported matrix dimension: %d", dim); + fprintf(stderr, "unsupported matrix dimension: %u", dim); exit(EXIT_FAILURE); } if (dim >= 3 && num_mats % 4 != 0) { fprintf(stderr, - "unsupported batch size for matrix dimension >= 3: %d %% 4 != 0", + "unsupported batch size for matrix dimension >= 3: %u %% 4 != 0", num_mats); exit(EXIT_FAILURE); } if (dim == 2 && num_mats % 2 != 0) { fprintf(stderr, - "unsupported batch size for matrix dimension = 2: %d %% 2 != 0", + "unsupported batch size for matrix dimension = 2: %u %% 2 != 0", num_mats); exit(EXIT_FAILURE); } diff --git a/bench/BasicMathFun/MatrixInv/Single/GeneralMatInv/main.cpp b/bench/BasicMathFun/MatrixInv/Single/GeneralMatInv/main.cpp index 2509ad5..dda075a 100644 --- a/bench/BasicMathFun/MatrixInv/Single/GeneralMatInv/main.cpp +++ b/bench/BasicMathFun/MatrixInv/Single/GeneralMatInv/main.cpp @@ -17,7 +17,7 @@ void run_general_matinv_perf(uint32_t dim, uint32_t num_reps) { "iterations = %u\n", dim, num_reps); - const auto a = gen_invertible_matrix(dim); + const auto a = armral::utils::gen_invertible_matrix(dim); auto res = std::vector(dim * dim); const auto *a_ptr = a.data(); diff --git a/bench/BasicMathFun/MatrixInv/Single/HermitianMatInv/main.cpp b/bench/BasicMathFun/MatrixInv/Single/HermitianMatInv/main.cpp index de9d111..0082492 100644 --- a/bench/BasicMathFun/MatrixInv/Single/HermitianMatInv/main.cpp +++ b/bench/BasicMathFun/MatrixInv/Single/HermitianMatInv/main.cpp @@ -17,7 +17,8 @@ void run_hermitian_matinv_perf(uint32_t dim, uint32_t num_reps) { "iterations = %u\n", dim, num_reps); - const auto a = gen_hermitian_matrix(dim, false, 1.0, 1.0, true); + const auto a = + armral::utils::gen_hermitian_matrix(dim, false, 1.0, 1.0, true); auto res = std::vector(dim * dim); const auto *a_ptr = a.data(); diff --git a/bench/BasicMathFun/MatrixMult/Single/MatrixMult16/32b/main.cpp b/bench/BasicMathFun/MatrixMult/Single/MatrixMult16/32b/main.cpp index 68baf3a..48edaf7 100644 --- a/bench/BasicMathFun/MatrixMult/Single/MatrixMult16/32b/main.cpp +++ b/bench/BasicMathFun/MatrixMult/Single/MatrixMult16/32b/main.cpp @@ -24,7 +24,7 @@ void run_matmul_i16_32b_perf(uint32_t dim, uint32_t num_reps) { auto *c_ptr = c.data(); for (uint32_t i = 0; i < num_reps; ++i) { - armral_cmplx_mat_mult_i16_32bit(dim, dim, dim, a_ptr, b_ptr, c_ptr); + armral_cmplx_matmul_i16_32bit(dim, dim, dim, a_ptr, b_ptr, c_ptr); } } diff --git a/bench/BasicMathFun/MatrixMult/Single/MatrixMult16/64b/main.cpp b/bench/BasicMathFun/MatrixMult/Single/MatrixMult16/64b/main.cpp index 098fc26..57faa8e 100644 --- a/bench/BasicMathFun/MatrixMult/Single/MatrixMult16/64b/main.cpp +++ b/bench/BasicMathFun/MatrixMult/Single/MatrixMult16/64b/main.cpp @@ -24,7 +24,7 @@ void run_matmul_i16_64b_perf(uint32_t dim, uint32_t num_reps) { auto *c_ptr = c.data(); for (uint32_t i = 0; i < num_reps; ++i) { - armral_cmplx_mat_mult_i16(dim, dim, dim, a_ptr, b_ptr, c_ptr); + armral_cmplx_matmul_i16(dim, dim, dim, a_ptr, b_ptr, c_ptr); } } diff --git a/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/2x2/IQ/main.cpp b/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/2x2/IQ/main.cpp index d3ceec3..fab257f 100644 --- a/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/2x2/IQ/main.cpp +++ b/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/2x2/IQ/main.cpp @@ -12,12 +12,12 @@ void run_matmul_f32_2x2_iq_perf(uint32_t num_reps) { const auto a = std::vector(4); const auto b = std::vector(4); const auto c = std::vector(4); - const auto a_re = unpack_real_cf32(a); - const auto a_im = unpack_imag_cf32(a); - const auto b_re = unpack_real_cf32(b); - const auto b_im = unpack_imag_cf32(b); - auto c_re = unpack_real_cf32(c); - auto c_im = unpack_imag_cf32(c); + const auto a_re = armral::utils::unpack_real_cf32(a); + const auto a_im = armral::utils::unpack_imag_cf32(a); + const auto b_re = armral::utils::unpack_real_cf32(b); + const auto b_im = armral::utils::unpack_imag_cf32(b); + auto c_re = armral::utils::unpack_real_cf32(c); + auto c_im = armral::utils::unpack_imag_cf32(c); const auto *a_re_ptr = a_re.data(); const auto *a_im_ptr = a_im.data(); diff --git a/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/4x4/IQ/main.cpp b/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/4x4/IQ/main.cpp index abf84cd..f0af1d8 100644 --- a/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/4x4/IQ/main.cpp +++ b/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/4x4/IQ/main.cpp @@ -12,12 +12,12 @@ void run_matmul_f32_4x4_iq_perf(uint32_t num_reps) { const auto a = std::vector(16); const auto b = std::vector(16); const auto c = std::vector(16); - const auto a_re = unpack_real_cf32(a); - const auto a_im = unpack_imag_cf32(a); - const auto b_re = unpack_real_cf32(b); - const auto b_im = unpack_imag_cf32(b); - auto c_re = unpack_real_cf32(c); - auto c_im = unpack_imag_cf32(c); + const auto a_re = armral::utils::unpack_real_cf32(a); + const auto a_im = armral::utils::unpack_imag_cf32(a); + const auto b_re = armral::utils::unpack_real_cf32(b); + const auto b_im = armral::utils::unpack_imag_cf32(b); + auto c_re = armral::utils::unpack_real_cf32(c); + auto c_im = armral::utils::unpack_imag_cf32(c); const auto *a_re_ptr = a_re.data(); const auto *a_im_ptr = a_im.data(); diff --git a/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/general/bench.py b/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/general/bench.py index f0c6ae0..87b9058 100755 --- a/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/general/bench.py +++ b/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/general/bench.py @@ -11,25 +11,24 @@ import os def get_path(x): return x if Path(x).is_file() else os.path.join("armral", x) -exe_name = get_path("bench_matrix_mult_f32_general") +exe_name = get_path("bench_matmul_f32_general") j = { "exe_name": exe_name, "cases": [] } -reps = 300000 -lenArr = [2, 4, 8, 16] +full_reps = 300000 +lenArr = [2, 4, 8, 16, 128] mArr = [4, 8] -nArr = [32, 64, 128, 256] -kArr = [2, 3, 4] +nArr = [2, 3, 4] +kArr = [32, 64, 128, 256] for (m, n, k) in itertools.chain(zip(lenArr, lenArr, lenArr), itertools.product(mArr, nArr, kArr)): - name = "matmul_f32_{}".format(m) - if m != n: - name = name + "_{}_{}".format(n, k) + combined_size = m + n + k + reps = full_reps // combined_size if combined_size > 300 else full_reps case = { - "name": name, + "name": "matmul_f32_{}_{}_{}".format(m, n, k), "args": "{} {} {}".format(m, n, k), "reps": reps } diff --git a/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/general/main.cpp b/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/general/main.cpp index a2f6657..695fbc2 100644 --- a/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/general/main.cpp +++ b/bench/BasicMathFun/MatrixMult/Single/MatrixMult32/general/main.cpp @@ -15,18 +15,18 @@ void run_matmul_f32_general_perf(uint32_t dim1, uint32_t dim2, uint32_t dim3, printf("[MATMUL f32 GENERAL] - dimensions %u-by-%u x %u-by-%u - " "number of iterations = " "%u\n", - dim1, dim2, dim2, dim3, num_reps); + dim1, dim3, dim3, dim2, num_reps); - const auto a = std::vector(dim1 * dim2); - const auto b = std::vector(dim2 * dim3); - auto c = std::vector(dim1 * dim3); + const auto a = std::vector(dim1 * dim3); + const auto b = std::vector(dim3 * dim2); + auto c = std::vector(dim1 * dim2); const auto *a_ptr = a.data(); const auto *b_ptr = b.data(); auto *c_ptr = c.data(); for (uint32_t i = 0; i < num_reps; ++i) { - armral_cmplx_mat_mult_f32(dim1, dim2, dim3, a_ptr, b_ptr, c_ptr); + armral_cmplx_matmul_f32(dim1, dim2, dim3, a_ptr, b_ptr, c_ptr); } } @@ -34,9 +34,10 @@ void run_matmul_f32_general_perf(uint32_t dim1, uint32_t dim2, uint32_t dim3, int main(int argc, char **argv) { if (argc != 5) { - // dim1 - The number of rows in the first input matrix - // dim2 - The number of columns in the first input matrix - // dim3 - The number of columns in the second input matrix + // dim1 - The number of rows in the output / first input matrix + // dim2 - The number of columns in the output / second input matrix + // dim3 - The number of columns in the first input matrix / rows in the + // second input matrix // nreps - The number of times to repeat the function fprintf(stderr, "usage: %s dim1 dim2 dim3 nreps\n", argv[0]); exit(EXIT_FAILURE); diff --git a/bench/BasicMathFun/MatrixMult/Single/MatrixMultAAH32/main.cpp b/bench/BasicMathFun/MatrixMult/Single/MatrixMultAAH32/main.cpp index 949c879..55e290a 100644 --- a/bench/BasicMathFun/MatrixMult/Single/MatrixMultAAH32/main.cpp +++ b/bench/BasicMathFun/MatrixMult/Single/MatrixMultAAH32/main.cpp @@ -23,7 +23,7 @@ void run_matmul_aah_perf(uint32_t dim1, uint32_t dim2, uint32_t num_reps) { auto *c_ptr = c.data(); for (uint32_t i = 0; i < num_reps; ++i) { - armral_cmplx_mat_mult_aah_f32(dim1, dim2, a_ptr, c_ptr); + armral_cmplx_matmul_aah_f32(dim1, dim2, a_ptr, c_ptr); } } diff --git a/bench/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/bench.py b/bench/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/bench.py index 71916e5..9bb9862 100755 --- a/bench/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/bench.py +++ b/bench/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/bench.py @@ -24,23 +24,23 @@ reps = 300000 # That means we know B will be a square matrix (and due to using the matrix # inverse code) will have a size of 2x2, 3x3, 4x4, 8x8, or 16x16. So benchmark # those cases: -mkArrInverseSizes = [2, 3, 4, 8, 16] -nArr32_256 = [32, 64, 128, 256] +mArr32_256 = [32, 64, 128, 256] +nkArrInverseSizes = [2, 3, 4, 8, 16] -for (mk, n) in itertools.product(mkArrInverseSizes, nArr32_256): +for (m, nk) in itertools.product(mArr32_256, nkArrInverseSizes): case = { - "name": f"matmulahb_f32_{mk}_{n}_{mk}", - "args": f"{mk} {n} {mk}", + "name": f"matmulahb_f32_{m}_{nk}_{nk}", + "args": f"{m} {nk} {nk}", "reps": reps } j["cases"].append(case) # Try a (smaller) range of non-square B matrices too: -mkArr2_16 = [2, 4, 16] -nArr32_64 = [32, 64] +mArr32_64 = [32, 64] +nkArr2_16 = [2, 4, 16] -for (m, n, k) in itertools.product(mkArr2_16, nArr32_64, mkArr2_16): +for (m, n, k) in itertools.product(mArr32_64, nkArr2_16, nkArr2_16): if m != k: case = { "name": f"matmulahb_f32_{m}_{n}_{k}", diff --git a/bench/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/main.cpp b/bench/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/main.cpp index 873a3e3..041939c 100644 --- a/bench/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/main.cpp +++ b/bench/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/main.cpp @@ -13,19 +13,19 @@ namespace { void run_matmul_ahb_cf32_perf(uint16_t m, uint16_t n, uint16_t k, int num_reps) { printf("[MATMUL AHB] - dimensions (%u-by-%u)^H x " - "%u-by-%u, number of iterations = %u\n", - m, n, m, k, num_reps); + "%u-by-%u, number of iterations = %d\n", + k, m, k, n, num_reps); - const std::vector a(m * n); - const std::vector b(m * k); - std::vector c(n * k); + const std::vector a(k * m); + const std::vector b(k * n); + std::vector c(m * n); const auto *a_ptr = a.data(); const auto *b_ptr = b.data(); auto *c_ptr = c.data(); for (int i = 0; i < num_reps; ++i) { - armral_cmplx_mat_mult_ahb_f32(m, n, k, a_ptr, b_ptr, c_ptr); + armral_cmplx_matmul_ahb_f32(m, n, k, a_ptr, b_ptr, c_ptr); } } @@ -33,9 +33,9 @@ void run_matmul_ahb_cf32_perf(uint16_t m, uint16_t n, uint16_t k, int main(int argc, const char *argv[]) { if (argc != 5) { - // m - The number of rows in the input matrices - // n - The number of columns in input matrix A - // k - The number of columns in input matrix B + // m - The number of columns in A and rows in C + // n - The number of columns in B and C + // k - The number of rows in A and B // nreps - The number of times to repeat the function fprintf(stderr, "usage: %s m n k nreps\n", argv[0]); exit(EXIT_FAILURE); diff --git a/bench/DuRuInterface/MuLaw/Compression/14bit/main.cpp b/bench/DuRuInterface/MuLaw/Compression/14bit/main.cpp index 0cd606c..615c0a5 100644 --- a/bench/DuRuInterface/MuLaw/Compression/14bit/main.cpp +++ b/bench/DuRuInterface/MuLaw/Compression/14bit/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/MuLaw/Compression/8bit/main.cpp b/bench/DuRuInterface/MuLaw/Compression/8bit/main.cpp index 8a489e1..ffdc044 100644 --- a/bench/DuRuInterface/MuLaw/Compression/8bit/main.cpp +++ b/bench/DuRuInterface/MuLaw/Compression/8bit/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/MuLaw/Compression/9bit/main.cpp b/bench/DuRuInterface/MuLaw/Compression/9bit/main.cpp index f88240e..a1ef410 100644 --- a/bench/DuRuInterface/MuLaw/Compression/9bit/main.cpp +++ b/bench/DuRuInterface/MuLaw/Compression/9bit/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/MuLaw/Decompression/14bit/main.cpp b/bench/DuRuInterface/MuLaw/Decompression/14bit/main.cpp index d0b3498..e8b9dfd 100644 --- a/bench/DuRuInterface/MuLaw/Decompression/14bit/main.cpp +++ b/bench/DuRuInterface/MuLaw/Decompression/14bit/main.cpp @@ -41,7 +41,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = nullptr; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/MuLaw/Decompression/8bit/main.cpp b/bench/DuRuInterface/MuLaw/Decompression/8bit/main.cpp index 2a50c5d..61104eb 100644 --- a/bench/DuRuInterface/MuLaw/Decompression/8bit/main.cpp +++ b/bench/DuRuInterface/MuLaw/Decompression/8bit/main.cpp @@ -41,7 +41,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = nullptr; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/MuLaw/Decompression/9bit/main.cpp b/bench/DuRuInterface/MuLaw/Decompression/9bit/main.cpp index c3b1b85..ae484c9 100644 --- a/bench/DuRuInterface/MuLaw/Decompression/9bit/main.cpp +++ b/bench/DuRuInterface/MuLaw/Decompression/9bit/main.cpp @@ -41,7 +41,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = nullptr; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockFloat/Compression/12bit/main.cpp b/bench/DuRuInterface/ORanBlockFloat/Compression/12bit/main.cpp index b4c34d7..e094bec 100644 --- a/bench/DuRuInterface/ORanBlockFloat/Compression/12bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockFloat/Compression/12bit/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockFloat/Compression/14bit/main.cpp b/bench/DuRuInterface/ORanBlockFloat/Compression/14bit/main.cpp index eb1cc6a..ecdcada 100644 --- a/bench/DuRuInterface/ORanBlockFloat/Compression/14bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockFloat/Compression/14bit/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockFloat/Compression/8bit/main.cpp b/bench/DuRuInterface/ORanBlockFloat/Compression/8bit/main.cpp index be88dc1..bd8f377 100644 --- a/bench/DuRuInterface/ORanBlockFloat/Compression/8bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockFloat/Compression/8bit/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockFloat/Compression/9bit/main.cpp b/bench/DuRuInterface/ORanBlockFloat/Compression/9bit/main.cpp index a253642..3734cbb 100644 --- a/bench/DuRuInterface/ORanBlockFloat/Compression/9bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockFloat/Compression/9bit/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockFloat/Decompression/12bit/main.cpp b/bench/DuRuInterface/ORanBlockFloat/Decompression/12bit/main.cpp index fc9b8a8..ef13981 100644 --- a/bench/DuRuInterface/ORanBlockFloat/Decompression/12bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockFloat/Decompression/12bit/main.cpp @@ -41,7 +41,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = nullptr; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockFloat/Decompression/14bit/main.cpp b/bench/DuRuInterface/ORanBlockFloat/Decompression/14bit/main.cpp index 285eaa6..de72229 100644 --- a/bench/DuRuInterface/ORanBlockFloat/Decompression/14bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockFloat/Decompression/14bit/main.cpp @@ -41,7 +41,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = nullptr; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockFloat/Decompression/8bit/main.cpp b/bench/DuRuInterface/ORanBlockFloat/Decompression/8bit/main.cpp index 8ba5be5..954cbf2 100644 --- a/bench/DuRuInterface/ORanBlockFloat/Decompression/8bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockFloat/Decompression/8bit/main.cpp @@ -41,7 +41,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = nullptr; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockFloat/Decompression/9bit/main.cpp b/bench/DuRuInterface/ORanBlockFloat/Decompression/9bit/main.cpp index 8071995..5b5722a 100644 --- a/bench/DuRuInterface/ORanBlockFloat/Decompression/9bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockFloat/Decompression/9bit/main.cpp @@ -41,7 +41,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = nullptr; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockScaling/Compression/14bit/main.cpp b/bench/DuRuInterface/ORanBlockScaling/Compression/14bit/main.cpp index 754710c..239c2e7 100644 --- a/bench/DuRuInterface/ORanBlockScaling/Compression/14bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockScaling/Compression/14bit/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockScaling/Compression/8bit/main.cpp b/bench/DuRuInterface/ORanBlockScaling/Compression/8bit/main.cpp index 73958ae..3b90304 100644 --- a/bench/DuRuInterface/ORanBlockScaling/Compression/8bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockScaling/Compression/8bit/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockScaling/Compression/9bit/main.cpp b/bench/DuRuInterface/ORanBlockScaling/Compression/9bit/main.cpp index b3436d7..14b02de 100644 --- a/bench/DuRuInterface/ORanBlockScaling/Compression/9bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockScaling/Compression/9bit/main.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockScaling/Decompression/14bit/main.cpp b/bench/DuRuInterface/ORanBlockScaling/Decompression/14bit/main.cpp index 6e07821..9193570 100644 --- a/bench/DuRuInterface/ORanBlockScaling/Decompression/14bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockScaling/Decompression/14bit/main.cpp @@ -41,7 +41,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockScaling/Decompression/8bit/main.cpp b/bench/DuRuInterface/ORanBlockScaling/Decompression/8bit/main.cpp index f5bedca..f47cb54 100644 --- a/bench/DuRuInterface/ORanBlockScaling/Decompression/8bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockScaling/Decompression/8bit/main.cpp @@ -41,7 +41,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/DuRuInterface/ORanBlockScaling/Decompression/9bit/main.cpp b/bench/DuRuInterface/ORanBlockScaling/Decompression/9bit/main.cpp index 2c6aa9e..4d3c7d1 100644 --- a/bench/DuRuInterface/ORanBlockScaling/Decompression/9bit/main.cpp +++ b/bench/DuRuInterface/ORanBlockScaling/Decompression/9bit/main.cpp @@ -41,7 +41,7 @@ int main(int argc, char **argv) { const uint32_t scale_arg = atoi(argv[2]); const uint32_t num_reps = atoi(argv[3]); - armral_cmplx_int16_t *scale_ptr = NULL; + const armral_cmplx_int16_t *scale_ptr = NULL; armral_cmplx_int16_t scale; if (scale_arg != 0) { scale.re = scale_arg; diff --git a/bench/LowerPHY/SeqGenerator/main.cpp b/bench/LowerPHY/SeqGenerator/main.cpp index 259e102..bd83fe2 100644 --- a/bench/LowerPHY/SeqGenerator/main.cpp +++ b/bench/LowerPHY/SeqGenerator/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" static void run_sequence_generator_perf(uint32_t len, int num_reps) { uint32_t len_bytes = (((uint64_t)len) + 7) / 8; diff --git a/bench/UpperPHY/LDPC/Decoding/main.cpp b/bench/UpperPHY/LDPC/Decoding/main.cpp index 85acce3..f1fa2af 100755 --- a/bench/UpperPHY/LDPC/Decoding/main.cpp +++ b/bench/UpperPHY/LDPC/Decoding/main.cpp @@ -4,7 +4,7 @@ */ #include "armral.h" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include "ldpc_coding.hpp" #include "utils/allocators.hpp" diff --git a/bench/UpperPHY/LDPC/Encoding/main.cpp b/bench/UpperPHY/LDPC/Encoding/main.cpp index cbc6cbd..864a8c5 100644 --- a/bench/UpperPHY/LDPC/Encoding/main.cpp +++ b/bench/UpperPHY/LDPC/Encoding/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include "ldpc_coding.hpp" namespace { diff --git a/bench/UpperPHY/Turbo/Decoding/main.cpp b/bench/UpperPHY/Turbo/Decoding/main.cpp index 1362f23..4f55145 100644 --- a/bench/UpperPHY/Turbo/Decoding/main.cpp +++ b/bench/UpperPHY/Turbo/Decoding/main.cpp @@ -72,7 +72,7 @@ int main(int argc, char **argv) { if (armral::turbo::valid_num_bits(num_bits)) { run_turbo_decoding_perf(num_prbs, num_bits, num_reps); } else { - printf("ERROR: Unsupported number of bits (%d) specified for turbo " + printf("ERROR: Unsupported number of bits (%u) specified for turbo " "decoding.\n", num_bits); exit(EXIT_FAILURE); diff --git a/bench/UpperPHY/Turbo/RateMatching/main.cpp b/bench/UpperPHY/Turbo/RateMatching/main.cpp index 3148fa2..d9535f0 100644 --- a/bench/UpperPHY/Turbo/RateMatching/main.cpp +++ b/bench/UpperPHY/Turbo/RateMatching/main.cpp @@ -18,7 +18,7 @@ void run_turbo_rate_matching_perf(uint32_t d, uint32_t e, uint32_t rv, "= %u\n", d, e, rv, num_reps); - assert(rv >= 0 && rv <= 3); + assert(rv <= 3); std::vector encoded_bits_0((d + 7) / 8); std::vector encoded_bits_1((d + 7) / 8); std::vector encoded_bits_2((d + 7) / 8); diff --git a/bench/UpperPHY/Turbo/RateRecovery/main.cpp b/bench/UpperPHY/Turbo/RateRecovery/main.cpp index 38795a3..68a80c0 100644 --- a/bench/UpperPHY/Turbo/RateRecovery/main.cpp +++ b/bench/UpperPHY/Turbo/RateRecovery/main.cpp @@ -18,7 +18,7 @@ void run_turbo_rate_recovery_perf(uint32_t d, uint32_t e, uint32_t rv, "= %u\n", d, e, rv, num_reps); - assert(rv >= 0 && rv <= 3); + assert(rv <= 3); std::vector demodulated_llrs(e); std::vector recovered_llrs_0(d); std::vector recovered_llrs_1(d); diff --git a/cmake_uninstall.cmake.in b/cmake_uninstall.cmake.in index 00ba6c2..0bceb6b 100644 --- a/cmake_uninstall.cmake.in +++ b/cmake_uninstall.cmake.in @@ -9,7 +9,7 @@ set(INSTALL_BASE_DIR "") # Try and read the base directory of the install. This is used to try and # figure out if the install is in a shared location, or in a user's home # directory. If the installation is not in the home directory, we don't want to -# force remove the installed files in case it causes unexpected behaviour for +# force remove the installed files in case it causes unexpected behavior for # others on the system foreach(file ${files}) if(file MATCHES "armral.a") diff --git a/docs/examples.md b/docs/examples.md index c789c0d..2ce0c48 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -17,7 +17,7 @@ Acceleration Library (ArmRAL). To build the library, use: - git clone -b armral-24.04 https://git.gitlab.arm.com/networking/ral.git + git clone -b armral-24.07 https://git.gitlab.arm.com/networking/ral.git mkdir ral/build cd ral/build cmake .. diff --git a/docs/frontmatter.md b/docs/frontmatter.md index fe55907..db88d5d 100644 --- a/docs/frontmatter.md +++ b/docs/frontmatter.md @@ -49,7 +49,7 @@ supplier and give: If you have any comments on content, send an e-mail to errata@arm.com. Give: * The title Arm RAN Acceleration Library Reference Guide. -* The number 102249_2404_00_en. +* The number 102249_2407_00_en. * If applicable, the relevant page number(s) to which your comments refer. * A concise explanation of your comments. @@ -160,3 +160,4 @@ Issue | Date | Confidentiality | Change 2310-00 | 06 October 2023 | Non-Confidential | Update for Arm RAN Acceleration Library v23.10 2401-00 | 19 January 2024 | Non-Confidential | Update for Arm RAN Acceleration Library v24.01 2404-00 | 19 April 2024 | Non-Confidential | Update for Arm RAN Acceleration Library v24.04 +2407-00 | 18 July 2024 | Non-Confidential | Update for Arm RAN Acceleration Library v24.07 diff --git a/include/armral.h b/include/armral.h index 60592f0..d47a5b8 100644 --- a/include/armral.h +++ b/include/armral.h @@ -23,7 +23,7 @@ * Functions are provided for working with matrices, including: * + Matrix-vector multiplication for 16-bit integer datatypes. * + Matrix-matrix multiplication. Supports both 16-bit integer and 32-bit - * floating-point datatypes. In addition, the `solve` routines + * floating-point datatypes. In addition, the `solve` functions * support specifying a custom Q-format specifier for both input and output * matrices, instead of assuming that the input is in Q15 format. * + Matrix inversion. Supports the 32-bit floating-point datatype. @@ -255,8 +255,9 @@ typedef struct { * \brief Multiplies a complex vector by another complex vector and generates a * complex result. * - * The complex arrays have a total of `2*n` real values.
The - * vector multiplication algorithm is: + * The complex arrays have a total of `2*n` real values. + * + * The vector multiplication algorithm is: *
  * for (n = 0; n < numSamples; n++) {
  *     pDst[2n+0] = pSrcA[2n+0] * pSrcB[2n+0] - pSrcA[2n+1] * pSrcB[2n+1];
@@ -265,8 +266,9 @@ typedef struct {
  * 
*/ /** - * This algorithm performs the element-wise complex multiplication between two - * complex input sequences, `A` and `B`, of the same length, (`N`).
+ * This function performs the element-wise complex multiplication between two + * complex input sequences, `A` and `B`, of the same length, (`N`). + * * The implementation uses saturating arithmetic. Intermediate operations are * performed on 32-bit variables in Q31 format. To convert the final result back * into Q15 format, the final result is right-shifted and narrowed to 16 bits. @@ -301,14 +303,15 @@ typedef struct { * @param[out] c Points to the output vector. * @return An `armral_status` value that indicates success or failure. */ -armral_status armral_cmplx_vecmul_i16(int32_t n, const armral_cmplx_int16_t *a, +armral_status armral_cmplx_vecmul_i16(uint32_t n, const armral_cmplx_int16_t *a, const armral_cmplx_int16_t *b, armral_cmplx_int16_t *c); /** - * This algorithm performs the element-wise complex multiplication between two + * This function performs the element-wise complex multiplication between two * complex [I and Q separated] input sequences, `A` and `B`, of the same length - * (`N`).
+ * (`N`). + * * The implementation uses saturating arithmetic. Intermediate operations are * performed on 32-bit variables in Q31 format. To convert the final result back * into Q15 format, the final result is right-shifted and narrowed to 16 bits. @@ -341,14 +344,14 @@ armral_status armral_cmplx_vecmul_i16(int32_t n, const armral_cmplx_int16_t *a, * @param[out] c_im Points to the imaginary part of the output result. * @return An `armral_status` value that indicates success or failure. */ -armral_status armral_cmplx_vecmul_i16_2(int32_t n, const int16_t *a_re, +armral_status armral_cmplx_vecmul_i16_2(uint32_t n, const int16_t *a_re, const int16_t *a_im, const int16_t *b_re, const int16_t *b_im, int16_t *c_re, int16_t *c_im); /** - * This algorithm performs the element-wise complex multiplication between two + * This function performs the element-wise complex multiplication between two * complex input sequences, `A` and `B`, of the same length (`N`). * *
@@ -381,12 +384,12 @@ armral_status armral_cmplx_vecmul_i16_2(int32_t n, const int16_t *a_re,
  * @param[out]    c       Points to the output vector.
  * @return     An `armral_status` value that indicates success or failure.
  */
-armral_status armral_cmplx_vecmul_f32(int32_t n, const armral_cmplx_f32_t *a,
+armral_status armral_cmplx_vecmul_f32(uint32_t n, const armral_cmplx_f32_t *a,
                                       const armral_cmplx_f32_t *b,
                                       armral_cmplx_f32_t *c);
 
 /**
- * This algorithm performs the element-wise complex multiplication between two
+ * This function performs the element-wise complex multiplication between two
  * complex [I and Q separated] input sequences, `A` and `B`, of the same length
  * (`N`).
  *
@@ -418,7 +421,7 @@ armral_status armral_cmplx_vecmul_f32(int32_t n, const armral_cmplx_f32_t *a,
  * @param[out] c_im  Points to the imaginary part of the output result.
  * @return     An `armral_status` value that indicates success or failure.
  */
-armral_status armral_cmplx_vecmul_f32_2(int32_t n, const float32_t *a_re,
+armral_status armral_cmplx_vecmul_f32_2(uint32_t n, const float32_t *a_re,
                                         const float32_t *a_im,
                                         const float32_t *b_re,
                                         const float32_t *b_im, float32_t *c_re,
@@ -457,7 +460,7 @@ armral_status armral_cmplx_vecmul_f32_2(int32_t n, const float32_t *a_re,
  * 
*/ /** - * This algorithm computes the dot product between a pair of arrays of complex + * This function computes the dot product between a pair of arrays of complex * values. The arrays are multiplied element-by-element and then summed. Array * elements are assumed to be complex float32 and with interleaved real and * imaginary parts. @@ -468,13 +471,13 @@ armral_status armral_cmplx_vecmul_f32_2(int32_t n, const float32_t *a_re, * @param[out] p_src_c Points to the output complex vector. * @return An `armral_status` value that indicates success or failure. */ -armral_status armral_cmplx_vecdot_f32(int32_t n, +armral_status armral_cmplx_vecdot_f32(uint32_t n, const armral_cmplx_f32_t *p_src_a, const armral_cmplx_f32_t *p_src_b, armral_cmplx_f32_t *p_src_c); /** - * This algorithm computes the dot product between a pair of arrays of complex + * This function computes the dot product between a pair of arrays of complex * values. The arrays are multiplied element-by-element and then summed. Array * elements are assumed to be 32-bit floats, and separate arrays are used for * the real and imaginary parts of the input data. @@ -490,7 +493,7 @@ armral_status armral_cmplx_vecdot_f32(int32_t n, * @param[out] p_src_c_im Points to the imaginary part of the output result. * @return An `armral_status` value that indicates success or failure. */ -armral_status armral_cmplx_vecdot_f32_2(int32_t n, const float32_t *p_src_a_re, +armral_status armral_cmplx_vecdot_f32_2(uint32_t n, const float32_t *p_src_a_re, const float32_t *p_src_a_im, const float32_t *p_src_b_re, const float32_t *p_src_b_im, @@ -498,9 +501,10 @@ armral_status armral_cmplx_vecdot_f32_2(int32_t n, const float32_t *p_src_a_re, float32_t *p_src_c_im); /** - * This algorithm computes the dot product between a pair of arrays of complex + * This function computes the dot product between a pair of arrays of complex * values. The arrays are multiplied element-by-element and then summed. Array - * elements are assumed to be complex int16 in Q15 format and interleaved.
+ * elements are assumed to be complex int16 in Q15 format and interleaved. + * * To avoid overflow issues input values are internally extended to 32-bit * variables and all intermediate calculations results are stored in 64-bit * internal variables. To get the final result in Q15 and to avoid overflow, @@ -512,16 +516,17 @@ armral_status armral_cmplx_vecdot_f32_2(int32_t n, const float32_t *p_src_a_re, * @param[out] p_src_c Points to the output complex result. * @return An `armral_status` value that indicates success or failure. */ -armral_status armral_cmplx_vecdot_i16(int32_t n, +armral_status armral_cmplx_vecdot_i16(uint32_t n, const armral_cmplx_int16_t *p_src_a, const armral_cmplx_int16_t *p_src_b, armral_cmplx_int16_t *p_src_c); /** - * This algorithm computes the dot product between a pair of arrays of complex + * This function computes the dot product between a pair of arrays of complex * values. The arrays are multiplied element-by-element and then summed. Array * elements are assumed to be int16 in Q15 format and separate arrays are used - * for real parts and imaginary parts of the input data.
+ * for real parts and imaginary parts of the input data. + * * To avoid overflow issues input values are internally extended to 32-bit * variables and all intermediate calculations results are stored in 64-bit * internal variables. To get the final result in Q15 and to avoid overflow, @@ -536,7 +541,7 @@ armral_status armral_cmplx_vecdot_i16(int32_t n, * @param[out] p_src_c_im Points to the imag part of output complex result. * @return An `armral_status` value that indicates success or failure. */ -armral_status armral_cmplx_vecdot_i16_2(int32_t n, const int16_t *p_src_a_re, +armral_status armral_cmplx_vecdot_i16_2(uint32_t n, const int16_t *p_src_a_re, const int16_t *p_src_a_im, const int16_t *p_src_b_re, const int16_t *p_src_b_im, @@ -544,9 +549,10 @@ armral_status armral_cmplx_vecdot_i16_2(int32_t n, const int16_t *p_src_a_re, int16_t *p_src_c_im); /** - * This algorithm computes the dot product between a pair of arrays of complex + * This function computes the dot product between a pair of arrays of complex * values. The arrays are multiplied element-by-element and then summed. Array - * elements are assumed to be complex int16 in Q15 format and interleaved.
+ * elements are assumed to be complex int16 in Q15 format and interleaved. + * * All intermediate calculations results are stored in 32-bit internal * variables, saturating the value to prevent overflow. To get the final result * in Q15 and to avoid overflow, the accumulator narrows to 16 bits with @@ -558,17 +564,18 @@ armral_status armral_cmplx_vecdot_i16_2(int32_t n, const int16_t *p_src_a_re, * @param[out] p_src_c Points to the output complex result. * @return An `armral_status` value that indicates success or failure. */ -armral_status armral_cmplx_vecdot_i16_32bit(int32_t n, +armral_status armral_cmplx_vecdot_i16_32bit(uint32_t n, const armral_cmplx_int16_t *p_src_a, const armral_cmplx_int16_t *p_src_b, armral_cmplx_int16_t *p_src_c); /** - * This algorithm computes the dot product between a pair of arrays of complex - * values. The arrays are multiplied element-by-element and then summed.
+ * This function computes the dot product between a pair of arrays of complex + * values. The arrays are multiplied element-by-element and then summed. + * * Array elements are assumed to be int16 in Q15 format and separate arrays are - * used for both the real parts and imaginary parts of the input data.
All - * intermediate calculation results are stored in 32-bit internal variables, + * used for both the real parts and imaginary parts of the input data. + * All intermediate calculation results are stored in 32-bit internal variables, * saturating the value to prevent overflow. To get the final result in Q15 and * to avoid overflow, the accumulator narrows to 16 bits with saturation. * @@ -584,7 +591,7 @@ armral_status armral_cmplx_vecdot_i16_32bit(int32_t n, * @return An `armral_status` value that indicates success or failure. */ armral_status armral_cmplx_vecdot_i16_2_32bit( - int32_t n, const int16_t *p_src_a_re, const int16_t *p_src_a_im, + uint32_t n, const int16_t *p_src_a_re, const int16_t *p_src_a_im, const int16_t *p_src_b_re, const int16_t *p_src_b_im, int16_t *p_src_c_re, int16_t *p_src_c_im); @@ -594,7 +601,7 @@ armral_status armral_cmplx_vecdot_i16_2_32bit( * @ingroup groupMatrix */ /** - * @addtogroup cmplx_matrix_vector_mult Complex Matrix-Vector Multiplication + * @addtogroup cmplx_matrix_vector_mult Complex Matrix-Vector Multiplication * @{ * \brief Computes a matrix-by-vector multiplication, storing the result in a * destination vector. @@ -602,8 +609,8 @@ armral_status armral_cmplx_vecdot_i16_2_32bit( * The destination vector is only written to and can be uninitialized. */ /** - * This algorithm performs the multiplication `A x` for matrix `A` and vector - * `x`, and assumes that: + * This function performs the multiplication `y = A x` for matrix `A` and + * vectors `x` and `y`, and assumes that: * + Matrix and vector elements are complex int16 in Q15 format. * + Matrices are stored in memory in row-major order. * @@ -616,9 +623,9 @@ armral_status armral_cmplx_vecdot_i16_2_32bit( * the output vector `y`. * @param[in] n The number of columns in matrix `A` and the length * of the input vector `x`. - * @param[in] p_src_a Points to the input matrix. - * @param[in] p_src_x Points to the input vector. - * @param[out] p_dst Points to the output vector. + * @param[in] p_src_a Points to the input matrix `A`. + * @param[in] p_src_x Points to the input vector `x`. + * @param[out] p_dst Points to the output vector `y`. * @return An `armral_status` value that indicates success or failure. */ armral_status armral_cmplx_mat_vec_mult_i16(uint16_t m, uint16_t n, @@ -627,9 +634,10 @@ armral_status armral_cmplx_mat_vec_mult_i16(uint16_t m, uint16_t n, armral_cmplx_int16_t *p_dst); /** - * This algorithm performs matrix-vector multiplication for a batch of - * `M-by-N` matrices and length `N` input vectors. Each multiplication is of the - * form `A x` for a matrix `A` and vector `x`, and assumes that: + * This function performs matrix-vector multiplication for a batch of + * `M`-by-`N` matrices and length `N` input vectors. Each multiplication is of + * the form `y = A x` for a matrix `A` and vectors `x` and `y`, and assumes + * that: * + Matrix and vector elements are complex int16 in Q15 format. * + Matrices are stored in memory in row-major order. * @@ -664,9 +672,9 @@ armral_status armral_cmplx_mat_vec_mult_i16(uint16_t m, uint16_t n, * `y`. * @param[in] n The number of columns (`N`) in each matrix * `A` and the length of each input vector `x`. - * @param[in] p_src_a Points to the input matrix. - * @param[in] p_src_x Points to the input vector. - * @param[out] p_dst Points to the output vector. + * @param[in] p_src_a Points to the input matrix `A`. + * @param[in] p_src_x Points to the input vector `x`. + * @param[out] p_dst Points to the output vector `y`. * @return An `armral_status` value that indicates success or failure. */ armral_status armral_cmplx_mat_vec_mult_batch_i16( @@ -675,10 +683,11 @@ armral_status armral_cmplx_mat_vec_mult_batch_i16( armral_cmplx_int16_t *p_dst); /** - * This algorithm performs matrix-vector multiplication for a batch of - * `M-by-N` matrices and length `N` input vectors, utilizing a "pointer array" + * This function performs matrix-vector multiplication for a batch of + * `M`-by-`N` matrices and length `N` input vectors, utilizing a "pointer array" * storage layout for the input and output matrix batches. Each multiplication - * is of the form `A x` for a matrix `A` and vector `x`, and assumes that: + * is of the form `y = A x` for a matrix `A` and vectors `x` and `y`, and + * assumes that: * + Matrix and vector elements are complex int16 in Q15 format. * + Matrices are stored in memory in row-major order. * @@ -686,7 +695,7 @@ armral_status armral_cmplx_mat_vec_mult_batch_i16( * value of `p_srcs_a[i]` is a pointer to the i-th element of the first matrix * in the batch, as represented in row-major ordering, such that the i-th * element of the j-th matrix in the batch is located at `p_srcs_a[i][j]`. - * For example, the j-th matrix in a batch of `2-by-2` matrices is formed as: + * For example, the j-th matrix in a batch of 2-by-2 matrices is formed as: *
  *   p_srcs_a[0][j]  p_srcs_a[1][j]
  *   p_srcs_a[2][j]  p_srcs_a[3][j]
@@ -732,8 +741,8 @@ armral_status armral_cmplx_mat_vec_mult_batch_i16_pa(
     const armral_cmplx_int16_t **p_srcs_x, armral_cmplx_int16_t **p_dsts);
 
 /**
- * This algorithm performs the multiplication `A x` for matrix `A` and vector
- * `x`, and assumes that:
+ * This function performs the multiplication `y = A x` for matrix `A` and
+ * vectors `x` and `y`, and assumes that:
  * + Matrix and vector elements are complex int16 in Q15 format.
  * + Matrices are stored in memory in row-major order.
  *
@@ -746,9 +755,9 @@ armral_status armral_cmplx_mat_vec_mult_batch_i16_pa(
  *                           the output vector `y`.
  * @param[in]     n          The number of columns in matrix `A` and the length
  *                           of each input vector `x`.
- * @param[in]     p_src_a    Points to the input matrix.
- * @param[in]     p_src_x    Points to the input vector.
- * @param[out]    p_dst      Points to the output matrix.
+ * @param[in]     p_src_a    Points to the input matrix `A`.
+ * @param[in]     p_src_x    Points to the input vector `x`.
+ * @param[out]    p_dst      Points to the output matrix `y`.
  * @return     An `armral_status` value that indicates success or failure.
  */
 armral_status armral_cmplx_mat_vec_mult_i16_32bit(
@@ -756,9 +765,10 @@ armral_status armral_cmplx_mat_vec_mult_i16_32bit(
     const armral_cmplx_int16_t *p_src_x, armral_cmplx_int16_t *p_dst);
 
 /**
- * This algorithm performs matrix-vector multiplication for a batch of
- * `M-by-N` matrices and length `N` input vectors. Each multiplication is of the
- * form `A x` for a matrix `A` and vector `x`, and assumes that:
+ * This function performs matrix-vector multiplication for a batch of
+ * `M`-by-`N` matrices and length `N` input vectors. Each multiplication is of
+ * the form `y = A x` for a matrix `A` and vectors `x` and `y`, and assumes
+ * that:
  * + Matrix and vector elements are complex int16 in Q15 format.
  * + Matrices are stored in memory in row-major order.
  *
@@ -788,9 +798,9 @@ armral_status armral_cmplx_mat_vec_mult_i16_32bit(
  *                                  and the length of each output vector `y`.
  * @param[in]     n                 The number of columns (`N`) in each matrix
  *                                  `A` and the length of each input vector `x`.
- * @param[in]     p_src_a           Points to the input matrix.
- * @param[in]     p_src_x           Points to the input vector.
- * @param[out]    p_dst             Points to the output vector.
+ * @param[in]     p_src_a           Points to the input matrix `A`.
+ * @param[in]     p_src_x           Points to the input vector `x`.
+ * @param[out]    p_dst             Points to the output vector `y`.
  * @return     An `armral_status` value that indicates success or failure.
  */
 armral_status armral_cmplx_mat_vec_mult_batch_i16_32bit(
@@ -799,10 +809,11 @@ armral_status armral_cmplx_mat_vec_mult_batch_i16_32bit(
     armral_cmplx_int16_t *p_dst);
 
 /**
- * This algorithm performs matrix-vector multiplication for a batch of
- * `M-by-N` matrices and length `N` input vectors, utilizing a "pointer array"
- * storage layout for the input and output matrix batches. Each multiplication
- * is of the form `A x` for a matrix `A` and vector `x`, and assumes that:
+ * This function performs matrix-vector multiplication for a batch of
+ * `M`-by-`N` matrices and length `N` input vectors, utilizing a "pointer
+ * array" storage layout for the input and output matrix batches. Each
+ * multiplication is of the form `y = A x` for a matrix `A` and vectors `x` and
+ * `y`, and assumes that:
  * + Matrix and vector elements are complex int16 in Q15 format.
  * + Matrices are stored in memory in row-major order.
  *
@@ -810,7 +821,7 @@ armral_status armral_cmplx_mat_vec_mult_batch_i16_32bit(
  * value of `p_srcs_a[i]` is a pointer to the i-th element of the first matrix
  * in the batch, as represented in row-major ordering, such that the i-th
  * element of the j-th matrix in the batch is located at `p_srcs_a[i][j]`.
- * For example, the j-th matrix in a batch of `2-by-2` matrices is formed as:
+ * For example, the j-th matrix in a batch of 2-by-2 matrices is formed as:
  * 
  *   p_srcs_a[0][j]  p_srcs_a[1][j]
  *   p_srcs_a[2][j]  p_srcs_a[3][j]
@@ -851,8 +862,8 @@ armral_status armral_cmplx_mat_vec_mult_batch_i16_32bit_pa(
     const armral_cmplx_int16_t **p_srcs_x, armral_cmplx_int16_t **p_dsts);
 
 /**
- * This algorithm performs the multiplication `A x` for matrix `A` and vector
- * `x`, and assumes that:
+ * This function performs the multiplication `y = A x` for matrix `A` and
+ * vectors `x` and `y`, and assumes that:
  * + Matrix and vector elements are complex 32-bit float values.
  * + Matrices are stored in memory in row-major order.
  *
@@ -860,9 +871,9 @@ armral_status armral_cmplx_mat_vec_mult_batch_i16_32bit_pa(
  *                           the output vector `y`.
  * @param[in]     n          The number of columns in matrix `A` and the length
  *                           of the input vector `x`.
- * @param[in]     p_src_a    Points to the first input matrix.
- * @param[in]     p_src_x    Points to the input vector.
- * @param[out]    p_dst      Points to the output matrix.
+ * @param[in]     p_src_a    Points to the input matrix `A`.
+ * @param[in]     p_src_x    Points to the input vector `x`.
+ * @param[out]    p_dst      Points to the output matrix `y`.
  * @return     An `armral_status` value that indicates success or failure.
  */
 armral_status armral_cmplx_mat_vec_mult_f32(uint16_t m, uint16_t n,
@@ -871,9 +882,10 @@ armral_status armral_cmplx_mat_vec_mult_f32(uint16_t m, uint16_t n,
                                             armral_cmplx_f32_t *p_dst);
 
 /**
- * This algorithm performs matrix-vector multiplication for a batch of
- * `M-by-N` matrices and length `N` input vectors. Each multiplication is of the
- * form `A x` for a matrix `A` and vector `x`, and assumes that:
+ * This function performs matrix-vector multiplication for a batch of
+ * `M`-by-`N` matrices and length `N` input vectors. Each multiplication is of
+ * the form `y = A x` for a matrix `A` and vectors `x` and `y`, and assumes
+ * that:
  * + Matrix and vector elements are complex 32-bit float values.
  * + Matrices are stored in memory in row-major order.
  *
@@ -902,9 +914,9 @@ armral_status armral_cmplx_mat_vec_mult_f32(uint16_t m, uint16_t n,
  *                                  `y`.
  * @param[in]     n                 The number of columns (`N`) in each matrix
  *                                  `A` and the length of each input vector `x`.
- * @param[in]     p_src_a           Points to the input matrix.
- * @param[in]     p_src_x           Points to the input vector.
- * @param[out]    p_dst             Points to the output vector.
+ * @param[in]     p_src_a           Points to the input matrix `A`.
+ * @param[in]     p_src_x           Points to the input vector `x`.
+ * @param[out]    p_dst             Points to the output vector `y`.
  * @return     An `armral_status` value that indicates success or failure.
  */
 armral_status armral_cmplx_mat_vec_mult_batch_f32(
@@ -913,10 +925,11 @@ armral_status armral_cmplx_mat_vec_mult_batch_f32(
     armral_cmplx_f32_t *p_dst);
 
 /**
- * This algorithm performs matrix-vector multiplication for a batch of
- * `M-by-N` matrices and length `N` input vectors, utilizing a "pointer array"
+ * This function performs matrix-vector multiplication for a batch of
+ * `M`-by-`N` matrices and length `N` input vectors, utilizing a "pointer array"
  * storage layout for the input and output matrix batches. Each multiplication
- * is of the form `A x` for a matrix `A` and vector `x`, and assumes that:
+ * is of the form `y = A x` for a matrix `A` and vectors `x` and `y`, and
+ * assumes that:
  * + Matrix and vector elements are complex 32-bit float values.
  * + Matrices are stored in memory in row-major order.
  *
@@ -924,7 +937,7 @@ armral_status armral_cmplx_mat_vec_mult_batch_f32(
  * value of `p_srcs_a[i]` is a pointer to the i-th element of the first matrix
  * in the batch, as represented in row-major ordering, such that the i-th
  * element of the j-th matrix in the batch is located at `p_srcs_a[i][j]`.
- * For example, the j-th matrix in a batch of `2-by-2` matrices is formed as:
+ * For example, the j-th matrix in a batch of 2-by-2 matrices is formed as:
  * 
  *   p_srcs_a[0][j]  p_srcs_a[1][j]
  *   p_srcs_a[2][j]  p_srcs_a[3][j]
@@ -965,103 +978,175 @@ armral_status armral_cmplx_mat_vec_mult_batch_f32_pa(
  * @ingroup groupMatrix
  */
 /**
- * @addtogroup cmplx_matrix_mult  Complex Matrix-Matrix Multiplication
+ * @addtogroup gen_cmplx_matrix_mult General Complex Matrix-Matrix Multiplication
  * @{
- * \brief Computes a matrix-by-matrix multiplication, storing the result in a
- * destination matrix.
+ * \brief Computes a general matrix-by-matrix multiplication, storing the
+ * result in a destination matrix.
  *
  * The destination matrix is only written to and can be uninitialized.
- *
- * To permit specifying different fixed-point formats for the input and output
- * matrices, the `solve` routines take an extra fixed-point type specifier.
  */
 /**
- * This algorithm performs the multiplication `A B` for matrices, and assumes
+ * This function performs the multiplication `C = A B` for matrices, and assumes
  * that:
  * + Matrix elements are complex int16 in Q15 format.
  * + Matrices are stored in memory in row-major order.
  *
  * A 64-bit Q32.31 accumulator is used internally. If you do not need such a
- * large range, consider using \link armral_cmplx_mat_mult_i16_32bit \endlink
+ * large range, consider using \link armral_cmplx_matmul_i16_32bit \endlink
  * instead. To get the final result in Q15 and to avoid overflow, the
  * accumulator narrows to 16 bits with saturation.
  *
- * @param[in]     m          The number of rows in matrix `A`.
- * @param[in]     n          The number of columns in matrix `A`.
- * @param[in]     k          The number of columns in matrix `B`.
- * @param[in]     p_src_a    Points to the first input matrix.
- * @param[in]     p_src_b    Points to the second input matrix.
- * @param[out]    p_dst      Points to the output matrix.
+ * @param[in]     m          The number of rows (`M`) in matrices `A` and `C`.
+ * @param[in]     n          The number of columns (`N`) in matrices `B` and
+ *                           `C`.
+ * @param[in]     k          The number of columns (`K`) in matrix `A` and the
+ *                           number of rows in matrix `B`.
+ * @param[in]     p_src_a    Points to the first input matrix `A`.
+ * @param[in]     p_src_b    Points to the second input matrix `B`.
+ * @param[out]    p_dst      Points to the output matrix `C`.
  * @return     An `armral_status` value that indicates success or failure.
  */
-armral_status armral_cmplx_mat_mult_i16(uint16_t m, uint16_t n, uint16_t k,
-                                        const armral_cmplx_int16_t *p_src_a,
-                                        const armral_cmplx_int16_t *p_src_b,
-                                        armral_cmplx_int16_t *p_dst);
+armral_status armral_cmplx_matmul_i16(uint16_t m, uint16_t n, uint16_t k,
+                                      const armral_cmplx_int16_t *p_src_a,
+                                      const armral_cmplx_int16_t *p_src_b,
+                                      armral_cmplx_int16_t *p_dst);
 
 /**
- * This algorithm performs the multiplication `A B` for matrices, and assumes
+ * This function performs the multiplication `C = A B` for matrices, and assumes
  * that:
  * + Matrix elements are complex int16 in Q15 format.
  * + Matrices are stored in memory in row-major order.
  *
  * A 32-bit Q0.31 saturating accumulator is used internally. If you need a
- * larger range, consider using \link armral_cmplx_mat_mult_i16 \endlink
+ * larger range, consider using \link armral_cmplx_matmul_i16 \endlink
  * instead. To get a Q15 result, the final result is narrowed to 16 bits with
  * saturation.
  *
- * @param[in]     m          The number of rows in matrix `A`.
- * @param[in]     n          The number of columns in matrix `A`.
- * @param[in]     k          The number of columns in matrix `B`.
- * @param[in]     p_src_a    Points to the first input matrix.
- * @param[in]     p_src_b    Points to the second input matrix.
- * @param[out]    p_dst      Points to the output matrix.
+ * @param[in]     m          The number of rows (`M`) in matrices `A` and `C`.
+ * @param[in]     n          The number of columns (`N`) in matrices `B` and
+ *                           `C`.
+ * @param[in]     k          The number of columns (`K`) in matrix `A` and the
+ *                           number of rows in matrix `B`.
+ * @param[in]     p_src_a    Points to the first input matrix `A`.
+ * @param[in]     p_src_b    Points to the second input matrix `B`.
+ * @param[out]    p_dst      Points to the output matrix `C`.
  * @return     An `armral_status` value that indicates success or failure.
  */
-armral_status armral_cmplx_mat_mult_i16_32bit(
-    uint16_t m, uint16_t n, uint16_t k, const armral_cmplx_int16_t *p_src_a,
-    const armral_cmplx_int16_t *p_src_b, armral_cmplx_int16_t *p_dst);
+armral_status armral_cmplx_matmul_i16_32bit(uint16_t m, uint16_t n, uint16_t k,
+                                            const armral_cmplx_int16_t *p_src_a,
+                                            const armral_cmplx_int16_t *p_src_b,
+                                            armral_cmplx_int16_t *p_dst);
 
 /**
- * This algorithm performs the multiplication `A B` for matrices of float
- * values, and assumes that matrices are stored in memory row-major.
+ * This function performs the multiplication `C = A B` for matrices, and assumes
+ * that:
+ * + Matrix elements are complex 32-bit floating point values.
+ * + Matrices are stored in memory in row-major order.
  *
- * @param[in]     m          The number of rows in matrix `A`.
- * @param[in]     n          The number of columns in matrix `A`.
- * @param[in]     k          The number of columns in matrix `B`.
- * @param[in]     p_src_a    Points to the first input matrix.
- * @param[in]     p_src_b    Points to the second input matrix.
- * @param[out]    p_dst      Points to the output matrix.
+ * @param[in]     m          The number of rows (`M`) in matrices `A` and `C`.
+ * @param[in]     n          The number of columns (`N`) in matrices `B` and
+ *                           `C`.
+ * @param[in]     k          The number of columns (`K`) in matrix `A` and the
+ *                           number of rows in matrix `B`.
+ * @param[in]     p_src_a    Points to the first input matrix `A`.
+ * @param[in]     p_src_b    Points to the second input matrix `B`.
+ * @param[out]    p_dst      Points to the output matrix `C`.
  * @return     An `armral_status` value that indicates success or failure.
  */
-armral_status armral_cmplx_mat_mult_f32(uint16_t m, uint16_t n, uint16_t k,
-                                        const armral_cmplx_f32_t *p_src_a,
-                                        const armral_cmplx_f32_t *p_src_b,
-                                        armral_cmplx_f32_t *p_dst);
+armral_status armral_cmplx_matmul_f32(uint16_t m, uint16_t n, uint16_t k,
+                                      const armral_cmplx_f32_t *p_src_a,
+                                      const armral_cmplx_f32_t *p_src_b,
+                                      armral_cmplx_f32_t *p_dst);
 
 /**
- * This algorithm performs an optimized product of two square `2-by-2` matrices.
- * The algorithm assumes that matrix `A` (first matrix) is column-major before
- * entering the `armral_cmplx_mat_mult_2x2_f32` function.
- * Matrix `B` (second matrix) is also assumed to be column-major. The result of - * the product is a column-major matrix. In LTE and 5G, you can use the - * `armral_cmplx_mat_mult_2x2_f32` function in the equalization - * step in the formula: + * This function performs a matrix multiplication of an input `M`-by-`N` matrix + * `A` with its conjugate transpose `A^H`: + * + *
+ * C = A A^H
+ * 
+ * + * `C` is therefore `M`-by-`M`. + * + * The input matrix `p_src_a` and output matrix `p_src_c` are stored + * contiguously in memory, in row-major order. + * + * `p_src_a` and `p_dst_c` must not alias each other. + * + * @param[in] m The number of rows (`M`) in the input matrix `A`, and + * the number of rows and columns in the output matrix + * `C`. + * @param[in] n The number of columns (`N`) in the input matrix `A`. + * @param[in] p_src_a Points to the input matrix `A`. + * @param[out] p_dst_c Points to the output matrix `C`. + * @return An `armral_status` value that indicates success or failure. + */ +armral_status armral_cmplx_matmul_aah_f32(uint16_t m, uint16_t n, + const armral_cmplx_f32_t *p_src_a, + armral_cmplx_f32_t *p_dst_c); + +/** + * This function performs the multiplication of the conjugate transpose of `A` + * with the matrix `B`, to compute the matrix `C`. That is: + * + *
+ * C = A^H B
+ * 
+ * + * Matrix `A` is `K`-by-`M`, `B` is `K`-by-`N`, and `C` is `M`-by-`N`. All + * matrices are stored contiguously in memory, in row-major order. + * + * None of the arrays passed to this function are allowed to alias. + * + * @param[in] m The number of columns (`M`) in matrix `A` and + * the number of rows in matrix `C`. + * @param[in] n The number of columns (`N`) in matrices `B` and + * `C`. + * @param[in] k The number of rows (`K`) in matrices `A` and `B`. + * @param[in] p_src_a Points to the input matrix `A`. + * @param[in] p_src_b Points to the input matrix `B`. + * @param[out] p_dst Points to the output matrix `C`. + * @return An `armral_status` value that indicates success or failure. + */ +armral_status armral_cmplx_matmul_ahb_f32(uint16_t m, uint16_t n, uint16_t k, + const armral_cmplx_f32_t *p_src_a, + const armral_cmplx_f32_t *p_src_b, + armral_cmplx_f32_t *p_dst); + +/** @} end of gen_cmplx_matrix_mult */ + +/** + * @ingroup groupMatrix + */ +/** + * @addtogroup spec_cmplx_matrix_mult Specific-Sized Complex Matrix-Matrix Multiplication + * @{ + * \brief Computes a specific-sized matrix-by-matrix multiplication, storing the + * result in a destination matrix. + * + * The destination matrix is only written to and can be uninitialized. + */ +/** + * This function performs an optimized product of two square 2-by-2 matrices. + * The function assumes matrices are stored in column-major order. In LTE and + * 5G, you can use the `armral_cmplx_mat_mult_2x2_f32` function in the + * equalization step in the formula: * *
  *  x̂ = G y
  * 
* * Equalization matrix `G` corresponds to the first input matrix (matrix `A`) of - * the function. The algorithm assumes that matrix `G` is transposed during - * computation so that the matrix presents as column-major on input.
The - * second input matrix (matrix `B`) is formed by two `2-by-1` vectors (`y` - * vectors in the preceding formula) so that each row of B represents a `2-by-1` - * vector output from each antenna port, and each call to + * the function. The function assumes that matrix `G` is transposed during + * computation so that the matrix presents as column-major on input. + * + * The second input matrix (matrix `B`) is formed by two 2-by-1 vectors (`y` + * vectors in the preceding formula) so that each row of `B` represents a + * 2-by-1 vector output from each antenna port, and each call to * `armral_cmplx_mat_mult_2x2_f32` computes two distinct `x̂` estimates. * - * @param[in] p_src_a Points to the first input matrix. - * @param[in] p_src_b Points to the second input matrix. + * @param[in] p_src_a Points to the first input matrix `A`. + * @param[in] p_src_b Points to the second input matrix `B`. * @param[out] p_dst Points to the output matrix. * @return An `armral_status` value that indicates success or failure. */ @@ -1070,13 +1155,10 @@ armral_status armral_cmplx_mat_mult_2x2_f32(const armral_cmplx_f32_t *p_src_a, armral_cmplx_f32_t *p_dst); /** - * This algorithm performs an optimized product of two square `2-by-2` matrices + * This function performs an optimized product of two square 2-by-2 matrices * whose complex elements have already been separated into real component and - * imaginary component arrays. The algorithm assumes that matrix `A` (first - * matrix) is column-major before entering the - * `armral_cmplx_mat_mult_2x2_f32_iq` function.
Matrix `B` (second matrix) - * is also considered to be column-major. The result of the product is a - * column-major matrix. In LTE and 5G, you can use the + * imaginary component arrays. The function assumes that matrices are stored in + * column-major order. In LTE and 5G, you can use the * `armral_cmplx_mat_mult_2x2_f32_iq` function in the equalization step in the * formula: * @@ -1085,11 +1167,11 @@ armral_status armral_cmplx_mat_mult_2x2_f32(const armral_cmplx_f32_t *p_src_a, *
* * Equalization matrix `G` corresponds to the first input matrix (matrix `A`) of - * the function. The algorithm assumes matrix `G` is transposed during - * computation so that the matrix presents as column-major on input.
The - * second input matrix (matrix `B`) is formed by two `2-by-1` vectors (`y` - * vectors in the preceding formula) so that each row of B represents a `2-by-1` - * vector output from each antenna port, and each call to + * the function. The function assumes matrix `G` is transposed during + * computation so that the matrix presents as column-major on input. + * The second input matrix (matrix `B`) is formed by two 2-by-1 vectors (`y` + * vectors in the preceding formula) so that each row of `B` represents a + * 2-by-1 vector output from each antenna port, and each call to * `armral_cmplx_mat_mult_2x2_f32_iq` computes two distinct `x̂` estimates. * * @param[in] src_a_re Points to the real part of the first input matrix. @@ -1108,28 +1190,25 @@ armral_status armral_cmplx_mat_mult_2x2_f32_iq(const float32_t *src_a_re, float32_t *dst_im); /** - * This algorithm performs an optimized product of two square `4-by-4` matrices. - * The algorithm assumes that matrix `A` (first matrix) is column-major before - * entering the `armral_cmplx_mat_mult_4x4_f32` function.
- * Matrix `B` (second matrix) is also considered to be column-major. The result - * of the product is a column-major matrix. In LTE and 5G, you can use the - * `armral_cmplx_mat_mult_4x4_f32` function in the equalization step in the - * formula: + * This function performs an optimized product of two square 4-by-4 matrices. + * The function assumes that matrices are stored in column-major order. In LTE + * and 5G, you can use the `armral_cmplx_mat_mult_4x4_f32` function in the + * equalization step in the formula: * *
  *  x̂ = G y
  * 
* * Equalization matrix `G` corresponds to the first input matrix (matrix `A`) of - * the function.
The algorithm assumes that matrix `G` is transposed during - * computation so that the matrix presents as column-major on input.
The - * second input matrix (matrix `B`) is formed by four `4-by-1` vectors (`y` - * vectors in the preceding formula) so that each row of B represents a `4-by-1` - * vector output from each antenna port, and each call to + * the function. The function assumes that matrix `G` is transposed during + * computation so that the matrix presents as column-major on input. + * The second input matrix (matrix `B`) is formed by four 4-by-1 vectors (`y` + * vectors in the preceding formula) so that each row of `B` represents a + * 4-by-1 vector output from each antenna port, and each call to * `cmplx_mat_mult_4x4_f32` computes four distinct `x̂` estimates. * - * @param[in] p_src_a Points to the first input matrix. - * @param[in] p_src_b Points to the second input matrix. + * @param[in] p_src_a Points to the first input matrix `A`. + * @param[in] p_src_b Points to the second input matrix `B`. * @param[out] p_dst Points to the output matrix. * @return An `armral_status` value that indicates success or failure. */ @@ -1138,13 +1217,10 @@ armral_status armral_cmplx_mat_mult_4x4_f32(const armral_cmplx_f32_t *p_src_a, armral_cmplx_f32_t *p_dst); /** - * This algorithm performs an optimized product of two square `4-by-4` matrices + * This function performs an optimized product of two square 4-by-4 matrices * whose complex elements have already been separated into real and imaginary - * component arrays. The algorithm assumes that matrix `A` (first matrix) is - * column-major before entering the `armral_cmplx_mat_mult_4x4_f32_iq` - * function.
Matrix `B` (second matrix) is also considered to be - * column-major. The result of the product is a column-major matrix. In LTE and - * 5G, you can use the `armral_cmplx_mat_mult_4x4_f32_iq` + * component arrays. The function assumes matrices are stored in column-major + * order. In LTE and 5G, you can use the `armral_cmplx_mat_mult_4x4_f32_iq` * function in the equalization step in the formula: * *
@@ -1152,11 +1228,11 @@ armral_status armral_cmplx_mat_mult_4x4_f32(const armral_cmplx_f32_t *p_src_a,
  * 
* * Equalization matrix `G` corresponds to the first input matrix (matrix `A`) of - * the function. The algorithm assumes that matrix `G` is transposed during - * computation so that the matrix presents as column-major on input.
The - * second input matrix (matrix `B`) is formed by four `4-by-1` vectors (`y` - * vectors in the preceding formula) so that each row of B represents a `4-by-1` - * vector output from each antenna port, and each call to + * the function. The function assumes that matrix `G` is transposed during + * computation so that the matrix presents as column-major on input. + * The second input matrix (matrix `B`) is formed by four 4-by-1 vectors (`y` + * vectors in the preceding formula) so that each row of `B` represents a + * 4-by-1 vector output from each antenna port, and each call to * `armral_cmplx_mat_mult_4x4_f32_iq` computes four distinct `x̂` estimates. * * @param[in] src_a_re Points to the real part of the first input matrix. @@ -1174,61 +1250,22 @@ armral_status armral_cmplx_mat_mult_4x4_f32_iq(const float32_t *src_a_re, float32_t *dst_re, float32_t *dst_im); +/** @} end of spec_cmplx_matrix_mult */ + /** - * This algorithm performs a matrix multiplication of an input `M-by-N` matrix - * `A` with its conjugate transpose `A^H`: - * - *
- * C = A A^H
- * 
- * - * `C` is therefore `M-by-M`. - * - * The input matrix `p_src_a` and output matrix `p_src_c` are stored - * contiguously in memory, in row-major order. - * - * `p_src_a` and `p_dst_c` must not alias each other. - * - * @param[in] m The number of rows (`M`) in the input matrix `A`, and - * the number of rows and columns in the output matrix - * `C`. - * @param[in] n The number of columns (`N`) in the input matrix `A`. - * @param[in] p_src_a Points to the input matrix `A`. - * @param[out] p_dst_c Points to the output matrix `C`. - * @return An `armral_status` value that indicates success or failure. + * @ingroup groupMatrix */ -armral_status armral_cmplx_mat_mult_aah_f32(uint16_t m, uint16_t n, - const armral_cmplx_f32_t *p_src_a, - armral_cmplx_f32_t *p_dst_c); - /** - * This algorithm performs the multiplication of the conjugate transpose of `A` - * with the matrix `B`, to compute the matrix `C`. That is: - * - *
- * C = A^H B
- * 
- * - * Matrix `A` is `M-by-N`, `B` is `M-by-K`, and `C` is `N-by-K`. All matrices - * are stored contiguously in memory, in row-major order. + * @addtogroup cmplx_matrix_solve Channel Matrix-Matrix Multiplication + * @{ + * \brief Computes a matrix-by-matrix multiplication, storing the result in a + * destination matrix. * - * None of the arrays passed to this function are allowed to alias. + * The destination matrix is only written to and can be uninitialized. * - * @param[in] m The number of rows (`M`) in matrices `A` and `B`. - * @param[in] n The number of columns (`N`) in matrix `A` and the - * number of rows in matrix `C`. - * @param[in] k The number of columns (`K`) in matrices `B` and - * `C`. - * @param[in] p_src_a Points to the input matrix `A`. - * @param[in] p_src_b Points to the input matrix `B`. - * @param[out] p_dst Points to the output matrix `C`. - * @return An `armral_status` value that indicates success or failure. + * To permit specifying different fixed-point formats for the input and output + * matrices, these functions take an extra fixed-point type specifier. */ -armral_status armral_cmplx_mat_mult_ahb_f32(uint16_t m, uint16_t n, uint16_t k, - const armral_cmplx_f32_t *p_src_a, - const armral_cmplx_f32_t *p_src_b, - armral_cmplx_f32_t *p_dst); - /** * In LTE and 5G, you can use the `armral_solve_2x2_f32` function * in the equalization step, as in the formula: @@ -1240,16 +1277,18 @@ armral_status armral_cmplx_mat_mult_ahb_f32(uint16_t m, uint16_t n, uint16_t k, * where `y` is a vector for the received signal, size corresponds to the * number of antennae and `x̂` is the estimate of the transmitted signal, * size corresponds to the number of layers. `G` is the equalization complex - * matrix and is assumed to be a `2-by-2` matrix. I and Q components of `G` - * elements are assumed to be stored separated in memory.
Also, each - * coefficient of `G` (`Gxy`, for `x, y = {1, 2}`) is assumed to be stored - * separated in memory locations set at `pGstride` one from the other.
+ * matrix and is assumed to be a 2-by-2 matrix. I and Q components of `G` + * elements are assumed to be stored separated in memory. + * Also, each coefficient of `G` (`Gxy`, for `x, y = {1, 2}`) is assumed to be + * stored separated in memory locations set at `p_gstride` one from the other. + * * The number of input signals is assumed to be a multiple of 12, and must - * be divisible by the number of subcarriers per `G` matrix.
For type 1 - * equalization, the number of subcarriers per `G` matrix must be four. For - * type 2 equalization, the number of subcarriers per `G` matrix must be - * six. An implementation is also available for cases where the number of - * subcarriers per `G` matrix is equal to one. + * be divisible by the number of subcarriers per `G` matrix. + * + * For type 1 equalization, the number of subcarriers per `G` matrix must be + * four. For type 2 equalization, the number of subcarriers per `G` matrix + * must be six. An implementation is also available for cases where the number + * of subcarriers per `G` matrix is equal to one. * * @param[in] num_sub_carrier The number of subcarriers to equalize. * @param[in] num_sc_per_g The number of subcarriers per `G` matrix. @@ -1283,16 +1322,22 @@ armral_solve_2x2_f32(uint32_t num_sub_carrier, uint32_t num_sc_per_g, * * where `y` is a vector for the received signal, size corresponds to the * number of antennae and `x̂` is the estimate of the transmitted signal, size - * corresponds to the number of layers.
`G` is the equalization complex - * matrix and is assumed to be a `2-by-4` matrix. I and Q components of `G` - * elements are assumed to be stored separated in memory.
Also, each - * coefficient of `G` (`Gxy`, for `x = {1, 2}` and `y = {1, 2, 3, 4}`) is - * assumed to be stored separated in memory locations set at `pGstride` one from - * the other.
The number of input signals is assumed to be a multiple of 12, - * and must be divisible by the number of subcarriers per `G` matrix.
For - * type 1 equalization, the number of subcarriers per `G` matrix must be four. - * For type 2 equalization, the number of subcarriers per `G` matrix must be - * six. An implementation is also available for cases where the number of + * corresponds to the number of layers. + * + * `G` is the equalization complex matrix and is assumed to be a 2-by-4 matrix. + * I and Q components of `G` elements are assumed to be stored separated in + * memory. + * + * Also, each coefficient of `G` (`Gxy`, for `x = {1, 2}` and + * `y = {1, 2, 3, 4}`) is assumed to be stored separated in memory locations + * set at `p_gstride` one from the other. + * + * The number of input signals is assumed to be a multiple of 12, and must be + * divisible by the number of subcarriers per `G` matrix. + * + * For type 1 equalization, the number of subcarriers per `G` matrix must be + * four. For type 2 equalization, the number of subcarriers per `G` matrix must + * be six. An implementation is also available for cases where the number of * subcarriers per `G` matrix is equal to one. * * @param[in] num_sub_carrier The number of subcarriers to equalize. @@ -1327,26 +1372,35 @@ armral_solve_2x4_f32(uint32_t num_sub_carrier, uint32_t num_sc_per_g, * * where `y` is a vector for the received signal, size corresponds to the * number of antennae and `x̂` is the estimate of the transmitted signal, size - * corresponds to the number of layers.
The input values for y are given in - * the Q0.15 fixed-point format. Each component of the vector may have a - * different number of fractional bits. The number of fractional bits per `y` - * component is passed in an array of the same length as `y`.
`G` is the - * equalization complex matrix and is assumed to be a `4-by-4` matrix. I and Q - * components of `G` elements are assumed to be stored separated in memory.
+ * corresponds to the number of layers. + * + * The input values for y are given in the Q0.15 fixed-point format. Each + * component of the vector may have a different number of fractional bits. The + * number of fractional bits per `y` component is passed in an array of the + * same length as `y`. + * + * `G` is the equalization complex matrix and is assumed to be a 4-by-4 matrix. + * I and Q components of `G` elements are assumed to be stored separated in + * memory. + * * Also, each coefficient of `G` (`Gxy`, for `x, y = {1, 2, 3, 4}`) is assumed * to be stored separated in memory locations set at `p_gstride` one from the - * other.
It is assumed that each component of the vectors `y` and `x` are - * stored in memory at `p_y_stride` and `p_x_stride` one from the other. It is - * also assumed that `p_g_stride` is greater than or equal to the number of + * other. + * + * It is assumed that each component of the vectors `y` and `x` are stored in + * memory at `p_y_stride` and `p_x_stride` one from the other. It is also + * assumed that `p_gstride` is greater than or equal to the number of * subcarriers divided by the number of subcarriers per `G`. `p_y_stride` and * `p_x_stride` are assumed greater than or equal to the number of subcarriers. * If these assumptions are violated, the results returned will be incorrect. - *
The number of input signals is assumed to be a multiple of 12, and must - * be divisible by the number of subcarriers per `G` matrix.
For type 1 - * equalization, the number of subcarriers per `G` matrix must be four. For type - * 2 equalization, the number of subcarriers per `G` matrix must be six. An - * implementation is also available for cases where the number of subcarriers - * per `G` matrix is equal to one. + * + * The number of input signals is assumed to be a multiple of 12, and must be + * divisible by the number of subcarriers per `G` matrix. + * + * For type 1 equalization, the number of subcarriers per `G` matrix must be + * four. For type 2 equalization, the number of subcarriers per `G` matrix must + * be six. An implementation is also available for cases where the number of + * subcarriers per `G` matrix is equal to one. * * @param[in] num_sub_carrier The number of subcarriers to equalize. * @param[in] num_sc_per_g The number of subcarriers per `G` matrix. @@ -1380,17 +1434,22 @@ armral_solve_4x4_f32(uint32_t num_sub_carrier, uint32_t num_sc_per_g, * * where `y` is a vector for the received signal, size corresponds to the * number of antennae and `x̂` is the estimate of the transmitted signal, size - * corresponds to the number of layers.
`G` is the equalization complex - * matrix and is assumed to be a `1-by-4` matrix (i.e. a row vector). I and Q - * components of `G` elements are assumed to be stored separated in memory.
+ * corresponds to the number of layers. + * + * `G` is the equalization complex matrix and is assumed to be a 1-by-4 matrix + * (i.e. a row vector). I and Q components of `G` elements are assumed to be + * stored separated in memory. + * * Also, each coefficient of `G` (`G1y`, for `y = {1, 2, 3, 4}`) is assumed to - * be stored separated in memory locations set at `pGstride` one from the - * other.
The number of input signals is assumed to be a multiple of 12, and - * must be divisible by the number of subcarriers per `G` matrix.
For type 1 - * equalization, the number of subcarriers per `G` matrix must be four. For type - * 2 equalization, the number of subcarriers per `G` matrix must be six. An - * implementation is also available for cases where the number of subcarriers - * per `G` matrix is equal to one. + * be stored separated in memory locations set at `p_gstride` one from the + * other. + * + * The number of input signals is assumed to be a multiple of 12, and must be + * divisible by the number of subcarriers per `G` matrix. + * For type 1 equalization, the number of subcarriers per `G` matrix must be + * four. For type 2 equalization, the number of subcarriers per `G` matrix must + * be six. An implementation is also available for cases where the number of + * subcarriers per `G` matrix is equal to one. * * @param[in] num_sub_carrier The number of subcarriers to equalize. * @param[in] num_sc_per_g The number of subcarriers per `G` matrix. @@ -1424,16 +1483,19 @@ armral_solve_1x4_f32(uint32_t num_sub_carrier, uint32_t num_sc_per_g, * where `y` is a vector for the received signal, size corresponds to the * number of antennae and `x̂` is the estimate of the transmitted signal, size * corresponds to the number of layers. `G` is the equalization complex matrix - * and is assumed to be a `1-by-2` matrix (i.e. a row vector). I and Q - * components of `G` elements are assumed to be stored separated in memory.
+ * and is assumed to be a 1-by-2 matrix (i.e. a row vector). I and Q components + * of `G` elements are assumed to be stored separated in memory. + * * Also, each coefficient of `G` (`G11`,`G12`) is assumed to be stored separated - * in memory locations set at `pGstride` one from the other.
The number of - * input signals is assumed to be a multiple of 12, and must be divisible by the - * number of subcarriers per `G` matrix.
For type 1 equalization, the number - * of subcarriers per `G` matrix must be four. For type 2 equalization, the - * number of subcarriers per `G` matrix must be six. An implementation is also - * available for cases where the number of subcarriers per `G` matrix is equal - * to one. + * in memory locations set at `p_gstride` one from the other. + * + * The number of input signals is assumed to be a multiple of 12, and must be + * divisible by the number of subcarriers per `G` matrix. + * + * For type 1 equalization, the number of subcarriers per `G` matrix must be + * four. For type 2 equalization, the number of subcarriers per `G` matrix must + * be six. An implementation is also available for cases where the number of + * subcarriers per `G` matrix is equal to one. * * @param[in] num_sub_carrier The number of subcarriers to equalize. * @param[in] num_sc_per_g The number of subcarriers per `G` matrix. @@ -1456,7 +1518,7 @@ armral_solve_1x2_f32(uint32_t num_sub_carrier, uint32_t num_sc_per_g, uint32_t p_gstride, armral_cmplx_int16_t *p_x, armral_fixed_point_index num_fract_bits_x); -/** @} end of cmplx_matrix_mult group */ +/** @} end of cmplx_matrix_solve group */ /** * @ingroup groupMatrix @@ -1465,13 +1527,16 @@ armral_solve_1x2_f32(uint32_t num_sub_carrier, uint32_t num_sc_per_g, * @addtogroup cmplx_matrix_inv Complex Matrix Inversion * @{ * \brief Computes the inverse of a complex Hermitian square matrix of size - * `N-by-N`. + * `N`-by-`N`. */ /** - * This algorithm computes the inverse of a single complex Hermitian square - * matrix of size `N-by-N`.
The supported dimensions are `2-by-2`, `3-by-3`, - * `4-by-4`, `8-by-8`, and `16-by-16`.
The input and output matrices are - * filled in row-major order with complex `float32_t` elements. + * This function computes the inverse of a single complex Hermitian square + * matrix of size `N`-by-`N`. + * + * The supported dimensions are 2-by-2, 3-by-3, 4-by-4, 8-by-8, and 16-by-16. + * + * The input and output matrices are filled in row-major order with complex + * `float32_t` elements. * * @param[in] size The size of the input matrix. * @param[in] p_src Points to the input matrix structure. @@ -1482,10 +1547,13 @@ armral_status armral_cmplx_hermitian_mat_inverse_f32( uint32_t size, const armral_cmplx_f32_t *p_src, armral_cmplx_f32_t *p_dst); /** - * This algorithm computes the inverse of a single complex square - * matrix of size `N-by-N`.
The supported dimensions are `2-by-2`, `3-by-3`, - * `4-by-4`, `8-by-8`, and `16-by-16`.
The input and output matrices are - * filled in row-major order with complex `float32_t` elements. + * This function computes the inverse of a single complex square + * matrix of size `N`-by-`N`. + * + * The supported dimensions are 2-by-2, 3-by-3, 4-by-4, 8-by-8, and 16-by-16. + * + * The input and output matrices are filled in row-major order with complex + * `float32_t` elements. * * @param[in] size The size of the input matrix. * @param[in] p_src Points to the input matrix structure. @@ -1497,15 +1565,17 @@ armral_status armral_cmplx_mat_inverse_f32(uint32_t size, armral_cmplx_f32_t *p_dst); /** - * This algorithm computes the inverse of a batch of `M` complex Hermitian - * square matrices, each of size `N-by-N`.
The supported matrix dimensions - * are `2-by-2`, `3-by-3`, and `4-by-4`.
The input and output matrices are - * filled in row-major order with complex `float32_t` elements, interleaved such - * that all elements for a particular location within the matrix are stored - * together. This means that, for instance, the first four complex numbers - * stored are the first element from each of the first four matrices in the - * batch. The offset to the next location in the same matrix is given by the - * `num_mats` batch size: + * This function computes the inverse of a batch of `M` complex Hermitian + * square matrices, each of size `N`-by-`N`. + * + * The supported matrix dimensions are 2-by-2, 3-by-3, and 4-by-4. + * + * The input and output matrices are filled in row-major order with complex + * `float32_t` elements, interleaved such that all elements for a particular + * location within the matrix are stored together. This means that, for + * instance, the first four complex numbers stored are the first element from + * each of the first four matrices in the batch. The offset to the next + * location in the same matrix is given by the `num_mats` batch size: *
  *   {Re(0), Im(0), Re(1), Im(1), ..., Re(M - 1), Im(M - 1)}
  * 
@@ -1525,21 +1595,20 @@ armral_cmplx_hermitian_mat_inverse_batch_f32(uint32_t num_mats, uint32_t size, armral_cmplx_f32_t *p_dst); /** - * This algorithm computes the inverse of a batch of `M` complex general - * square matrices, each of size `N-by-N`.
The supported matrix dimensions - * are `2-by-2`, `3-by-3`, and `4-by-4`.
The input and output matrices are - * filled in row-major order with complex `float32_t` elements, interleaved such - * that all elements for a particular location within the matrix are stored - * together. This means that, for instance, the first four complex numbers - * stored are the first element from each of the first four matrices in the - * batch. The offset to the next location in the same matrix is given by the - * `num_mats` batch size: + * This function computes the inverse of a batch of `M` complex general + * square matrices, each of size `N`-by-`N`. + * + * The supported matrix dimensions are 2-by-2, 3-by-3, and 4-by-4. + * + * The input and output matrices are filled in row-major order with complex + * `float32_t` elements, interleaved such that all elements for a particular + * location within the matrix are stored together. This means that, for + * instance, the first four complex numbers stored are the first element from + * each of the first four matrices in the batch. The offset to the next + * location in the same matrix is given by the `num_mats` batch size: *
  *   {Re(0), Im(0), Re(1), Im(1), ..., Re(M - 1), Im(M - 1)}
  * 
- * The number of matrices in a batch (`M`) must be a multiple of the matrix - * dimension. So, if `N = 2` then `M` must be a multiple of two, and if `N = 4` - * then `M` must be a multiple of four. * * @param[in] num_mats The number (`M`) of input and output matrices. * @param[in] size The size (`N`) of the input and output matrix. @@ -1553,15 +1622,16 @@ armral_cmplx_mat_inverse_batch_f32(uint32_t num_mats, uint32_t size, armral_cmplx_f32_t *p_dst); /** - * This algorithm computes the inverse of a batch of `M` complex Hermitian - * square matrices, each of size `N-by-N`, utilizing a "pointer array" storage - * layout for the input and output matrix batches.
The supported matrix - * dimensions are `2-by-2`, `3-by-3`, and `4-by-4`.
The `p_srcs` parameter - * is an array of pointers of length `N-by-N`. The value of `p_srcs[i]` is a - * pointer to the i-th element of the first matrix in the batch, as represented - * in row-major ordering, such that the i-th element of the j-th matrix in the - * batch is located at `p_srcs[i][j]`. Similarly, the j-th matrix in a batch of - * `2-by-2` matrices is formed as: + * This function computes the inverse of a batch of `M` complex Hermitian + * square matrices, each of size `N`-by-`N`, utilizing a "pointer array" storage + * layout for the input and output matrix batches. + * + * The supported matrix dimensions are 2-by-2, 3-by-3, and 4-by-4. + * The `p_srcs` parameter is an array of pointers of length `N`-by-`N`. The + * value of `p_srcs[i]` is a pointer to the i-th element of the first matrix in + * the batch, as represented in row-major ordering, such that the i-th element + * of the j-th matrix in the batch is located at `p_srcs[i][j]`. Similarly, + * the j-th matrix in a batch of 2-by-2 matrices is formed as: *
  *   p_srcs[0][j]  p_srcs[1][j]
  *   p_srcs[2][j]  p_srcs[3][j]
@@ -1584,15 +1654,16 @@ armral_status armral_cmplx_hermitian_mat_inverse_batch_f32_pa(
     armral_cmplx_f32_t **p_dsts);
 
 /**
- * This algorithm computes the inverse of a batch of `M` complex general
- * square matrices, each of size `N-by-N`, utilizing a "pointer array" storage
- * layout for the input and output matrix batches.
The supported matrix - * dimensions are `2-by-2`, `3-by-3`, and `4-by-4`.
The `p_srcs` parameter - * is an array of pointers of length `N-by-N`. The value of `p_srcs[i]` is a - * pointer to the i-th element of the first matrix in the batch, as represented - * in row-major ordering, such that the i-th element of the j-th matrix in the - * batch is located at `p_srcs[i][j]`. Similarly, the j-th matrix in a batch of - * `2-by-2` matrices is formed as: + * This function computes the inverse of a batch of `M` complex general + * square matrices, each of size `N`-by-`N`, utilizing a "pointer array" + * storage layout for the input and output matrix batches. + * + * The supported matrix dimensions are 2-by-2, 3-by-3, and 4-by-4. + * The `p_srcs` parameter is an array of pointers of length `N`-by-`N`. The + * value of `p_srcs[i]` is a pointer to the i-th element of the first matrix in + * the batch, as represented in row-major ordering, such that the i-th element + * of the j-th matrix in the batch is located at `p_srcs[i][j]`. Similarly, the + * j-th matrix in a batch of 2-by-2 matrices is formed as: *
  *   p_srcs[0][j]  p_srcs[1][j]
  *   p_srcs[2][j]  p_srcs[3][j]
@@ -1600,10 +1671,6 @@ armral_status armral_cmplx_hermitian_mat_inverse_batch_f32_pa(
  * The output array `p_dsts` points to an array of pointers, representing an
  * identical format to the input.
  *
- * The number of matrices in a batch (`M`) must be a multiple of the matrix
- * dimension. So, if `N = 2` then `M` must be a multiple of two, and if `N = 4`
- * then `M` must be a multiple of four.
- *
  * @param[in]  num_mats   The number (`M`) of input and output matrices.
  * @param[in]  size       The size (`N`) of the input and output matrix.
  * @param[in]  p_srcs     Points to the input matrix structure.
@@ -1624,11 +1691,11 @@ armral_cmplx_mat_inverse_batch_f32_pa(uint32_t num_mats, uint32_t size,
  * @addtogroup cmplx_matrix_pseudo_inv Complex Matrix Pseudo-Inverse
  * @{
  * \brief Computes the regularized pseudo-inverse of a complex matrix of size
- * `M-by-N`.
+ * `M`-by-`N`.
  */
 /**
- * Computes the regularized pseudo-inverse of a single matrix. The `N-by-M`
- * regularized pseudo-inverse `C` of an `M-by-N` matrix `A` is defined as:
+ * Computes the regularized pseudo-inverse of a single matrix. The `N`-by-`M`
+ * regularized pseudo-inverse `C` of an `M`-by-`N` matrix `A` is defined as:
  *
  * 
  *   C = A^H * (A * A^H + λ * I)^-1
@@ -1674,8 +1741,8 @@ armral_cmplx_pseudo_inverse_direct_f32(uint16_t m, uint16_t n, float32_t lambda,
  * Non-allocating variant of \link armral_cmplx_pseudo_inverse_direct_f32
  * \endlink.
  *
- * Computes the regularized pseudo-inverse of a single matrix. The `N-by-M`
- * regularized pseudo-inverse `C` of an `M-by-N` matrix `A` is defined as:
+ * Computes the regularized pseudo-inverse of a single matrix. The `N`-by-`M`
+ * regularized pseudo-inverse `C` of an `M`-by-`N` matrix `A` is defined as:
  *
  * 
  *   C = A^H * (A * A^H + λ * I)^-1
@@ -1741,12 +1808,14 @@ armral_status armral_cmplx_pseudo_inverse_direct_f32_noalloc(
  * Technical Specification (TS) 36.211, Chapter 7.2.
  */
 /**
- * This algorithm generates a pseudo-random sequence (Gold Sequence) that is
+ * This function generates a pseudo-random sequence (Gold Sequence) that is
  * used in 4G and 5G networks to scramble data of a specific channel or to
  * generate a specific sequence (for example for Downlink Reference Signal
- * generation).
The sequence generator is the same generator that is - * described in the 3GPP Technical Specification (TS) 36.211, Chapter 7.2. The - * generator uses two polynomials, `x1` and `x2`, defined as: + * generation). + * + * The sequence generator is the same generator that is described in the 3GPP + * Technical Specification (TS) 36.211, Chapter 7.2. The generator uses two + * polynomials, `x1` and `x2`, defined as: * *
  *  x1(n+31) = (x1(n+3) + x1(n)) mod 2
@@ -1820,11 +1889,12 @@ armral_status armral_modulation(uint32_t nbits, armral_modulation_type mod_type,
                                 armral_cmplx_int16_t *p_dst);
 
 /**
- * This algorithm implements the soft-demodulation (or soft bit demapping) for
- * QPSK, 16QAM, 64QAM, and 256QAM constellations.
For complex symbols `x_i`, - * the input sequence is assumed to be made of complex symbols `rx = rx_re + j * - * rx_im`, whose components I and Q are 16 bits each (format Q2.13) and in an - * interleaved form: + * This function implements the soft-demodulation (or soft bit demapping) for + * QPSK, 16QAM, 64QAM, and 256QAM constellations. + * + * For complex symbols `x_i`, the input sequence is assumed to be made of + * complex symbols `rx = rx_re + j * rx_im`, whose components I and Q are 16 + * bits each (format Q2.13) and in an interleaved form: * *
  *  {Re(0), Im(0), Re(1), Im(1), ..., Re(N - 1), Im(N - 1)}
@@ -1833,7 +1903,7 @@ armral_status armral_modulation(uint32_t nbits, armral_modulation_type mod_type,
  * The output of the soft-demodulation algorithm is a sequence of
  * log-likelihood ratio (LLR) `int8_t` values, which indicate the relative
  * confidence of the demapping decision, component by component, instead of
- * taking a hard decision and giving the bit value itself.
+ * taking a hard decision and giving the bit value itself. * * The LLRs are computed using a maximum likelihood approach. The LLR * calculations use a threshold method to approximate the maximum @@ -1886,13 +1956,16 @@ armral_status armral_demodulation(uint32_t n_symbols, uint16_t ulp, * * SQRT(SUM(y*conj(y)) - n*avg(y)*conj(avg(y))) *
* + * \warning n must be less than or equal to INT32_MAX, the largest number + * representable in a 32-bit signed integer. + * * @param[in] n The number of complex samples in each vector. * @param[in] p_src_a Points to the first input vector. * @param[in] p_src_b Points to the second input vector. * @param[out] c Points to the result. * @return An `armral_status` value that indicates success or failure. */ -armral_status armral_corr_coeff_i16(int32_t n, +armral_status armral_corr_coeff_i16(uint32_t n, const armral_cmplx_int16_t *p_src_a, const armral_cmplx_int16_t *p_src_b, armral_cmplx_int16_t *c); @@ -1903,7 +1976,7 @@ armral_status armral_corr_coeff_i16(int32_t n, * @ingroup groupLowPhy */ /** - * @addtogroup fir_filter FIR filter + * @addtogroup fir_filter FIR Filter * @{ * \brief FIR filter implemented for single-precision floating-point and 16-bit * signed integers. @@ -2153,7 +2226,7 @@ armral_status armral_mu_law_decompr_14bit( * */ /** - * The algorithm operates on a fixed block size of one Physical Resource Block + * The function operates on a fixed block size of one Physical Resource Block * (PRB). Each block consists of 12 16-bit complex resource elements. Each block * taken as input is compressed into 24 8-bit post-scaled samples and a common * unsigned scaling factor. @@ -2174,7 +2247,7 @@ armral_block_scaling_compr_8bit(uint32_t n_prb, const armral_cmplx_int16_t *src, const armral_cmplx_int16_t *scale); /** - * The algorithm operates on a fixed block size of one Physical Resource Block + * The function operates on a fixed block size of one Physical Resource Block * (PRB). Each block consists of 12 16-bit complex resource elements. Each block * taken as input is compressed into 24 9-bit post-scaled samples and a common * unsigned scaling factor. @@ -2195,7 +2268,7 @@ armral_block_scaling_compr_9bit(uint32_t n_prb, const armral_cmplx_int16_t *src, const armral_cmplx_int16_t *scale); /** - * The algorithm operates on a fixed block size of one Physical Resource Block + * The function operates on a fixed block size of one Physical Resource Block * (PRB). Each block consists of 12 16-bit complex resource elements. Each block * taken as input is compressed into 24 14-bit post-scaled samples and a common * unsigned scaling factor. @@ -2214,7 +2287,7 @@ armral_status armral_block_scaling_compr_14bit( armral_compressed_data_14bit *dst, const armral_cmplx_int16_t *scale); /** - * The algorithm operates on a fixed block size of one Physical Resource Block + * The function operates on a fixed block size of one Physical Resource Block * (PRB). Each block consists of 12 8-bit complex post-scaled resource elements * and an unsigned scaling factor. Each block taken as input is expanded into 12 * 16-bit complex samples. @@ -2233,7 +2306,7 @@ armral_status armral_block_scaling_decompr_8bit( armral_cmplx_int16_t *dst, const armral_cmplx_int16_t *scale); /** - * The algorithm operates on a fixed block size of one Physical Resource Block + * The function operates on a fixed block size of one Physical Resource Block * (PRB). Each block consists of 12 9-bit complex post-scaled resource elements * and an unsigned scaling factor. Each block taken as input is expanded into 12 * 16-bit complex samples. @@ -2252,7 +2325,7 @@ armral_status armral_block_scaling_decompr_9bit( armral_cmplx_int16_t *dst, const armral_cmplx_int16_t *scale); /** - * The algorithm operates on a fixed block size of one Physical Resource Block + * The function operates on a fixed block size of one Physical Resource Block * (PRB). Each block consists of 12 14-bit complex post-scaled resource elements * and an unsigned scaling factor. Each block taken as input is expanded into 12 * 16-bit complex samples. @@ -2285,7 +2358,7 @@ armral_status armral_block_scaling_decompr_14bit( /** * @brief Block floating-point compression to 8-bit * - * The algorithm operates on a fixed block size of one Resource Block (RB). + * The function operates on a fixed block size of one Resource Block (RB). * Each block consists of 12 16-bit complex resource elements. Each block taken * as input is compressed into 24 8-bit samples and one unsigned exponent. * @@ -2306,7 +2379,7 @@ armral_status armral_block_float_compr_8bit(uint32_t n_prb, /** * @brief Block floating point compression to 9-bit big-endian * - * The algorithm operates on a fixed block size of one Resource Block (RB). + * The function operates on a fixed block size of one Resource Block (RB). * Each block consists of 12 16-bit complex resource elements. Each block taken * as input is compressed into 24 9-bit big-endian samples and one unsigned * exponent. Big-endian means that where data from a 9-bit element is split @@ -2331,7 +2404,7 @@ armral_status armral_block_float_compr_9bit(uint32_t n_prb, /** * @brief Block floating point compression to 12-bit big-endian * - * The algorithm operates on a fixed block size of one Resource Block (RB). + * The function operates on a fixed block size of one Resource Block (RB). * Each block consists of 12 16-bit complex resource elements. Each block taken * as input is compressed into 24 12-bit big-endian samples and one unsigned * exponent. Big-endian means that where data from a 12-bit element is split @@ -2353,7 +2426,7 @@ armral_status armral_block_float_compr_12bit(uint32_t n_prb, /** * @brief Block floating point compression to 14-bit big-endian * - * The algorithm operates on a fixed block size of one Resource Block (RB). + * The function operates on a fixed block size of one Resource Block (RB). * Each block consists of 12 16-bit complex resource elements. Each block taken * as input is compressed into 24 14-bit big-endian samples and one unsigned * exponent. Big-endian means that where data from a 14-bit element is split @@ -2378,7 +2451,7 @@ armral_status armral_block_float_compr_14bit(uint32_t n_prb, /** * @brief Block floating-point decompression from 8 bit * - * The algorithm operates on a fixed block size of one Resource Block (RB). + * The function operates on a fixed block size of one Resource Block (RB). * Each block consists of 12 8-bit complex resource elements and an unsigned * exponent. Each block taken as input is expanded into 12 16-bit complex * samples. @@ -2399,7 +2472,7 @@ armral_status armral_block_float_decompr_8bit( /** * @brief Block floating point decompression from 9 bit big-endian * - * The algorithm operates on a fixed block size of one Resource Block (RB). + * The function operates on a fixed block size of one Resource Block (RB). * Each block consists of 12 9-bit big-endian complex resource elements and an * unsigned exponent. Each block taken as input is expanded into 12 16-bit * complex samples. Big-endian here means that where data from a 9-bit element @@ -2423,7 +2496,7 @@ armral_status armral_block_float_decompr_9bit( /** * @brief Block floating point decompression from 12 bit big-endian * - * The algorithm operates on a fixed block size of one Resource Block (RB). + * The function operates on a fixed block size of one Resource Block (RB). * Each block consists of 12 12-bit big-endian complex resource elements and an * unsigned exponent. Each block taken as input is expanded into 12 16-bit * complex samples. Big-endian here means that where data from a 12-bit element @@ -2447,7 +2520,7 @@ armral_status armral_block_float_decompr_12bit( /** * @brief Block floating point decompression from 14 bit big-endian * - * The algorithm operates on a fixed block size of one Resource Block (RB). + * The function operates on a fixed block size of one Resource Block (RB). * Each block consists of 12 14-bit big-endian complex resource elements and an * unsigned exponent. Each block taken as input is expanded into 12 16-bit * complex samples. Big-endian here means that where data from a 14-bit element @@ -2474,7 +2547,7 @@ armral_status armral_block_float_decompr_14bit( * @ingroup groupUpPhy */ /** - * @addtogroup crc CRC + * @addtogroup crc Cyclic Redundancy Check (CRC) * @{ * \brief Computes a Cyclic Redundancy Check (CRC) of an input buffer * using carry-less multiplication and Barret reduction. @@ -2647,7 +2720,7 @@ armral_status armral_crc6_be(uint32_t size, const uint64_t *input, * @ingroup groupUpPhy */ /** - * @addtogroup polar Polar encoding + * @addtogroup polar Polar Encoding * @{ * \brief In uplink, Polar codes are used to encode the Uplink Control * Information (UCI) over the Physical Uplink Control Channel (PUCCH) and @@ -2665,7 +2738,7 @@ armral_status armral_crc6_be(uint32_t size, const uint64_t *input, * `i` is included in the `frozen` bits set, then `u(i) = 0`. The input * information bits are stored in the remaining entries. * `[d] = [d(0), d(1), ..., d(N-1)]` is the vector of output encoded bits. - * `[G_N]` is the channel transformation matrix (`N-by-N`), obtained by + * `[G_N]` is the channel transformation matrix (`N`-by-`N`), obtained by * recursively applying the Kronecker product from the basic kernel `G_2 = |1 0; * 1 1|` to the order `n = log2(N)`. * @@ -3321,7 +3394,7 @@ armral_ldpc_get_base_graph(armral_ldpc_graph_t bg); * size and base graph. For base graph 1 the number of information bits per * code block is `22 * z`. For base graph 2 the number of information bits per * code block is `10 * z`. It is assumed that the correct number of input bits - * is passed into this routine. + * is passed into this function. * * @param[in] data_in The information bits to encode. It is assumed * that the number of bits stored in `data_in` fits @@ -3370,7 +3443,7 @@ armral_status armral_ldpc_encode_block(const uint8_t *data_in, * size and base graph. For base graph 1 the number of information bits per * code block is `22 * z`. For base graph 2 the number of information bits per * code block is `10 * z`. It is assumed that the correct number of input bits - * is passed into this routine. + * is passed into this function. * * This function takes a pre-allocated buffer (`buffer`) to use internally. * This variant will not call any system memory allocators. @@ -3813,7 +3886,7 @@ armral_status armral_ldpc_rate_recovery_noalloc( * for a single code block. */ /** - * This routine implements the LTE Turbo encoding scheme described in 3GPP + * This function implements the LTE Turbo encoding scheme described in 3GPP * Technical Specification (TS) 36.212 "Multiplexing and channel coding". It * takes as input an array `src` of length `k` bits, where `k` must be one of * the values defined in TS 36.212 Table 5.1.3-3. The outputs of the encoding @@ -3879,7 +3952,7 @@ armral_status armral_turbo_encode_block_noalloc(const uint8_t *src, uint32_t k, uint8_t *dst2, void *buffer); /** - * This routine implements a maximum a posteriori (MAP) algorithm to decode the + * This function implements a maximum a posteriori (MAP) algorithm to decode the * output of the LTE Turbo encoding scheme described in 3GPP Technical * Specification (TS) 36.212 "Multiplexing and channel coding". It takes as * input three arrays `sys`, `par` and `itl`, each of length `k + 4` bits where @@ -3890,9 +3963,9 @@ armral_status armral_turbo_encode_block_noalloc(const uint8_t *src, uint32_t k, * * The output is written into the array `dst`, which must contain enough bytes * to store `k` bits. These are hard outputs (that is, either 0 or 1); the - * routine does not return LLRs. + * function does not return LLRs. * - * The routine takes a parameter `max_iter`, which specifies the + * The function takes a parameter `max_iter`, which specifies the * maximum number of iterations that the decoder will perform. The * algorithm will terminate in fewer iterations if there is no change * in the computed LLRs between consecutive iterations. @@ -3926,9 +3999,9 @@ armral_status armral_turbo_decode_block(const int8_t *sys, const int8_t *par, * * The output is written into the array `dst`, which must contain enough bytes * to store `k` bits. These are hard outputs (that is, either 0 or 1); the - * routine does not return LLRs. + * function does not return LLRs. * - * The routine takes a parameter `max_iter`, which specifies the + * The function takes a parameter `max_iter`, which specifies the * maximum number of iterations that the decoder will perform. The * algorithm will terminate in fewer iterations if there is no change * in the computed LLRs between consecutive iterations. @@ -4156,7 +4229,7 @@ uint32_t armral_turbo_rate_recovery_noalloc_buffer_size(uint32_t d, uint32_t e, * @ingroup groupUpPhy */ /** - * @addtogroup conv LTE convolutional coding + * @addtogroup conv LTE Convolutional Coding * @{ * \brief Performs encoding and decoding of data using LTE tail biting * convolutional coding. The encoding scheme is defined in section 5.1.3.1 of @@ -4168,7 +4241,7 @@ uint32_t armral_turbo_rate_recovery_noalloc_buffer_size(uint32_t d, uint32_t e, * performed for a single code block. */ /** - * This routine implements the LTE tail biting convolutional encoding scheme + * This function implements the LTE tail biting convolutional encoding scheme * described in 3GPP Technical Specification (TS) 36.212 "Multiplexing and * channel coding". It takes as input an array `src` of length `k` bits. The * outputs of the encoding are written into the three arrays `dst0`, `dst1`, and @@ -4196,7 +4269,7 @@ armral_status armral_tail_biting_convolutional_encode_block(const uint8_t *src, uint8_t *dst2); /** - * This routine implements the Wrap Around Viterbi Algorithm (WAVA) to decode + * This function implements the Wrap Around Viterbi Algorithm (WAVA) to decode * the output of the LTE tail biting convolutional coding scheme described in * 3GPP Technical Specification (TS) 36.212 "Multiplexing and channel coding". * It takes as input three arrays containing the log-likelihood ratios (LLRs) @@ -4280,20 +4353,20 @@ uint32_t armral_tail_biting_convolutional_decode_block_noalloc_buffer_size( * @ingroup groupMatrix */ /** - * @addtogroup svd SVD decomposition of single complex matrix + * @addtogroup svd SVD of a Single Complex Matrix * @{ * \brief The Singular Value Decomposition (SVD) is used for selecting * orthogonal user equipment pairing in mMIMO channels. */ /** * - * This algorithm performs the Singular Value Decomposition (SVD) - * of an `M-by-N` single complex matrix `A`, where `M ≥ N` and `A` is stored in - * column-major order. The SVD of `A` is a two-sided decomposition in the form - * `A = U Σ V^H`, with `U` an `M-by-M` single complex orthogonal matrix. + * This function performs the Singular Value Decomposition (SVD) + * of an `M`-by-`N` single complex matrix `A`, where `M ≥ N` and `A` is stored + * in column-major order. The SVD of `A` is a two-sided decomposition in the + * form `A = U Σ V^H`, with `U` an `M`-by-`M` single complex orthogonal matrix. * Note that we only store the first `N` columns of `U` because there - * are at most `N` non-zero singular values. `V` is an `N-by-N` - * single complex orthogonal matrix, and `Σ` is an `M-by-N` + * are at most `N` non-zero singular values. `V` is an `N`-by-`N` + * single complex orthogonal matrix, and `Σ` is an `M`-by-`N` * real matrix. Entries `Σ_{i,i}`, `i < n` contain the singular * values, and other entries in `Σ` are zero. We only store * the singular values, not the full matrix `Σ`. The singular @@ -4313,7 +4386,7 @@ uint32_t armral_tail_biting_convolutional_decode_block_noalloc_buffer_size( * only the singular values are computed. * @param[in] m The number of rows (`M`) in matrix `A`. * @param[in] n The number of columns (`N`) in matrix `A`. - * @param[in,out] a On entry contains the `M-by-N` matrix + * @param[in,out] a On entry contains the `M`-by-`N` matrix * on which to perform the SVD. On exit contains * the Householder reflectors used to * perform the bidiagonalization of `A`. @@ -4321,29 +4394,31 @@ uint32_t armral_tail_biting_convolutional_decode_block_noalloc_buffer_size( * @param[out] u The left singular vectors, if required. * If `vect` is true, `u` is populated with * the left singular vectors in the SVD. - * A memory of `M-by-N` is assumed to have - * been allocated before the call to this method. + * Sufficient memory to store `M N` `float32_t` + * values is assumed to have been allocated before + * the call to this method. * @param[out] vt The right singular vectors, if required. * If `vect` is true, `vt` is populated with * the right singular vectors in the SVD. - * A memory of `N-by-N` is assumed to have been - * allocated before the call to this method. + * Sufficient memory to store `N N` `float32_t` + * values is assumed to have been allocated before + * the call to this method. * @return An `armral_status` value that indicates success or failure. */ -armral_status armral_svd_cf32(bool vect, int m, int n, armral_cmplx_f32_t *a, - float32_t *s, armral_cmplx_f32_t *u, - armral_cmplx_f32_t *vt); +armral_status armral_svd_cf32(bool vect, uint32_t m, uint32_t n, + armral_cmplx_f32_t *a, float32_t *s, + armral_cmplx_f32_t *u, armral_cmplx_f32_t *vt); /** * Non-allocating variant of \link armral_svd_cf32 \endlink. * * This function performs the Singular Value Decomposition (SVD) - * of an `M-by-N` single complex matrix `A`, where `M ≥ N` and `A` is stored in - * column-major order. The SVD of `A` is a two-sided decomposition in the form - * `A = U Σ V^H`, with `U` an `M-by-M` single complex orthogonal matrix. + * of an `M`-by-`N` single complex matrix `A`, where `M ≥ N` and `A` is stored + * in column-major order. The SVD of `A` is a two-sided decomposition in the + * form `A = U Σ V^H`, with `U` an `M`-by-`M` single complex orthogonal matrix. * Note that we only store the first `N` columns of `U` because there - * are at most `N` non-zero singular values. `V` is an `N-by-N` - * single complex orthogonal matrix, and `Σ` is an `M-by-N` + * are at most `N` non-zero singular values. `V` is an `N`-by-`N` + * single complex orthogonal matrix, and `Σ` is an `M`-by-`N` * real matrix. Entries `Σ_{i,i}`, `i < n` contain the singular * values, and other entries in `Σ` are zero. We only store * the singular values, not the full matrix `Σ`. The singular @@ -4370,7 +4445,7 @@ armral_status armral_svd_cf32(bool vect, int m, int n, armral_cmplx_f32_t *a, * only the singular values are computed. * @param[in] m The number of rows (`M`) in matrix `A`. * @param[in] n The number of columns (`N`) in matrix `A`. - * @param[in,out] a On entry contains the `M-by-N` matrix + * @param[in,out] a On entry contains the `M`-by-`N` matrix * on which to perform the SVD. On exit contains * the Householder reflectors used to * perform the bidiagonalization of `A`. @@ -4378,24 +4453,26 @@ armral_status armral_svd_cf32(bool vect, int m, int n, armral_cmplx_f32_t *a, * @param[out] u The left singular vectors, if required. * If `vect` is true, `u` is populated with * the left singular vectors in the SVD. - * A memory of `M-by-N` is assumed to have - * been allocated before the call to this method. + * Sufficient memory to store `M N` `float32_t` + * values is assumed to have been allocated before + * the call to this method. * @param[out] vt The right singular vectors, if required. * If `vect` is true, `vt` is populated with * the right singular vectors in the SVD. - * A memory of `N-by-N` is assumed to have been - * allocated before the call to this method. + * Sufficient memory to store `N N` `float32_t` + * values is assumed to have been allocated before + * the call to this method. * @param[in] buffer Workspace buffer to be used internally. * @return An `armral_status` value that indicates success or failure. */ -armral_status armral_svd_cf32_noalloc(bool vect, int m, int n, +armral_status armral_svd_cf32_noalloc(bool vect, uint32_t m, uint32_t n, armral_cmplx_f32_t *a, float32_t *s, armral_cmplx_f32_t *u, armral_cmplx_f32_t *vt, void *buffer); /** * Calculates the required buffer size in bytes needed to perform a singular - * value decomposition (SVD) of an `M-by-N` input matrix `A`. + * value decomposition (SVD) of an `M`-by-`N` input matrix `A`. * * @param[in] vect If true, both the singular values and * the singular vectors are computed, else @@ -4404,7 +4481,7 @@ armral_status armral_svd_cf32_noalloc(bool vect, int m, int n, * @param[in] n The number of columns (`N`) in matrix `A`. * @return The required buffer size in bytes. */ -uint32_t armral_svd_cf32_noalloc_buffer_size(bool vect, int m, int n); +uint32_t armral_svd_cf32_noalloc_buffer_size(bool vect, uint32_t m, uint32_t n); /** @} end svd */ @@ -4425,7 +4502,7 @@ uint32_t armral_svd_cf32_noalloc_buffer_size(bool vect, int m, int n); */ /** * - * This algorithm generates a block of scrambled bits using a pseudo-random + * This function generates a block of scrambled bits using a pseudo-random * sequence according to the scrambler described in the 3GPP Technical * Specification (TS) 38.211. For a codeword `b` with length `M` transmitted on * the physical channel, the block of bits `b(0), ..., b(M - 1)` is scrambled @@ -4437,7 +4514,7 @@ uint32_t armral_svd_cf32_noalloc_buffer_size(bool vect, int m, int n); * * where `s(0), ..., s(M - 1)` are the scrambled bits and `c` is a * pseudo-random scrambling sequence defined by a length-31 Gold sequence. - * Note that this routine cannot be used to scramble a transport block, as + * Note that this function cannot be used to scramble a transport block, as * defined in TS 38.212 section 7.1.2. * * @param[in] src Points to the array of input bits. diff --git a/simulation/CMakeLists.txt b/simulation/CMakeLists.txt index d0c5438..89eaf50 100644 --- a/simulation/CMakeLists.txt +++ b/simulation/CMakeLists.txt @@ -13,25 +13,6 @@ add_library(simulation_common INTERFACE) target_include_directories(simulation_common INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/include) -function(set_omp_cxx_flags) - if(NOT OpenMP_CXX_FLAGS STREQUAL "NOTFOUND") - return() - endif() - check_c_compiler_flag(-fopenmp OPENMP_FLAG_IS_VALID) - if(OPENMP_FLAG_IS_VALID) - set(OpenMP_CXX_FLAGS - "-fopenmp" - PARENT_SCOPE) - else() - check_c_compiler_flag(-fopenmp=libomp OPENMP_FLAG_IS_VALID) - if(OPENMP_FLAG_IS_VALID) - set(OpenMP_CXX_FLAGS - "-fopenmp=libomp" - PARENT_SCOPE) - endif() - endif() -endfunction() - find_package(OpenMP) find_package(Threads) if(Threads_FOUND) @@ -40,7 +21,6 @@ if(Threads_FOUND) if(NOT TARGET OpenMP::OpenMP_CXX) add_library(OpenMP::OpenMP_CXX IMPORTED INTERFACE) - set_omp_cxx_flags() if(OpenMP_CXX_FLAGS STREQUAL "NOTFOUND") # Sometimes we are failing to find OpenMP in testing. Needs more # investigation, but in the meantime, just don't build the project @@ -64,8 +44,6 @@ if(Threads_FOUND) set_property(TARGET OpenMP::OpenMP_CXX PROPERTY INTERFACE_COMPILE_OPTIONS ${OpenMP_CXX_FLAGS}) - # Only works if the same flag is passed to the linker; use CMake 3.9+ - # otherwise. set_property( TARGET OpenMP::OpenMP_CXX PROPERTY INTERFACE_LINK_LIBRARIES ${OpenMP_CXX_FLAGS} Threads::Threads) diff --git a/simulation/convolutional_awgn/convolutional_awgn.cpp b/simulation/convolutional_awgn/convolutional_awgn.cpp index 6c74c89..25d939c 100644 --- a/simulation/convolutional_awgn/convolutional_awgn.cpp +++ b/simulation/convolutional_awgn/convolutional_awgn.cpp @@ -4,8 +4,8 @@ */ #include "armral.h" #include "awgn.hpp" -#include "bit_utils.hpp" #include "simulation_common.hpp" +#include "utils/bits_to_bytes.hpp" #include #include @@ -158,11 +158,12 @@ int run_check(armral::utils::random_state *state, double snr_db, uint32_t ulp, data->len_out, iter_max, data->data_decoded); // To make it easier to compare the values, convert the bit array to a byte // array - bits_to_bytes(data->len_out, data->data_decoded, data->data_decoded_bytes); + armral::bits_to_bytes(data->len_out, data->data_decoded, + data->data_decoded_bytes); // Check the number of errors in decoding int num_errors = 0; - bits_to_bytes(data->len_in, data->data_in, data->data_in_bytes); + armral::bits_to_bytes(data->len_in, data->data_in, data->data_in_bytes); for (uint32_t i = 0; i < data->len_in; ++i) { if (data->data_decoded_bytes[i] != data->data_in_bytes[i]) { num_errors++; diff --git a/simulation/ldpc_awgn/ldpc_awgn.cpp b/simulation/ldpc_awgn/ldpc_awgn.cpp index 19db5f7..a7890ae 100644 --- a/simulation/ldpc_awgn/ldpc_awgn.cpp +++ b/simulation/ldpc_awgn/ldpc_awgn.cpp @@ -4,8 +4,8 @@ */ #include "armral.h" #include "awgn.hpp" -#include "bit_utils.hpp" #include "simulation_common.hpp" +#include "utils/bits_to_bytes.hpp" #include #include @@ -190,8 +190,8 @@ int run_check(armral::utils::random_state *state, uint32_t z, data->data_encoded); // To make it easier to compare the bits, convert the bit array to a byte // array - bits_to_bytes(data->len_encoded, data->data_encoded, - data->data_encoded_bytes); + armral::bits_to_bytes(data->len_encoded, data->data_encoded, + data->data_encoded_bytes); // Rate match data_encoded to create an array of length e bits from // num_mod_symbols * bit_per_symbol bits. @@ -227,11 +227,12 @@ int run_check(armral::utils::random_state *state, uint32_t z, // To make it easier to compare the values, convert the bit array to a byte // // array - bits_to_bytes(data->len_out, data->data_decoded, data->data_decoded_bytes); + armral::bits_to_bytes(data->len_out, data->data_decoded, + data->data_decoded_bytes); // Check the number of errors in decoding int num_errors = 0; - bits_to_bytes(2 * z, data->data_in, data->data_in_bytes); + armral::bits_to_bytes(2 * z, data->data_in, data->data_in_bytes); // Check that the punctured columns are the same as the input data for (uint32_t i = 0; i < 2 * z; ++i) { if (data->data_decoded_bytes[i] != data->data_in_bytes[i]) { diff --git a/simulation/modulation_awgn/modulation_awgn.cpp b/simulation/modulation_awgn/modulation_awgn.cpp index bd493e5..79f1001 100644 --- a/simulation/modulation_awgn/modulation_awgn.cpp +++ b/simulation/modulation_awgn/modulation_awgn.cpp @@ -3,8 +3,8 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "awgn.hpp" -#include "bit_utils.hpp" #include "simulation_common.hpp" +#include "utils/bits_to_bytes.hpp" #include #include @@ -87,7 +87,7 @@ int run_check(armral::utils::random_state *state, double snr_db, uint32_t ulp, data->data_mod, data->data_demod_soft); // Check the number of errors - bits_to_bytes(data->len_in, data->data_in, data->data_in_bytes); + armral::bits_to_bytes(data->len_in, data->data_in, data->data_in_bytes); int num_errors = 0; for (uint32_t i = 0; i < data->len_in; ++i) { uint8_t demod_hard = data->data_demod_soft[i] < 0 ? 1 : 0; diff --git a/simulation/polar_awgn/polar_awgn.cpp b/simulation/polar_awgn/polar_awgn.cpp index c25aa41..6c179ba 100644 --- a/simulation/polar_awgn/polar_awgn.cpp +++ b/simulation/polar_awgn/polar_awgn.cpp @@ -4,8 +4,8 @@ */ #include "armral.h" #include "awgn.hpp" -#include "bit_utils.hpp" #include "simulation_common.hpp" +#include "utils/bits_to_bytes.hpp" #include #include @@ -202,7 +202,8 @@ int run_check(armral::utils::random_state *state, double snr_db, // Convert the data to a byte array rather than a bit string. This makes // comparison of the data easier later on, where we want to count the number // of incorrect bits - bits_to_bytes(data->n, data->data_interleave, data->data_interleave_bytes); + armral::bits_to_bytes(data->n, data->data_interleave, + data->data_interleave_bytes); armral_polar_encode_block(data->n, data->data_interleave, data->data_encoded); @@ -256,8 +257,8 @@ int run_check(armral::utils::random_state *state, double snr_db, // arrays std::vector data_in_bytes(data->k); std::vector data_deint0_bytes(data->k); - bits_to_bytes(data->k, data->data_in, data_in_bytes.data()); - bits_to_bytes(data->k, data_deint0.data(), data_deint0_bytes.data()); + armral::bits_to_bytes(data->k, data->data_in, data_in_bytes.data()); + armral::bits_to_bytes(data->k, data_deint0.data(), data_deint0_bytes.data()); for (uint32_t i = 0; i < data->k; ++i) { if (data_deint0_bytes[i] != data_in_bytes[i]) { num_errors++; diff --git a/simulation/turbo_awgn/turbo_awgn.cpp b/simulation/turbo_awgn/turbo_awgn.cpp index 7dbc271..779c017 100644 --- a/simulation/turbo_awgn/turbo_awgn.cpp +++ b/simulation/turbo_awgn/turbo_awgn.cpp @@ -4,8 +4,8 @@ */ #include "armral.h" #include "awgn.hpp" -#include "bit_utils.hpp" #include "simulation_common.hpp" +#include "utils/bits_to_bytes.hpp" #include #include @@ -233,11 +233,12 @@ int run_check(armral::utils::random_state *state, double snr_db, uint32_t ulp, // To make it easier to compare the values, convert the bit array to a byte // array - bits_to_bytes(data->len_out, data->data_decoded, data->data_decoded_bytes); + armral::bits_to_bytes(data->len_out, data->data_decoded, + data->data_decoded_bytes); // Check the number of errors in decoding int num_errors = 0; - bits_to_bytes(data->len_in, data->data_in, data->data_in_bytes); + armral::bits_to_bytes(data->len_in, data->data_in, data->data_in_bytes); for (uint32_t i = 0; i < data->len_in; ++i) { if (data->data_decoded_bytes[i] != data->data_in_bytes[i]) { num_errors++; diff --git a/src/BasicMathFun/MatrixInv/arm_cmplx_hermitian_mat_inversion_f32.cpp b/src/BasicMathFun/MatrixInv/arm_cmplx_hermitian_mat_inversion_f32.cpp index c3941c2..591c01d 100644 --- a/src/BasicMathFun/MatrixInv/arm_cmplx_hermitian_mat_inversion_f32.cpp +++ b/src/BasicMathFun/MatrixInv/arm_cmplx_hermitian_mat_inversion_f32.cpp @@ -2716,7 +2716,7 @@ void invert_hermitian_matrix<16>(const armral_cmplx_f32_t *__restrict p_src, // E = C*A^-1 */ armral_cmplx_f32_t e[8 * 8]; - armral_cmplx_mat_mult_f32(8, 8, 8, c, a_inv, e); + armral_cmplx_matmul_f32(8, 8, 8, c, a_inv, e); // G = D - C*A^-1*B */ armral_cmplx_f32_t g[8 * 8]; @@ -2802,8 +2802,8 @@ void invert_hermitian_matrix<16>(const armral_cmplx_f32_t *__restrict p_src, /*Calculate E = (CA^-1) */ armral_cmplx_f32_t mat_e[64]; armral_cmplx_f32_t *p_mat_e = mat_e; - armral_cmplx_mat_mult_f32(8, 8, 8, (armral_cmplx_f32_t *)p_mat_c, - (armral_cmplx_f32_t *)p_inv_a, p_mat_e); + armral_cmplx_matmul_f32(8, 8, 8, (armral_cmplx_f32_t *)p_mat_c, + (armral_cmplx_f32_t *)p_inv_a, p_mat_e); /*Calculate F=(D-(CA^-1) * B) */ float32x4_t temp_mat[32]; @@ -2814,8 +2814,8 @@ void invert_hermitian_matrix<16>(const armral_cmplx_f32_t *__restrict p_src, const armral_cmplx_f32_t *p_mat_f = (armral_cmplx_f32_t *)mat_f; armral_cmplx_f32_t *p_inv_f = (armral_cmplx_f32_t *)inv_f; - armral_cmplx_mat_mult_f32(8, 8, 8, mat_e, (armral_cmplx_f32_t *)p_mat_b, - p_temp); + armral_cmplx_matmul_f32(8, 8, 8, mat_e, (armral_cmplx_f32_t *)p_mat_b, + p_temp); for (int i = 0; i < 16; i++) { mat_f[2 * i] = vsubq_f32(mat_d[2 * i], temp_mat[2 * i]); @@ -2826,7 +2826,7 @@ void invert_hermitian_matrix<16>(const armral_cmplx_f32_t *__restrict p_src, invert_hermitian_matrix<8>(p_mat_f, p_inv_f); /*calculate block 2-1 -F^-1 * E */ - armral_cmplx_mat_mult_f32(8, 8, 8, p_inv_f, p_mat_e, p_temp); + armral_cmplx_matmul_f32(8, 8, 8, p_inv_f, p_mat_e, p_temp); float32x4_t block21[32]; for (int i = 0; i < 16; i++) { @@ -2838,11 +2838,11 @@ void invert_hermitian_matrix<16>(const armral_cmplx_f32_t *__restrict p_src, armral_cmplx_f32_t mat_g[64]; armral_cmplx_f32_t *p_mat_g = mat_g; - armral_cmplx_mat_mult_f32(8, 8, 8, p_inv_a, (armral_cmplx_f32_t *)p_mat_b, - p_mat_g); + armral_cmplx_matmul_f32(8, 8, 8, p_inv_a, (armral_cmplx_f32_t *)p_mat_b, + p_mat_g); /*Calculate block 1-2 G*invF*/ - armral_cmplx_mat_mult_f32(8, 8, 8, p_mat_g, p_inv_f, p_temp); + armral_cmplx_matmul_f32(8, 8, 8, p_mat_g, p_inv_f, p_temp); float32x4_t block12[32]; @@ -2855,7 +2855,7 @@ void invert_hermitian_matrix<16>(const armral_cmplx_f32_t *__restrict p_src, float32x4_t temp_mat2[32]; armral_cmplx_f32_t *p_temp2 = (armral_cmplx_f32_t *)temp_mat2; - armral_cmplx_mat_mult_f32(8, 8, 8, p_temp, p_mat_e, p_temp2); + armral_cmplx_matmul_f32(8, 8, 8, p_temp, p_mat_e, p_temp2); float32x4_t block11[32]; diff --git a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_f32.c b/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_f32.c deleted file mode 100644 index b24f257..0000000 --- a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_f32.c +++ /dev/null @@ -1,1165 +0,0 @@ -/* - Arm RAN Acceleration Library - SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates -*/ -#include "armral.h" - -#ifdef ARMRAL_ARCH_SVE -#include -#endif - -#ifndef ARMRAL_ARCH_SVE -static inline float32x4_t __attribute__((always_inline)) -vzip1q_f32x2(float32x4_t a, float32x4_t b) { - // This zips a pair of 32-bit floats in a 128-bit vector, e.g. given 32-bit - // vectors - // ^: a = [a0, a1, a2, a3] - // ^: b = [b0, b1, b2, b3] - // ^: returns - // ^: c = [a0, a1, b0, b1] - return vreinterpretq_f32_f64( - vzip1q_f64(vreinterpretq_f64_f32(a), vreinterpretq_f64_f32(b))); -} -#endif - -#ifndef ARMRAL_ARCH_SVE -static inline float32x4_t __attribute__((always_inline)) -vzip2q_f32x2(float32x4_t a, float32x4_t b) { - // This zips a pair of 32-bit floats in 128-bit vector, e.g. given 32-bit - // vectors - // ^: a = [a0, a1, a2, a3] - // ^: b = [b0, b1, b2, b3] - // ^: returns - // ^: c = [a2, a3, b2, b3] - return vreinterpretq_f32_f64( - vzip2q_f64(vreinterpretq_f64_f32(a), vreinterpretq_f64_f32(b))); -} -#endif - -#ifdef ARMRAL_ARCH_SVE -static inline svfloat32_t __attribute__((always_inline)) -svtrn2iq_f32(svfloat32_t a) { - // Interleaves 32-bit floating point numbers at odd 64-bit indices in an - // SVE vector, e.g. given 256-bit vector - // ^: a = [a0, a1, a2, a3, a4, a5, a6, a7, ...] - // ^: returns - // ^: c = [a2, a3, a2, a3, a6, a7, a6, a7, ...] - return svreinterpret_f32_f64( - svtrn2_f64(svreinterpret_f64_f32(a), svreinterpret_f64_f32(a))); -} -#endif - -armral_status -armral_cmplx_mat_mult_2x2_f32(const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, - armral_cmplx_f32_t *p_dst) { - const float32_t *a_ptr = (const float32_t *)p_src_a; - const float32_t *b_ptr = (const float32_t *)p_src_b; - float32_t *out_ptr = (float32_t *)p_dst; - -#ifdef ARMRAL_ARCH_SVE - svbool_t p4 = svptrue_pat_b32(SV_VL4); - svfloat32_t a0 = svld1_f32(p4, &a_ptr[0]); - svfloat32_t a1 = svld1_f32(p4, &a_ptr[4]); - svfloat32_t b0 = svld1_f32(p4, &b_ptr[0]); - svfloat32_t b1 = svld1_f32(p4, &b_ptr[4]); - svfloat32_t c0 = svdup_n_f32(0); - svfloat32_t c1 = svdup_n_f32(0); - - c0 = svcmla_lane_f32(c0, a0, b0, 0, 0); - c0 = svcmla_lane_f32(c0, a0, b0, 0, 90); - c0 = svcmla_lane_f32(c0, a1, b0, 1, 0); - c0 = svcmla_lane_f32(c0, a1, b0, 1, 90); - c1 = svcmla_lane_f32(c1, a0, b1, 0, 0); - c1 = svcmla_lane_f32(c1, a0, b1, 0, 90); - c1 = svcmla_lane_f32(c1, a1, b1, 1, 0); - c1 = svcmla_lane_f32(c1, a1, b1, 1, 90); - - svst1_f32(p4, &out_ptr[0], c0); - svst1_f32(p4, &out_ptr[4], c1); -#else - float32x2x2_t a_col[2]; - float32x2x2_t b[2]; - float32x2x2_t result[2]; - - a_col[0] = vld2_f32(a_ptr); - a_ptr = a_ptr + 4; - - b[0] = vld2_f32(b_ptr); - b_ptr = b_ptr + 4; - - // result[0] 4 rows elem 1 RE * first column elem 1 RE - result[0].val[0] = vmul_lane_f32(a_col[0].val[0], b[0].val[0], 0); - // result[0] 4 rows elem 1 IM * first column elem 1 IM - result[0].val[0] = - vfms_lane_f32(result[0].val[0], a_col[0].val[1], b[0].val[1], 0); - b[1] = vld2_f32(b_ptr); - b_ptr = b_ptr + 4; - // result[1] 4 rows elem 1 IM * first row elem 1 RE - result[0].val[1] = vmul_lane_f32(a_col[0].val[1], b[0].val[0], 0); - // result[1] 4 rows elem 1 RE * first row elem 1 IM - result[0].val[1] = - vfma_lane_f32(result[0].val[1], a_col[0].val[0], b[0].val[1], 0); - a_col[1] = vld2_f32(a_ptr); - a_ptr = a_ptr + 4; - - // result[1].val[0] 4 rows elem 1 RE * second row elem 1 RE - result[1].val[0] = vmul_lane_f32(a_col[0].val[0], b[1].val[0], 0); - // result[1].val[0] 4 rows elem 1 IM * second row elem 1 IM - result[1].val[0] = - vfms_lane_f32(result[1].val[0], a_col[0].val[1], b[1].val[1], 0); - result[1].val[1] = vmul_lane_f32(a_col[0].val[1], b[1].val[0], 0); - result[1].val[1] = - vfma_lane_f32(result[1].val[1], a_col[0].val[0], b[1].val[1], 0); - - // result[0] 4 rows elem 2 RE * first row elem 2 RE - result[0].val[0] = - vfma_lane_f32(result[0].val[0], a_col[1].val[0], b[0].val[0], 1); - // result[0] 4 rows elem 2 IM * first row elem 2 IM - result[0].val[0] = - vfms_lane_f32(result[0].val[0], a_col[1].val[1], b[0].val[1], 1); - // result[1] 4 rows elem 2 IM * first row elem 2 RE - result[0].val[1] = - vfma_lane_f32(result[0].val[1], a_col[1].val[1], b[0].val[0], 1); - // result[1] 4 rows elem 2 RE * first row elem 2 IM - result[0].val[1] = - vfma_lane_f32(result[0].val[1], a_col[1].val[0], b[0].val[1], 1); - - // result[0] 4 rows elem 2 RE * second row elem 2 RE - result[1].val[0] = - vfma_lane_f32(result[1].val[0], a_col[1].val[0], b[1].val[0], 1); - // result[0] 4 rows elem 2 IM * second row elem 2 IM - result[1].val[0] = - vfms_lane_f32(result[1].val[0], a_col[1].val[1], b[1].val[1], 1); - // result[1] 4 rows elem 2 IM * second row elem 2 RE - result[1].val[1] = - vfma_lane_f32(result[1].val[1], a_col[1].val[1], b[1].val[0], 1); - // result[1] 4 rows elem 2 RE * second row elem 2 IM - result[1].val[1] = - vfma_lane_f32(result[1].val[1], a_col[1].val[0], b[1].val[1], 1); - - vst2_f32(out_ptr, result[0]); - out_ptr = out_ptr + 4; - - vst2_f32(out_ptr, result[1]); - out_ptr = out_ptr + 4; -#endif - - return ARMRAL_SUCCESS; -} - -armral_status armral_cmplx_mat_mult_2x2_f32_iq( - const float32_t *restrict src_a_re, const float32_t *restrict src_a_im, - const float32_t *restrict src_b_re, const float32_t *restrict src_b_im, - float32_t *dst_re, float32_t *dst_im) { - -#ifdef ARMRAL_ARCH_SVE - svbool_t p4 = svptrue_pat_b32(SV_VL4); - svfloat32_t a_re = svld1_f32(p4, src_a_re); - svfloat32_t a_im = svld1_f32(p4, src_a_im); - svfloat32_t b_re = svld1_f32(p4, src_b_re); - svfloat32_t b_im = svld1_f32(p4, src_b_im); - - svfloat32_t c_re; - svfloat32_t c_im; - - svfloat32_t tmp_a_re = svtrn2iq_f32(a_re); - svfloat32_t tmp_a_im = svtrn2iq_f32(a_im); - svfloat32_t tmp_b_re = svtrn2(b_re, b_re); - svfloat32_t tmp_b_im = svtrn2(b_im, b_im); - - c_re = svmul_f32_x(p4, tmp_a_re, tmp_b_re); - c_re = svmls_f32_x(p4, c_re, tmp_a_im, tmp_b_im); - c_re = svcmla_lane_f32(c_re, b_re, a_re, 0, 0); - c_re = svcmla_lane_f32(c_re, b_im, a_im, 0, 180); - - c_im = svmul_f32_x(p4, tmp_a_re, tmp_b_im); - c_im = svmla_f32_x(p4, c_im, tmp_a_im, tmp_b_re); - c_im = svcmla_lane_f32(c_im, b_im, a_re, 0, 0); - c_im = svcmla_lane_f32(c_im, b_re, a_im, 0, 0); - - svst1_f32(p4, dst_re, c_re); - svst1_f32(p4, dst_im, c_im); -#else - - float32x2_t a_col0_re = vld1_f32(&src_a_re[0]); - float32x2_t a_col0_im = vld1_f32(&src_a_im[0]); - float32x2_t a_col1_re = vld1_f32(&src_a_re[2]); - float32x2_t a_col1_im = vld1_f32(&src_a_im[2]); - float32x2_t b0_re = vld1_f32(&src_b_re[0]); - float32x2_t b0_im = vld1_f32(&src_b_im[0]); - float32x2_t b1_re = vld1_f32(&src_b_re[2]); - float32x2_t b1_im = vld1_f32(&src_b_im[2]); - - float32x2x2_t result[2]; - result[0].val[0] = vmul_lane_f32(a_col0_re, b0_re, 0); - result[0].val[0] = vfms_lane_f32(result[0].val[0], a_col0_im, b0_im, 0); - result[0].val[1] = vmul_lane_f32(a_col0_im, b0_re, 0); - result[0].val[1] = vfma_lane_f32(result[0].val[1], a_col0_re, b0_im, 0); - - result[1].val[0] = vmul_lane_f32(a_col0_re, b1_re, 0); - result[1].val[0] = vfms_lane_f32(result[1].val[0], a_col0_im, b1_im, 0); - result[1].val[1] = vmul_lane_f32(a_col0_im, b1_re, 0); - result[1].val[1] = vfma_lane_f32(result[1].val[1], a_col0_re, b1_im, 0); - - result[0].val[0] = vfma_lane_f32(result[0].val[0], a_col1_re, b0_re, 1); - result[0].val[0] = vfms_lane_f32(result[0].val[0], a_col1_im, b0_im, 1); - result[0].val[1] = vfma_lane_f32(result[0].val[1], a_col1_im, b0_re, 1); - result[0].val[1] = vfma_lane_f32(result[0].val[1], a_col1_re, b0_im, 1); - - result[1].val[0] = vfma_lane_f32(result[1].val[0], a_col1_re, b1_re, 1); - result[1].val[0] = vfms_lane_f32(result[1].val[0], a_col1_im, b1_im, 1); - result[1].val[1] = vfma_lane_f32(result[1].val[1], a_col1_im, b1_re, 1); - result[1].val[1] = vfma_lane_f32(result[1].val[1], a_col1_re, b1_im, 1); - - vst1_f32(&dst_re[0], result[0].val[0]); - vst1_f32(&dst_im[0], result[0].val[1]); - vst1_f32(&dst_re[2], result[1].val[0]); - vst1_f32(&dst_im[2], result[1].val[1]); -#endif - - return ARMRAL_SUCCESS; -} - -armral_status -armral_cmplx_mat_mult_4x4_f32(const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, - armral_cmplx_f32_t *p_dst) { - const float32_t *a_ptr = (const float32_t *)p_src_a; - const float32_t *b_ptr = (const float32_t *)p_src_b; - float32_t *out_ptr = (float32_t *)p_dst; -#ifdef ARMRAL_ARCH_SVE - - svbool_t p4 = svptrue_pat_b32(SV_VL4); - - svfloat32_t a00 = svld1_f32(p4, &a_ptr[0 * 4 + 0 * 8]); - svfloat32_t a10 = svld1_f32(p4, &a_ptr[1 * 4 + 0 * 8]); - svfloat32_t a01 = svld1_f32(p4, &a_ptr[0 * 4 + 1 * 8]); - svfloat32_t a11 = svld1_f32(p4, &a_ptr[1 * 4 + 1 * 8]); - svfloat32_t a02 = svld1_f32(p4, &a_ptr[0 * 4 + 2 * 8]); - svfloat32_t a12 = svld1_f32(p4, &a_ptr[1 * 4 + 2 * 8]); - svfloat32_t a03 = svld1_f32(p4, &a_ptr[0 * 4 + 3 * 8]); - svfloat32_t a13 = svld1_f32(p4, &a_ptr[1 * 4 + 3 * 8]); - - for (int j = 0; j < 4; j++) { - svfloat32_t cj0 = svdup_n_f32(0); - svfloat32_t cj1 = svdup_n_f32(0); - svfloat32_t b0j = svld1_f32(p4, &b_ptr[0 * 4 + j * 8]); - svfloat32_t b1j = svld1_f32(p4, &b_ptr[1 * 4 + j * 8]); - cj0 = svcmla_lane_f32(cj0, a00, b0j, 0, 0); - cj0 = svcmla_lane_f32(cj0, a00, b0j, 0, 90); - cj0 = svcmla_lane_f32(cj0, a01, b0j, 1, 0); - cj0 = svcmla_lane_f32(cj0, a01, b0j, 1, 90); - cj0 = svcmla_lane_f32(cj0, a02, b1j, 0, 0); - cj0 = svcmla_lane_f32(cj0, a02, b1j, 0, 90); - cj0 = svcmla_lane_f32(cj0, a03, b1j, 1, 0); - cj0 = svcmla_lane_f32(cj0, a03, b1j, 1, 90); - - cj1 = svcmla_lane_f32(cj1, a10, b0j, 0, 0); - cj1 = svcmla_lane_f32(cj1, a10, b0j, 0, 90); - cj1 = svcmla_lane_f32(cj1, a11, b0j, 1, 0); - cj1 = svcmla_lane_f32(cj1, a11, b0j, 1, 90); - cj1 = svcmla_lane_f32(cj1, a12, b1j, 0, 0); - cj1 = svcmla_lane_f32(cj1, a12, b1j, 0, 90); - cj1 = svcmla_lane_f32(cj1, a13, b1j, 1, 0); - cj1 = svcmla_lane_f32(cj1, a13, b1j, 1, 90); - - svst1_f32(p4, &out_ptr[0 * 4 + j * 8], cj0); - svst1_f32(p4, &out_ptr[1 * 4 + j * 8], cj1); - } - -#else - __asm__ __volatile__( - - "ld2 {v10.4s, v11.4s}, [%x[APtr]], #32\n" - - "ld2 {v18.4s, v19.4s}, [%x[BPtr]], #32\n" - "ld2 {v20.4s, v21.4s}, [%x[BPtr]], #32\n" - - "fmul v2.4s, v10.4s, v18.s[0]\n" - "fmls v2.4s, v11.4s, v19.s[0]\n" - "ld2 {v12.4s, v13.4s}, [%x[APtr]], #32\n" - "fmul v4.4s, v10.4s, v20.s[0]\n" - "fmls v4.4s, v11.4s, v21.s[0]\n" - "ld2 {v14.4s, v15.4s}, [%x[APtr]], #32\n" - "fmul v3.4s, v11.4s, v18.s[0]\n" - "fmla v3.4s, v10.4s, v19.s[0]\n" - "ld2 {v16.4s, v17.4s}, [%x[APtr]], #32\n" - "fmul v5.4s, v11.4s, v20.s[0]\n" - "fmla v5.4s, v10.4s, v21.s[0]\n" - - "fmla v2.4s, v12.4s, v18.s[1]\n" - "fmls v2.4s, v13.4s, v19.s[1]\n" - "fmla v3.4s, v13.4s, v18.s[1]\n" - "fmla v3.4s, v12.4s, v19.s[1]\n" - "fmla v4.4s, v12.4s, v20.s[1]\n" - "fmls v4.4s, v13.4s, v21.s[1]\n" - "fmla v5.4s, v13.4s, v20.s[1]\n" - "fmla v5.4s, v12.4s, v21.s[1]\n" - - "fmla v2.4s, v14.4s, v18.s[2]\n" - "fmls v2.4s, v15.4s, v19.s[2]\n" - "fmla v3.4s, v15.4s, v18.s[2]\n" - "fmla v3.4s, v14.4s, v19.s[2]\n" - "fmla v4.4s, v14.4s, v20.s[2]\n" - "fmls v4.4s, v15.4s, v21.s[2]\n" - "fmla v5.4s, v15.4s, v20.s[2]\n" - "fmla v5.4s, v14.4s, v21.s[2]\n" - - "fmla v2.4s, v16.4s, v18.s[3]\n" - "fmls v2.4s, v17.4s, v19.s[3]\n" - "fmla v3.4s, v17.4s, v18.s[3]\n" - "fmla v3.4s, v16.4s, v19.s[3]\n" - "st2 {v2.4s, v3.4s}, [%x[outPtr]], #32\n" - "fmla v4.4s, v16.4s, v20.s[3]\n" - "fmls v4.4s, v17.4s, v21.s[3]\n" - "ld2 {v18.4s, v19.4s}, [%x[BPtr]], #32\n" - "fmla v5.4s, v17.4s, v20.s[3]\n" - "fmla v5.4s, v16.4s, v21.s[3]\n" - - "st2 {v4.4s, v5.4s}, [%x[outPtr]], #32\n" - - "ld2 {v20.4s, v21.4s}, [%x[BPtr]], #32\n" - - "fmul v2.4s, v10.4s, v18.s[0]\n" - "fmls v2.4s, v11.4s, v19.s[0]\n" - "fmul v4.4s, v10.4s, v20.s[0]\n" - "fmls v4.4s, v11.4s, v21.s[0]\n" - "fmul v3.4s, v11.4s, v18.s[0]\n" - "fmla v3.4s, v10.4s, v19.s[0]\n" - "fmul v5.4s, v11.4s, v20.s[0]\n" - "fmla v5.4s, v10.4s, v21.s[0]\n" - - "fmla v2.4s, v12.4s, v18.s[1]\n" - "fmls v2.4s, v13.4s, v19.s[1]\n" - "fmla v3.4s, v13.4s, v18.s[1]\n" - "fmla v3.4s, v12.4s, v19.s[1]\n" - "fmla v4.4s, v12.4s, v20.s[1]\n" - "fmls v4.4s, v13.4s, v21.s[1]\n" - "fmla v5.4s, v13.4s, v20.s[1]\n" - "fmla v5.4s, v12.4s, v21.s[1]\n" - - "fmla v2.4s, v14.4s, v18.s[2]\n" - "fmls v2.4s, v15.4s, v19.s[2]\n" - "fmla v3.4s, v15.4s, v18.s[2]\n" - "fmla v3.4s, v14.4s, v19.s[2]\n" - "fmla v4.4s, v14.4s, v20.s[2]\n" - "fmls v4.4s, v15.4s, v21.s[2]\n" - "fmla v5.4s, v15.4s, v20.s[2]\n" - "fmla v5.4s, v14.4s, v21.s[2]\n" - - "fmla v2.4s, v16.4s, v18.s[3]\n" - "fmls v2.4s, v17.4s, v19.s[3]\n" - "fmla v3.4s, v17.4s, v18.s[3]\n" - "fmla v3.4s, v16.4s, v19.s[3]\n" - "fmla v4.4s, v16.4s, v20.s[3]\n" - "fmls v4.4s, v17.4s, v21.s[3]\n" - "fmla v5.4s, v17.4s, v20.s[3]\n" - "fmla v5.4s, v16.4s, v21.s[3]\n" - - "st2 {v2.4s, v3.4s}, [%x[outPtr]], #32\n" - "st2 {v4.4s, v5.4s}, [%x[outPtr]]\n" - - : [APtr] "+r"(a_ptr), [BPtr] "+r"(b_ptr), [outPtr] "+r"(out_ptr) - - : - - : "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v2", "v3", "v4", "v5", "cc"); -#endif - return ARMRAL_SUCCESS; -} - -armral_status armral_cmplx_mat_mult_4x4_f32_iq( - const float32_t *restrict src_a_re, const float32_t *restrict src_a_im, - const float32_t *restrict src_b_re, const float32_t *restrict src_b_im, - float32_t *dst_re, float32_t *dst_im) { -#ifdef ARMRAL_ARCH_SVE - svbool_t p4 = svptrue_pat_b32(SV_VL4); - - svfloat32_t a_col0_re = svld1_f32(p4, &src_a_re[0 * 4]); - svfloat32_t a_col1_re = svld1_f32(p4, &src_a_re[1 * 4]); - svfloat32_t a_col2_re = svld1_f32(p4, &src_a_re[2 * 4]); - svfloat32_t a_col3_re = svld1_f32(p4, &src_a_re[3 * 4]); - svfloat32_t a_col0_im = svld1_f32(p4, &src_a_im[0 * 4]); - svfloat32_t a_col1_im = svld1_f32(p4, &src_a_im[1 * 4]); - svfloat32_t a_col2_im = svld1_f32(p4, &src_a_im[2 * 4]); - svfloat32_t a_col3_im = svld1_f32(p4, &src_a_im[3 * 4]); - - svfloat32_t c_re; - svfloat32_t c_im; - - for (int j = 0; j < 4; j++) { - svfloat32_t b_re = svld1_f32(p4, &src_b_re[j * 4]); - svfloat32_t b_im = svld1_f32(p4, &src_b_im[j * 4]); - - c_re = svmul_lane_f32(a_col0_re, b_re, 0); - c_re = svmla_lane_f32(c_re, a_col1_re, b_re, 1); - c_re = svmla_lane_f32(c_re, a_col2_re, b_re, 2); - c_re = svmla_lane_f32(c_re, a_col3_re, b_re, 3); - c_re = svmls_lane_f32(c_re, a_col0_im, b_im, 0); - c_re = svmls_lane_f32(c_re, a_col1_im, b_im, 1); - c_re = svmls_lane_f32(c_re, a_col2_im, b_im, 2); - c_re = svmls_lane_f32(c_re, a_col3_im, b_im, 3); - - c_im = svmul_lane_f32(a_col0_re, b_im, 0); - c_im = svmla_lane_f32(c_im, a_col1_im, b_re, 1); - c_im = svmla_lane_f32(c_im, a_col2_re, b_im, 2); - c_im = svmla_lane_f32(c_im, a_col3_im, b_re, 3); - c_im = svmla_lane_f32(c_im, a_col0_im, b_re, 0); - c_im = svmla_lane_f32(c_im, a_col1_re, b_im, 1); - c_im = svmla_lane_f32(c_im, a_col2_im, b_re, 2); - c_im = svmla_lane_f32(c_im, a_col3_re, b_im, 3); - - svst1_f32(p4, &dst_re[j * 4], c_re); - svst1_f32(p4, &dst_im[j * 4], c_im); - } - -#else - const float32_t *a_ptr_re = (const float32_t *)src_a_re; - const float32_t *a_ptr_im = (const float32_t *)src_a_im; - const float32_t *b_ptr_re = (const float32_t *)src_b_re; - const float32_t *b_ptr_im = (const float32_t *)src_b_im; - float32_t *out_ptr_re = dst_re; - float32_t *out_ptr_im = dst_im; - __asm__ __volatile__( - - "ld1 {v10.4s}, [%x[APtr_re]], #16\n" - "ld1 {v11.4s}, [%x[APtr_im]], #16\n" - - "ld1 {v18.4s}, [%x[BPtr_re]], #16\n" - "ld1 {v19.4s}, [%x[BPtr_im]], #16\n" - - "ld1 {v20.4s}, [%x[BPtr_re]], #16\n" - "ld1 {v21.4s}, [%x[BPtr_im]], #16\n" - - "fmul v2.4s, v10.4s, v18.s[0]\n" - "fmls v2.4s, v11.4s, v19.s[0]\n" - "ld1 {v12.4s}, [%x[APtr_re]], #16\n" - "ld1 {v13.4s}, [%x[APtr_im]], #16\n" - "fmul v4.4s, v10.4s, v20.s[0]\n" - "fmls v4.4s, v11.4s, v21.s[0]\n" - "ld1 {v14.4s}, [%x[APtr_re]], #16\n" - "ld1 {v15.4s}, [%x[APtr_im]], #16\n" - "fmul v3.4s, v11.4s, v18.s[0]\n" - "fmla v3.4s, v10.4s, v19.s[0]\n" - "ld1 {v16.4s}, [%x[APtr_re]], #16\n" - "ld1 {v17.4s}, [%x[APtr_im]], #16\n" - "fmul v5.4s, v11.4s, v20.s[0]\n" - "fmla v5.4s, v10.4s, v21.s[0]\n" - "fmla v2.4s, v12.4s, v18.s[1]\n" - "fmls v2.4s, v13.4s, v19.s[1]\n" - "fmla v3.4s, v13.4s, v18.s[1]\n" - "fmla v3.4s, v12.4s, v19.s[1]\n" - "fmla v4.4s, v12.4s, v20.s[1]\n" - "fmls v4.4s, v13.4s, v21.s[1]\n" - "fmla v5.4s, v13.4s, v20.s[1]\n" - "fmla v5.4s, v12.4s, v21.s[1]\n" - - "fmla v2.4s, v14.4s, v18.s[2]\n" - "fmls v2.4s, v15.4s, v19.s[2]\n" - "fmla v3.4s, v15.4s, v18.s[2]\n" - "fmla v3.4s, v14.4s, v19.s[2]\n" - "fmla v4.4s, v14.4s, v20.s[2]\n" - "fmls v4.4s, v15.4s, v21.s[2]\n" - "fmla v5.4s, v15.4s, v20.s[2]\n" - "fmla v5.4s, v14.4s, v21.s[2]\n" - - "fmla v2.4s, v16.4s, v18.s[3]\n" - "fmls v2.4s, v17.4s, v19.s[3]\n" - "fmla v3.4s, v17.4s, v18.s[3]\n" - "fmla v3.4s, v16.4s, v19.s[3]\n" - "st1 {v2.4s}, [%x[outPtr_re]], #16\n" - "st1 {v3.4s}, [%x[outPtr_im]], #16\n" - "fmla v4.4s, v16.4s, v20.s[3]\n" - "fmls v4.4s, v17.4s, v21.s[3]\n" - "ld1 {v18.4s}, [%x[BPtr_re]], #16\n" - "ld1 {v19.4s}, [%x[BPtr_im]], #16\n" - "fmla v5.4s, v17.4s, v20.s[3]\n" - "fmla v5.4s, v16.4s, v21.s[3]\n" - - "st1 {v4.4s}, [%x[outPtr_re]], #16\n" - "st1 {v5.4s}, [%x[outPtr_im]], #16\n" - "ld1 {v20.4s}, [%x[BPtr_re]], #16\n" - "ld1 {v21.4s}, [%x[BPtr_im]], #16\n" - "fmul v2.4s, v10.4s, v18.s[0]\n" - "fmls v2.4s, v11.4s, v19.s[0]\n" - "fmul v4.4s, v10.4s, v20.s[0]\n" - "fmls v4.4s, v11.4s, v21.s[0]\n" - "fmul v3.4s, v11.4s, v18.s[0]\n" - "fmla v3.4s, v10.4s, v19.s[0]\n" - "fmul v5.4s, v11.4s, v20.s[0]\n" - "fmla v5.4s, v10.4s, v21.s[0]\n" - - "fmla v2.4s, v12.4s, v18.s[1]\n" - "fmls v2.4s, v13.4s, v19.s[1]\n" - "fmla v3.4s, v13.4s, v18.s[1]\n" - "fmla v3.4s, v12.4s, v19.s[1]\n" - "fmla v4.4s, v12.4s, v20.s[1]\n" - "fmls v4.4s, v13.4s, v21.s[1]\n" - "fmla v5.4s, v13.4s, v20.s[1]\n" - "fmla v5.4s, v12.4s, v21.s[1]\n" - - "fmla v2.4s, v14.4s, v18.s[2]\n" - "fmls v2.4s, v15.4s, v19.s[2]\n" - "fmla v3.4s, v15.4s, v18.s[2]\n" - "fmla v3.4s, v14.4s, v19.s[2]\n" - "fmla v4.4s, v14.4s, v20.s[2]\n" - "fmls v4.4s, v15.4s, v21.s[2]\n" - "fmla v5.4s, v15.4s, v20.s[2]\n" - "fmla v5.4s, v14.4s, v21.s[2]\n" - - "fmla v2.4s, v16.4s, v18.s[3]\n" - "fmls v2.4s, v17.4s, v19.s[3]\n" - "fmla v3.4s, v17.4s, v18.s[3]\n" - "fmla v3.4s, v16.4s, v19.s[3]\n" - "fmla v4.4s, v16.4s, v20.s[3]\n" - "fmls v4.4s, v17.4s, v21.s[3]\n" - "fmla v5.4s, v17.4s, v20.s[3]\n" - "fmla v5.4s, v16.4s, v21.s[3]\n" - - "st1 {v2.4s}, [%x[outPtr_re]], #16\n" - "st1 {v3.4s}, [%x[outPtr_im]], #16\n" - "st1 {v4.4s}, [%x[outPtr_re]], #16\n" - "st1 {v5.4s}, [%x[outPtr_im]], #16\n" - - : [APtr_re] "+r"(a_ptr_re), [APtr_im] "+r"(a_ptr_im), - [BPtr_re] "+r"(b_ptr_re), [BPtr_im] "+r"(b_ptr_im), - [outPtr_re] "+r"(out_ptr_re), [outPtr_im] "+r"(out_ptr_im) - - : - - : "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v2", "v3", "v4", "v5", "cc"); -#endif - return ARMRAL_SUCCESS; -} - -#ifdef ARMRAL_ARCH_SVE -// Calculates a vector width of consecutive output elements in a matrix product -// of a m x n and n x k matrix. p_src_a and p_src_b must point to the start -// row/column respectively, the operation must be valid and the result will be -// stored at exactly dst. -static void sve_mat_one_row_dot(svbool_t pg, const uint16_t m, const uint16_t n, - const uint16_t k, - const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, - armral_cmplx_f32_t *dst) { - - svfloat32_t c0 = svdup_f32(0); - for (int h = 0; h < n - 1; h += 2) { - svbool_t pa = svwhilelt_b32(h * 2, n * 2); - svfloat32_t a0i = svld1rq_f32(pa, (const float32_t *)&p_src_a[h]); - svfloat32_t bi0 = svld1_f32(pg, (const float32_t *)&p_src_b[h * k]); - svfloat32_t bi1 = svld1_f32(pg, (const float32_t *)&p_src_b[h * k + k]); - c0 = svcmla_lane_f32(c0, bi0, a0i, 0, 0); - c0 = svcmla_lane_f32(c0, bi0, a0i, 0, 90); - c0 = svcmla_lane_f32(c0, bi1, a0i, 1, 0); - c0 = svcmla_lane_f32(c0, bi1, a0i, 1, 90); - } - - // If n is odd, we have one more row/col to go - if (n % 2) { - svfloat32_t an = - svreinterpret_f32_u64(svdup_u64(*((const uint64_t *)&p_src_a[n - 1]))); - svfloat32_t bn = svld1_f32(pg, (const float32_t *)&p_src_b[(n - 1) * k]); - c0 = svcmla_f32_x(pg, c0, an, bn, 0); - c0 = svcmla_f32_x(pg, c0, an, bn, 90); - } - svst1_f32(pg, (float32_t *)dst, c0); -} -#endif - -#ifdef ARMRAL_ARCH_SVE -// Calculates 2 vector widths of consecutive output elements in a matrix product -// of a m x n and n x k matrix. p_src_a and p_src_b must point to the start -// row/column respectively, the operation must be valid and the result will be -// stored at exactly dst. -static void sve_mat_two_row_dot(svbool_t pg, const uint16_t m, const uint16_t n, - const uint16_t k, - const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, - armral_cmplx_f32_t *dst) { - - svfloat32_t c0 = svdup_f32(0); - svfloat32_t c1 = svdup_f32(0); - for (int h = 0; h < n - 1; h += 2) { - svbool_t pa = svwhilelt_b32(h * 2, n * 2); - svfloat32_t a0i = svld1rq_f32(pa, (const float32_t *)&p_src_a[h]); - svfloat32_t a1i = svld1rq_f32(pa, (const float32_t *)&p_src_a[h + n]); - svfloat32_t bi0 = svld1_f32(pg, (const float32_t *)&p_src_b[h * k]); - svfloat32_t bi1 = svld1_f32(pg, (const float32_t *)&p_src_b[h * k + k]); - - c0 = svcmla_lane_f32(c0, bi0, a0i, 0, 0); - c0 = svcmla_lane_f32(c0, bi0, a0i, 0, 90); - c0 = svcmla_lane_f32(c0, bi1, a0i, 1, 0); - c0 = svcmla_lane_f32(c0, bi1, a0i, 1, 90); - c1 = svcmla_lane_f32(c1, bi0, a1i, 0, 0); - c1 = svcmla_lane_f32(c1, bi0, a1i, 0, 90); - c1 = svcmla_lane_f32(c1, bi1, a1i, 1, 0); - c1 = svcmla_lane_f32(c1, bi1, a1i, 1, 90); - } - - // If n is odd, we have one more row/col to go - if (n % 2) { - svfloat32_t a0n = - svreinterpret_f32_u64(svdup_u64(*((const uint64_t *)&p_src_a[n - 1]))); - svfloat32_t a1n = svreinterpret_f32_u64( - svdup_u64(*((const uint64_t *)&p_src_a[2 * n - 1]))); - - svfloat32_t bn = svld1_f32(pg, (const float32_t *)&p_src_b[(n - 1) * k]); - c0 = svcmla_f32_x(pg, c0, a0n, bn, 0); - c0 = svcmla_f32_x(pg, c0, a0n, bn, 90); - c1 = svcmla_f32_x(pg, c1, a1n, bn, 0); - c1 = svcmla_f32_x(pg, c1, a1n, bn, 90); - } - svst1_f32(pg, (float32_t *)&dst[0], c0); - svst1_f32(pg, (float32_t *)&dst[k], c1); -} -#endif - -armral_status -armral_cmplx_mat_mult_f32(const uint16_t m, const uint16_t n, const uint16_t k, - const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, - armral_cmplx_f32_t *p_dst) { -#ifdef ARMRAL_ARCH_SVE - - for (int i = 0; i < m - 1; i += 2) { - for (int j = 0; j < k; j += svcntd()) { - const svbool_t pg = svwhilelt_b32(2 * j, 2 * k); - sve_mat_two_row_dot(pg, m, n, k, &p_src_a[i * n], &p_src_b[j], - &p_dst[i * k + j]); - } - } - if (m % 2) { - const int i = m - 1; - for (int j = 0; j < k; j += svcntd()) { - const svbool_t pg = svwhilelt_b32(2 * j, 2 * k); - sve_mat_one_row_dot(pg, m, n, k, &p_src_a[i * n], &p_src_b[j], - &p_dst[i * k + j]); - } - } -#else - - const float32_t *p_in1 = (const float32_t *)p_src_a; - const float32_t *p_in2 = (const float32_t *)p_src_b; - const armral_cmplx_f32_t *p_in_a = p_src_a; - armral_cmplx_f32_t *p_out = p_dst; - armral_cmplx_f32_t *px; - uint16_t num_rows_a = m; /* number of rows of input matrix A */ - uint16_t num_cols_b = k; /* number of columns of input matrix B */ - uint16_t num_cols_a = n; /* number of columns of input matrix A */ - - float32x4x2_t a0_v; - float32x4x2_t a1_v; - float32x4_t temp_r2; - float32x4_t temp_i2; - float32x4_t b0_v; - float32x4_t b1_v; - float32x4_t b2_v; - float32x4_t b3_v; - float32x4_t b_col_real; - float32x4_t b_col_im; - float32x4_t b_col_real2; - float32x4_t b_col_im2; - float32x2_t accum = vdup_n_f32(0); - const float32_t *p_in1_b = (const float32_t *)p_src_a; - const float32_t *p_in1_b2 = (const float32_t *)p_src_b; - - uint16_t col; - uint16_t i = 0U; - uint16_t j; - uint16_t row_cnt; - uint16_t row = num_rows_a; - uint16_t col_cnt; - armral_cmplx_f32_t *px_b; - - /* The following loop performs the dot-product of each row in pSrcA with each - * column in pSrcB */ - - row_cnt = row >> 1; - /* Row loop */ - while (row_cnt > 0U) { - /* Output pointer is set to starting address of the row being processed */ - px = p_out + i; - px_b = px + num_cols_b; - - /* For every row wise process, the column loop counter is to be initiated */ - col = num_cols_b; - - /* For every row wise process, the pIn2 pointer is set - ** to the starting address of the pSrcB data */ - p_in2 = (const float32_t *)p_src_b; - p_in1_b2 = p_in2 + 2 * num_cols_b; - - j = 0U; - - /* Column loop */ - col >>= 1; - while (col > 0U) { - /* Set the variable sum, that acts as accumulator, to zero */ - float32_t sum_real1 = 0.0F; - float32_t sum_imag1 = 0.0F; - float32_t sum_real1_b = 0.0F; - float32_t sum_imag1_b = 0.0F; - - float32_t sum_real2 = 0.0F; - float32_t sum_imag2 = 0.0F; - float32_t sum_real2_b = 0.0F; - float32_t sum_imag2_b = 0.0F; - - float32_t sum_real3 = 0.0F; - float32_t sum_imag3 = 0.0F; - float32_t sum_real3_b = 0.0F; - float32_t sum_imag3_b = 0.0F; - - float32_t sum_real4 = 0.0F; - float32_t sum_imag4 = 0.0F; - float32_t sum_real4_b = 0.0F; - float32_t sum_imag4_b = 0.0F; - - /* Initiate the pointer pIn1 to point to the starting address of the - * column being processed */ - p_in1 = (const float32_t *)p_in_a; - p_in1_b = p_in1 + 2 * num_cols_a; - - float32x4_t acc_r0 = {}; - float32x4_t acc_i0 = {}; - float32x4_t acc_r1 = {}; - float32x4_t acc_i1 = {}; - float32x4_t acc_r2 = {}; - float32x4_t acc_i2 = {}; - float32x4_t acc_r3 = {}; - float32x4_t acc_i3 = {}; - - /* Compute 4 MACs simultaneously. */ - col_cnt = num_cols_a >> 2; - - /* Matrix multiplication */ - while (col_cnt > 0U) { - float32x4_t temp_r = {}; - float32x4_t temp_i = {}; - // load & separate real/imag pSrcA (de-interleave 2) - a0_v = vld2q_f32(p_in1); - // load & separate real/imag pSrcA (de-interleave 2) - a1_v = vld2q_f32(p_in1_b); - - p_in1 += 8; - p_in1_b += 8; - - // load but NOT separate real/imag - b0_v = vld1q_f32(p_in2); - b1_v = vld1q_f32(p_in1_b2); - b2_v = vld1q_f32(p_in2 + 4 * num_cols_b); - b3_v = vld1q_f32(p_in1_b2 + 4 * num_cols_b); - - p_in2 = p_in2 + 8 * num_cols_b; - p_in1_b2 = p_in1_b2 + 8 * num_cols_b; - b_col_real = vtrn1q_f32(b0_v, b1_v); // even elem - b_col_im = vtrn2q_f32(b0_v, b1_v); // odd elem - b_col_real2 = vtrn1q_f32(b2_v, b3_v); // even elem - b_col_im2 = vtrn2q_f32(b2_v, b3_v); // odd elem - - /*First column*/ - temp_r = vzip1q_f32x2(b_col_real, b_col_real2); - temp_i = vzip1q_f32x2(b_col_im, b_col_im2); - - /*Second column*/ - temp_r2 = vzip2q_f32x2(b_col_real, b_col_real2); - temp_i2 = vzip2q_f32x2(b_col_im, b_col_im2); - - acc_r0 = vfmaq_f32(acc_r0, a0_v.val[0], temp_r); - acc_r0 = vfmsq_f32(acc_r0, a0_v.val[1], temp_i); - acc_i0 = vfmaq_f32(acc_i0, a0_v.val[1], temp_r); - acc_i0 = vfmaq_f32(acc_i0, a0_v.val[0], temp_i); - /* same row A, next column B*/ - acc_r2 = vfmaq_f32(acc_r2, a0_v.val[0], temp_r2); - acc_r2 = vfmsq_f32(acc_r2, a0_v.val[1], temp_i2); - - acc_i2 = vfmaq_f32(acc_i2, a0_v.val[1], temp_r2); - acc_i2 = vfmaq_f32(acc_i2, a0_v.val[0], temp_i2); - - acc_r1 = vfmaq_f32(acc_r1, a1_v.val[0], temp_r); - acc_r1 = vfmsq_f32(acc_r1, a1_v.val[1], temp_i); - - acc_i1 = vfmaq_f32(acc_i1, a1_v.val[1], temp_r); - acc_i1 = vfmaq_f32(acc_i1, a1_v.val[0], temp_i); - /* same row A, next column B*/ - acc_r3 = vfmaq_f32(acc_r3, a1_v.val[0], temp_r2); - acc_r3 = vfmsq_f32(acc_r3, a1_v.val[1], temp_i2); - - acc_i3 = vfmaq_f32(acc_i3, a1_v.val[1], temp_r2); - acc_i3 = vfmaq_f32(acc_i3, a1_v.val[0], temp_i2); - - col_cnt--; - } - - sum_real1 += vaddvq_f32(acc_r0); - sum_imag1 += vaddvq_f32(acc_i0); - sum_real3 += vaddvq_f32(acc_r2); - sum_imag3 += vaddvq_f32(acc_i2); - - sum_real1_b += vaddvq_f32(acc_r1); - sum_imag1_b += vaddvq_f32(acc_i1); - sum_real3_b += vaddvq_f32(acc_r3); - sum_imag3_b += vaddvq_f32(acc_i3); - - /* If the columns of pSrcA is not a multiple of 4, compute any remaining - *MACs here. - ** No loop unrolling is used. */ - col_cnt = num_cols_a & 3; - while (col_cnt > 0U) { - - float32_t a1 = *p_in1; - float32_t a1_b = *p_in1_b; - - float32_t c1 = *p_in2; - float32_t c1_b = *(p_in2 + 2U); - - float32_t b1 = *(p_in1 + 1U); - float32_t b1_b = *(p_in1_b + 1U); - - float32_t d1 = *(p_in2 + 1U); - float32_t d1_b = *(p_in2 + 3U); - - sum_real1 += a1 * c1; - sum_imag1 += b1 * c1; - - sum_real3 += a1 * c1_b; - sum_imag3 += b1 * c1_b; - - sum_real1_b += a1_b * c1; - sum_imag1_b += b1_b * c1; - - sum_real3_b += a1_b * c1_b; - sum_imag3_b += b1_b * c1_b; - - p_in1 += 2U; - p_in1_b += 2U; - p_in2 += 2 * num_cols_b; - - sum_real2 -= b1 * d1; - sum_imag2 += a1 * d1; - - sum_real4 -= b1 * d1_b; - sum_imag4 += a1 * d1_b; - - sum_real2_b -= b1_b * d1; - sum_imag2_b += a1_b * d1; - - sum_real4_b -= b1_b * d1_b; - sum_imag4_b += a1_b * d1_b; - - /* Decrement the loop counter */ - col_cnt--; - } - - sum_real1 += sum_real2; - sum_imag1 += sum_imag2; - - sum_real3 += sum_real4; - sum_imag3 += sum_imag4; - - sum_real1_b += sum_real2_b; - sum_imag1_b += sum_imag2_b; - - sum_real3_b += sum_real4_b; - sum_imag3_b += sum_imag4_b; - - /* Store the result in the destination buffer */ - (*px).re = sum_real1; - (*px).im = sum_imag1; - px++; - (*px).re = sum_real3; - (*px).im = sum_imag3; - px++; - (*px_b).re = sum_real1_b; - (*px_b).im = sum_imag1_b; - px_b++; - (*px_b).re = sum_real3_b; - (*px_b).im = sum_imag3_b; - px_b++; - // /* Update the pointer pIn2 to point to the starting address of the - // next column */ - j++; - p_in2 = (const float32_t *)p_src_b + 4U * j; - p_in1_b2 = p_in2 + 2U * num_cols_b; - col--; - } - - col = num_cols_b & 1; - if (col) { - /* Set the variable sum, that acts as accumulator, to zero */ - float32_t sum_real1 = 0.0F; - float32_t sum_imag1 = 0.0F; - float32_t sum_real2 = 0.0F; - float32_t sum_imag2 = 0.0F; - float32_t sum_real1_b = 0.0F; - float32_t sum_imag1_b = 0.0F; - float32_t sum_real2_b = 0.0F; - float32_t sum_imag2_b = 0.0F; - - /* Initiate the pointer pIn1 to point to the starting address of the - * column being processed */ - p_in1 = (const float32_t *)p_in_a; - p_in1_b = p_in1 + 2 * num_cols_a; - - float32x4_t acc_r0 = {}; - float32x4_t acc_i0 = {}; - float32x4_t acc_r1 = {}; - float32x4_t acc_i1 = {}; - - /* Compute 4 MACs simultaneously. */ - col_cnt = num_cols_a >> 2; - - /* Matrix multiplication */ - while (col_cnt > 0U) { - // load & separate real/imag pSrcA (de-interleave 2) - a0_v = vld2q_f32(p_in1); - a1_v = vld2q_f32(p_in1_b); - - p_in1 += 8; - p_in1_b += 8; - - // load but NOT separate real/imag - float32x2_t b_four_rows[4]; - b_four_rows[0] = vld1_f32(p_in2); - b_four_rows[1] = vld1_f32(p_in1_b2); - b_four_rows[2] = vld1_f32(p_in2 + 4 * num_cols_b); - b_four_rows[3] = vld1_f32(p_in1_b2 + 4 * num_cols_b); - - p_in2 = p_in2 + 8 * num_cols_b; - p_in1_b2 = p_in1_b2 + 8 * num_cols_b; - float32x4_t b_tmp_real; - float32x4_t b_tmp_im; - b_tmp_real = vcombine_f32(vtrn1_f32(b_four_rows[0], b_four_rows[1]), - vtrn1_f32(b_four_rows[2], b_four_rows[3])); - b_tmp_im = vcombine_f32(vtrn2_f32(b_four_rows[0], b_four_rows[1]), - vtrn2_f32(b_four_rows[2], b_four_rows[3])); - - // real * real - acc_r0 = vfmaq_f32(acc_r0, a0_v.val[0], b_tmp_real); - // imag * imag - acc_r0 = vfmsq_f32(acc_r0, a0_v.val[1], b_tmp_im); - // imag * real - acc_i0 = vfmaq_f32(acc_i0, a0_v.val[1], b_tmp_real); - // real * imag - acc_i0 = vfmaq_f32(acc_i0, a0_v.val[0], b_tmp_im); - - // real * real - acc_r1 = vfmaq_f32(acc_r1, a1_v.val[0], b_tmp_real); - // imag * imag - acc_r1 = vfmsq_f32(acc_r1, a1_v.val[1], b_tmp_im); - // imag * real - acc_i1 = vfmaq_f32(acc_i1, a1_v.val[1], b_tmp_real); - // real * imag - acc_i1 = vfmaq_f32(acc_i1, a1_v.val[0], b_tmp_im); - - col_cnt--; - } - - sum_real1 += vaddvq_f32(acc_r0); - sum_imag1 += vaddvq_f32(acc_i0); - - sum_real1_b += vaddvq_f32(acc_r1); - sum_imag1_b += vaddvq_f32(acc_i1); - - /* If the columns of pSrcA is not a multiple of 4, compute any remaining - *MACs here. - ** No loop unrolling is used. */ - col_cnt = num_cols_a & 3; - while (col_cnt > 0U) { - - float32_t a1 = *p_in1; // real part of entry from A - float32_t a1_b = *p_in1_b; - float32_t c1 = *p_in2; // real part of entry from B - - float32_t b1 = *(p_in1 + 1U); // imaginary part of entry from A - float32_t b1_b = *(p_in1_b + 1U); - float32_t d1 = *(p_in2 + 1U); // imaginary part of entry from B - - // real * real - sum_real1 += a1 * c1; - // imag * real - sum_imag1 += b1 * c1; - - // imag * imag - sum_real2 -= b1 * d1; - // real * imag - sum_imag2 += a1 * d1; - - // real * real - sum_real1_b += a1_b * c1; - // imag * real - sum_imag1_b += b1_b * c1; - - // imag * imag - sum_real2_b -= b1_b * d1; - // real * imag - sum_imag2_b += a1_b * d1; - - p_in1 += 2U; - p_in1_b += 2U; - p_in2 += 2 * num_cols_b; - - /* Decrement the loop counter */ - col_cnt--; - } - - sum_real1 += sum_real2; - sum_imag1 += sum_imag2; - - sum_real1_b += sum_real2_b; - sum_imag1_b += sum_imag2_b; - - /* Store the result in the destination buffer */ - (*px).re = sum_real1; - (*px).im = sum_imag1; - px++; - (*px_b).re = sum_real1_b; - (*px_b).im = sum_imag1_b; - px_b++; - // Update the pointer pIn2 to point to the starting address of the next - // column - j++; - } - - /* Update the pointer pInA to point to the starting address of the next 2 - * row */ - i = i + 2 * num_cols_b; - p_in_a = p_in_a + 2 * num_cols_a; - /* Decrement the row loop counter */ - row_cnt--; - } - - row_cnt = row & 1; - while (row_cnt > 0U) { - /* Output pointer is set to starting address of the row being processed */ - px = p_out + i; - - /* For every row wise process, the column loop counter is to be initiated */ - col = num_cols_b; - - /* For every row wise process, the pIn2 pointer is set - ** to the starting address of the pSrcB data */ - p_in2 = (const float32_t *)p_src_b; - - j = 0U; - - /* Column loop */ - while (col > 0U) { - /* Set the variable sum, that acts as accumulator, to zero */ - float32_t sum_real1 = 0.0F; - float32_t sum_imag1 = 0.0F; - - float32_t sum_real2 = 0.0F; - float32_t sum_imag2 = 0.0F; - - /* Initiate the pointer pIn1 to point to the starting address of the - * column being processed */ - p_in1 = (const float32_t *)p_in_a; - - float32x4_t acc_r0 = {}; - float32x4_t acc_i0 = {}; - - /* Compute 4 MACs simultaneously. */ - col_cnt = num_cols_a >> 2; - - /* Matrix multiplication */ - while (col_cnt > 0U) { - float32x4_t temp_r = {}; - float32x4_t temp_i = {}; - /* Reading real part of complex matrix A */ - // load & separate real/imag p_src_a (de-interleave 2) - a0_v = vld2q_f32(p_in1); - p_in1 += 8; - - temp_r[0] = *p_in2; - temp_i[0] = *(p_in2 + 1U); - p_in2 += 2 * num_cols_b; - - temp_r[1] = *p_in2; - temp_i[1] = *(p_in2 + 1U); - p_in2 += 2 * num_cols_b; - - temp_r[2] = *p_in2; - temp_i[2] = *(p_in2 + 1U); - p_in2 += 2 * num_cols_b; - - temp_r[3] = *p_in2; - temp_i[3] = *(p_in2 + 1U); - p_in2 += 2 * num_cols_b; - - acc_r0 = vfmaq_f32(acc_r0, a0_v.val[0], temp_r); - acc_r0 = vfmsq_f32(acc_r0, a0_v.val[1], temp_i); - - acc_i0 = vfmaq_f32(acc_i0, a0_v.val[1], temp_r); - acc_i0 = vfmaq_f32(acc_i0, a0_v.val[0], temp_i); - - /* Decrement the loop count */ - col_cnt--; - } - - accum = vpadd_f32(vget_low_f32(acc_r0), vget_high_f32(acc_r0)); - sum_real1 += accum[0] + accum[1]; - - accum = vpadd_f32(vget_low_f32(acc_i0), vget_high_f32(acc_i0)); - sum_imag1 += accum[0] + accum[1]; - - /* If the columns of pSrcA is not a multiple of 4, compute any remaining - *MACs here. - ** No loop unrolling is used. */ - col_cnt = num_cols_a & 3; - - while (col_cnt > 0U) { - /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */ - float32_t a1 = *p_in1; - float32_t c1 = *p_in2; - - float32_t b1 = *(p_in1 + 1U); - float32_t d1 = *(p_in2 + 1U); - - sum_real1 += a1 * c1; - sum_imag1 += b1 * c1; - - p_in1 += 2U; - p_in2 += 2 * num_cols_b; - - sum_real2 -= b1 * d1; - sum_imag2 += a1 * d1; - - /* Decrement the loop counter */ - col_cnt--; - } - - sum_real1 += sum_real2; - sum_imag1 += sum_imag2; - - /* Store the result in the destination buffer */ - (*px).re = sum_real1; - (*px).im = sum_imag1; - px++; - - /* Update the pointer pIn2 to point to the starting address of the next - * column */ - j++; - p_in2 = (const float32_t *)p_src_b + 2U * j; - - /* Decrement the column loop counter */ - col--; - } - - // /* Update the pointer pInA to point to the starting address of the next - // row */ - i = i + num_cols_b; - p_in_a = p_in_a + num_cols_a; - - /* Decrement the row loop counter */ - row_cnt--; - } - -#endif - - return ARMRAL_SUCCESS; -} diff --git a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_vec_mult_f32.c b/src/BasicMathFun/MatrixMult/arm_cmplx_mat_vec_mult_f32.c index 83f9ec1..8225ce4 100644 --- a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_vec_mult_f32.c +++ b/src/BasicMathFun/MatrixMult/arm_cmplx_mat_vec_mult_f32.c @@ -6,7 +6,7 @@ #include "intrinsics.h" #include -#ifdef ARMRAL_ARCH_SVE +#if ARMRAL_ARCH_SVE >= 2 #include #endif @@ -22,7 +22,7 @@ armral_cmplx_mat_vec_mult_f32(const uint16_t m, const uint16_t n, uint16_t num_rows_a = m; // number of rows of input matrix A uint16_t num_cols_a = n; // number of columns of input matrix A -#ifdef ARMRAL_ARCH_SVE +#if ARMRAL_ARCH_SVE >= 2 svbool_t ptrue = svptrue_b32(); if (num_rows_a % 2 == 0) { const float32_t *p_in1_2 = (const float32_t *)p_src_a; diff --git a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_vec_mult_i16.c b/src/BasicMathFun/MatrixMult/arm_cmplx_mat_vec_mult_i16.c index 325b206..e8512f3 100644 --- a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_vec_mult_i16.c +++ b/src/BasicMathFun/MatrixMult/arm_cmplx_mat_vec_mult_i16.c @@ -52,10 +52,10 @@ armral_cmplx_mat_vec_mult_i16(const uint16_t m, const uint16_t n, p_in1 = (const int16_t *)p_in_a; p_in1_b = p_in1 + 2 * num_cols_a; - int64x2_t acc_r0 = {}; - int64x2_t acc_i0 = {}; - int64x2_t acc_r1 = {}; - int64x2_t acc_i1 = {}; + int64x2_t acc_r0 = {0}; + int64x2_t acc_i0 = {0}; + int64x2_t acc_r1 = {0}; + int64x2_t acc_i1 = {0}; // Compute 8 MACs simultaneously uint16_t col_cnt = num_cols_a >> 3; @@ -83,10 +83,10 @@ armral_cmplx_mat_vec_mult_i16(const uint16_t m, const uint16_t n, // Load eight entries of X, splitting real and imaginary components int16x8x2_t tmp_x = vld2q_s16(p_in2); - int32x4_t r_32bit[2] = {}; + int32x4_t r_32bit[2] = {0}; r_32bit[0] = vmovl_low_s16(tmp_x.val[0]); r_32bit[1] = vmovl_high_s16(tmp_x.val[0]); - int32x4_t i_32bit[2] = {}; + int32x4_t i_32bit[2] = {0}; i_32bit[0] = vmovl_low_s16(tmp_x.val[1]); i_32bit[1] = vmovl_high_s16(tmp_x.val[1]); @@ -149,10 +149,10 @@ armral_cmplx_mat_vec_mult_i16(const uint16_t m, const uint16_t n, int16x4x2_t a0_v = vld2_s16(p_in1); int16x4x2_t a1_v = vld2_s16(p_in1_b); - int32x4_t a0_vextended[2] = {}; + int32x4_t a0_vextended[2] = {0}; a0_vextended[0] = vmovl_s16(a0_v.val[0]); a0_vextended[1] = vmovl_s16(a0_v.val[1]); - int32x4_t a1_vextended[2] = {}; + int32x4_t a1_vextended[2] = {0}; a1_vextended[0] = vmovl_s16(a1_v.val[0]); a1_vextended[1] = vmovl_s16(a1_v.val[1]); @@ -232,7 +232,7 @@ armral_cmplx_mat_vec_mult_i16(const uint16_t m, const uint16_t n, col_cnt--; } - armral_cmplx_int16_t out[2] = {}; + armral_cmplx_int16_t out[2] = {0}; out[0].re = vqmovns_s32(vqshrnd_n_s64(sum_real1_ext, 15)); out[0].im = vqmovns_s32(vqshrnd_n_s64(sum_imag1_ext, 15)); out[1].re = vqmovns_s32(vqshrnd_n_s64(sum_real1_b_ext, 15)); diff --git a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_aah_f32.cpp b/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_aah_f32.cpp similarity index 90% rename from src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_aah_f32.cpp rename to src/BasicMathFun/MatrixMult/arm_cmplx_matmul_aah_f32.cpp index 339080c..88b337b 100644 --- a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_aah_f32.cpp +++ b/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_aah_f32.cpp @@ -8,16 +8,20 @@ #include #endif +namespace { + #ifdef ARMRAL_ARCH_SVE template -static inline void aah_mult_iter( - svbool_t pg, uint16_t n, uint16_t k, - const armral_cmplx_f32_t *__restrict p_src_a, svfloat32_t &p_out_re0, - svfloat32_t &p_out_re1, svfloat32_t &p_out_re2, svfloat32_t &p_out_re3, - svfloat32_t &p_out_re4, svfloat32_t &p_out_re5, svfloat32_t &p_out_re6, - svfloat32_t &p_out_re7, svfloat32_t &p_out_re8, svfloat32_t &p_out_re9, - svfloat32_t &p_out_im0, svfloat32_t &p_out_im1, svfloat32_t &p_out_im2, - svfloat32_t &p_out_im3, svfloat32_t &p_out_im4, svfloat32_t &p_out_im5) { +inline void aah_mult_iter(svbool_t pg, uint16_t n, uint16_t k, + const armral_cmplx_f32_t *__restrict p_src_a, + svfloat32_t &p_out_re0, svfloat32_t &p_out_re1, + svfloat32_t &p_out_re2, svfloat32_t &p_out_re3, + svfloat32_t &p_out_re4, svfloat32_t &p_out_re5, + svfloat32_t &p_out_re6, svfloat32_t &p_out_re7, + svfloat32_t &p_out_re8, svfloat32_t &p_out_re9, + svfloat32_t &p_out_im0, svfloat32_t &p_out_im1, + svfloat32_t &p_out_im2, svfloat32_t &p_out_im3, + svfloat32_t &p_out_im4, svfloat32_t &p_out_im5) { svfloat32_t p_in0; svfloat32_t p_in1; svfloat32_t p_in2; @@ -93,10 +97,10 @@ static inline void aah_mult_iter( #endif template -static inline armral_status -armral_cmplx_mat_mult_aah_f32_m(uint16_t n, - const armral_cmplx_f32_t *__restrict p_src_a, - armral_cmplx_f32_t *p_dst_c) { +inline armral_status +armral_cmplx_matmul_aah_f32_m(uint16_t n, + const armral_cmplx_f32_t *__restrict p_src_a, + armral_cmplx_f32_t *p_dst_c) { // For each row, we have a decrementing number of real values to calculate. constexpr uint16_t num_re = (M * (M + 1)) / 2; constexpr uint16_t num_im = (M * M) - num_re; @@ -294,25 +298,25 @@ armral_cmplx_mat_mult_aah_f32_m(uint16_t n, float32x2_t p_out_extra; switch (M) { case 2: - p_out[0] = (float32x4_t){re[0], 0.F, re[1], im[0]}; - p_out[1] = (float32x4_t){re[1], -im[0], re[2], 0.F}; + p_out[0] = float32x4_t{re[0], 0.F, re[1], im[0]}; + p_out[1] = float32x4_t{re[1], -im[0], re[2], 0.F}; break; case 3: - p_out[0] = (float32x4_t){re[0], 0.F, re[1], im[0]}; - p_out[1] = (float32x4_t){re[2], im[1], re[1], -im[0]}; - p_out[2] = (float32x4_t){re[3], 0.F, re[4], im[2]}; - p_out[3] = (float32x4_t){re[2], -im[1], re[4], -im[2]}; - p_out_extra = (float32x2_t){re[5], 0.F}; + p_out[0] = float32x4_t{re[0], 0.F, re[1], im[0]}; + p_out[1] = float32x4_t{re[2], im[1], re[1], -im[0]}; + p_out[2] = float32x4_t{re[3], 0.F, re[4], im[2]}; + p_out[3] = float32x4_t{re[2], -im[1], re[4], -im[2]}; + p_out_extra = float32x2_t{re[5], 0.F}; break; case 4: - p_out[0] = (float32x4_t){re[0], 0.F, re[1], im[0]}; - p_out[1] = (float32x4_t){re[2], im[1], re[3], im[2]}; - p_out[2] = (float32x4_t){re[1], -im[0], re[4], 0.F}; - p_out[3] = (float32x4_t){re[5], im[3], re[6], im[4]}; - p_out[4] = (float32x4_t){re[2], -im[1], re[5], -im[3]}; - p_out[5] = (float32x4_t){re[7], 0.F, re[8], im[5]}; - p_out[6] = (float32x4_t){re[3], -im[2], re[6], -im[4]}; - p_out[7] = (float32x4_t){re[8], -im[5], re[9], 0.F}; + p_out[0] = float32x4_t{re[0], 0.F, re[1], im[0]}; + p_out[1] = float32x4_t{re[2], im[1], re[3], im[2]}; + p_out[2] = float32x4_t{re[1], -im[0], re[4], 0.F}; + p_out[3] = float32x4_t{re[5], im[3], re[6], im[4]}; + p_out[4] = float32x4_t{re[2], -im[1], re[5], -im[3]}; + p_out[5] = float32x4_t{re[7], 0.F, re[8], im[5]}; + p_out[6] = float32x4_t{re[3], -im[2], re[6], -im[4]}; + p_out[7] = float32x4_t{re[8], -im[5], re[9], 0.F}; break; default: return ARMRAL_ARGUMENT_ERROR; @@ -328,17 +332,19 @@ armral_cmplx_mat_mult_aah_f32_m(uint16_t n, return ARMRAL_SUCCESS; } +} // namespace + armral_status -armral_cmplx_mat_mult_aah_f32(uint16_t m, uint16_t n, - const armral_cmplx_f32_t *__restrict p_src_a, - armral_cmplx_f32_t *p_dst_c) { +armral_cmplx_matmul_aah_f32(uint16_t m, uint16_t n, + const armral_cmplx_f32_t *__restrict p_src_a, + armral_cmplx_f32_t *p_dst_c) { switch (m) { case 2: - return armral_cmplx_mat_mult_aah_f32_m<2>(n, p_src_a, p_dst_c); + return armral_cmplx_matmul_aah_f32_m<2>(n, p_src_a, p_dst_c); case 3: - return armral_cmplx_mat_mult_aah_f32_m<3>(n, p_src_a, p_dst_c); + return armral_cmplx_matmul_aah_f32_m<3>(n, p_src_a, p_dst_c); case 4: - return armral_cmplx_mat_mult_aah_f32_m<4>(n, p_src_a, p_dst_c); + return armral_cmplx_matmul_aah_f32_m<4>(n, p_src_a, p_dst_c); } #ifdef ARMRAL_ARCH_SVE const uint32_t vals_per_vector = svcntd(); diff --git a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_ahb_f32.c b/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_ahb_f32.cpp similarity index 70% rename from src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_ahb_f32.c rename to src/BasicMathFun/MatrixMult/arm_cmplx_matmul_ahb_f32.cpp index 15d40f5..440c22a 100644 --- a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_ahb_f32.c +++ b/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_ahb_f32.cpp @@ -5,19 +5,20 @@ #include "armral.h" #include "intrinsics.h" -#include -#include -#include +#include +#include #ifdef ARMRAL_ARCH_SVE #include #endif -static void cmplx_mat_mult_ahb_b2x2(uint16_t n, - const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, - armral_cmplx_f32_t *p_dst) { - const uint16_t mk = 2; +namespace { + +void cmplx_matmul_ahb_b2x2(uint16_t m, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *p_dst) { + constexpr uint16_t nk = 2; #if ARMRAL_ARCH_SVE >= 2 svbool_t p4 = svptrue_pat_b32(SV_VL4); @@ -27,7 +28,7 @@ static void cmplx_mat_mult_ahb_b2x2(uint16_t n, svfloat32_t b_r0_rev = svrev64_f32(p4, b_r0); svfloat32_t b_r1_rev = svrev64_f32(p4, b_r1); - for (uint32_t j = 0; j < n; j++) { + for (uint32_t j = 0; j < m; j++) { svfloat32_t acc_1; svfloat32_t acc_2; @@ -39,7 +40,7 @@ static void cmplx_mat_mult_ahb_b2x2(uint16_t n, acc_2 = svmul_lane_f32(b_r0_rev, a_r0j, 1); // r = 1 - svfloat32_t a_r1j = svld1_f32(p2, (float32_t const *)&p_src_a[n + j]); + svfloat32_t a_r1j = svld1_f32(p2, (float32_t const *)&p_src_a[m + j]); // [R0*r, R0*i, R0*r, R0*i] acc_1 = svmla_lane_f32(acc_1, b_r1, a_r1j, 0); // [I0*i, I0*r, I0*i, I0*r] @@ -48,7 +49,7 @@ static void cmplx_mat_mult_ahb_b2x2(uint16_t n, svfloat32_t result = svadd_f32_x( p4, acc_1, svreinterpret_f32_f64(svneg_f64_x(p4, svreinterpret_f64_f32(acc_2)))); - svst1_f32(p4, (float32_t *)&p_dst[mk * j], result); + svst1_f32(p4, (float32_t *)&p_dst[nk * j], result); } #else float32x4_t b_r0 = vld1q_f32((float32_t const *)p_src_b); @@ -56,7 +57,7 @@ static void cmplx_mat_mult_ahb_b2x2(uint16_t n, float32x4_t b_r0_rev = vrev64q_f32(b_r0); float32x4_t b_r1_rev = vrev64q_f32(b_r1); - for (uint32_t j = 0; j < n; j++) { + for (uint32_t j = 0; j < m; j++) { float32x4_t acc_1; float32x4_t acc_2; @@ -68,7 +69,7 @@ static void cmplx_mat_mult_ahb_b2x2(uint16_t n, acc_2 = vmulq_lane_f32(b_r0_rev, a_r0j, 1); // r = 1 - float32x2_t a_r1j = vld1_f32((float32_t const *)&p_src_a[n + j]); + float32x2_t a_r1j = vld1_f32((float32_t const *)&p_src_a[m + j]); // [R0*r, R0*i, R0*r, R0*i] acc_1 = vfmaq_lane_f32(acc_1, b_r1, a_r1j, 0); // [I0*i, I0*r, I0*i, I0*r] @@ -76,16 +77,16 @@ static void cmplx_mat_mult_ahb_b2x2(uint16_t n, float32x4_t result = vaddq_f32(acc_1, vnegq64_f32(acc_2)); - vst1q_f32((float32_t *)&p_dst[mk * j], result); + vst1q_f32((float32_t *)&p_dst[nk * j], result); } #endif } -static void cmplx_mat_mult_ahb_b3x3(uint16_t n, - const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, - armral_cmplx_f32_t *p_dst) { - const uint16_t mk = 3; +void cmplx_matmul_ahb_b3x3(uint16_t m, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *p_dst) { + constexpr uint16_t nk = 3; #if ARMRAL_ARCH_SVE >= 2 svbool_t p3 = svptrue_pat_b32(SV_VL3); @@ -96,12 +97,12 @@ static void cmplx_mat_mult_ahb_b3x3(uint16_t n, svfloat32x2_t b_r2 = svld2_f32(p3, (float32_t const *)&p_src_b[6]); svfloat32x2_t *b_rows[] = {&b_r0, &b_r1, &b_r2}; - for (uint32_t j = 0; j < n; j++) { + for (uint32_t j = 0; j < m; j++) { svfloat32_t dot_0 = svundef_f32(); svfloat32_t dot_1 = svundef_f32(); - // Note: We leave it to the compiler to unroll this loop over mk - for (uint32_t r = 0; r < mk; r++) { - svfloat32_t a_rj = svld1_f32(p2, (float32_t const *)&p_src_a[r * n + j]); + // Note: We leave it to the compiler to unroll this loop over nk + for (uint32_t r = 0; r < nk; r++) { + svfloat32_t a_rj = svld1_f32(p2, (float32_t const *)&p_src_a[r * m + j]); // Note: We leave it to the compiler to eliminate the following branch if (r == 0) { // dot.re += a_jr.re * b_ir.re + a_jr.im * b_ir.im; @@ -120,7 +121,7 @@ static void cmplx_mat_mult_ahb_b3x3(uint16_t n, } } svfloat32x2_t dot = svcreate2(dot_0, dot_1); - svst2_f32(p3, (float32_t *)&p_dst[mk * j], dot); + svst2_f32(p3, (float32_t *)&p_dst[nk * j], dot); } #else // Copy the final row of B so we can safely read one extra column: @@ -139,11 +140,11 @@ static void cmplx_mat_mult_ahb_b3x3(uint16_t n, vld2q_f32((float32_t const *)&final_row), }; - for (uint32_t j = 0; j < n; j++) { + for (uint32_t j = 0; j < m; j++) { float32x4x2_t dot; - // Note: We leave it to the compiler to unroll this loop over mk - for (uint32_t r = 0; r < mk; r++) { - float32x2_t a_rj = vld1_f32((float32_t const *)&p_src_a[r * n + j]); + // Note: We leave it to the compiler to unroll this loop over nk + for (uint32_t r = 0; r < nk; r++) { + float32x2_t a_rj = vld1_f32((float32_t const *)&p_src_a[r * m + j]); // Note: We leave it to the compiler to eliminate the following branch if (r == 0) { // dot.re += a_jr.re * b_ir.re + a_jr.im * b_ir.im; @@ -165,18 +166,18 @@ static void cmplx_mat_mult_ahb_b3x3(uint16_t n, // result = r i r i r i X X float32x4x2_t result = vzipq_f32(dot.val[0], dot.val[1]); // Store first the first two columns: - vst1q_f32((float32_t *)&p_dst[mk * j], result.val[0]); + vst1q_f32((float32_t *)&p_dst[nk * j], result.val[0]); // Store the remaining column: - vst1_f32(((float32_t *)&p_dst[mk * j]) + 4, vget_low_f32(result.val[1])); + vst1_f32(((float32_t *)&p_dst[nk * j]) + 4, vget_low_f32(result.val[1])); } #endif } -static void cmplx_mat_mult_ahb_b4x4(uint16_t n, - const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, - armral_cmplx_f32_t *p_dst) { - const uint16_t mk = 4; +void cmplx_matmul_ahb_b4x4(uint16_t m, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *p_dst) { + constexpr uint16_t nk = 4; #if ARMRAL_ARCH_SVE >= 2 svbool_t p4 = svptrue_pat_b32(SV_VL4); @@ -188,12 +189,12 @@ static void cmplx_mat_mult_ahb_b4x4(uint16_t n, svfloat32x2_t b_r3 = svld2_f32(p4, (float32_t const *)&p_src_b[12]); svfloat32x2_t *b_rows[] = {&b_r0, &b_r1, &b_r2, &b_r3}; - for (uint32_t j = 0; j < n; j++) { + for (uint32_t j = 0; j < m; j++) { svfloat32_t dot_0 = svundef_f32(); svfloat32_t dot_1 = svundef_f32(); - // Note: We leave it to the compiler to unroll this loop over mk - for (uint32_t r = 0; r < mk; r++) { - svfloat32_t a_rj = svld1_f32(p2, (float32_t const *)&p_src_a[r * n + j]); + // Note: We leave it to the compiler to unroll this loop over nk + for (uint32_t r = 0; r < nk; r++) { + svfloat32_t a_rj = svld1_f32(p2, (float32_t const *)&p_src_a[r * m + j]); // Note: We leave it to the compiler to eliminate the following branch if (r == 0) { // dot.re += a_jr.re * b_ir.re + a_jr.im * b_ir.im; @@ -213,7 +214,7 @@ static void cmplx_mat_mult_ahb_b4x4(uint16_t n, } svfloat32x2_t dot = svcreate2(dot_0, dot_1); - svst2_f32(p4, (float32_t *)&p_dst[mk * j], dot); + svst2_f32(p4, (float32_t *)&p_dst[nk * j], dot); } #else float32x4x2_t b_rows[4] = {vld2q_f32((float32_t const *)&p_src_b[0]), @@ -221,11 +222,11 @@ static void cmplx_mat_mult_ahb_b4x4(uint16_t n, vld2q_f32((float32_t const *)&p_src_b[8]), vld2q_f32((float32_t const *)&p_src_b[12])}; - for (uint32_t j = 0; j < n; j++) { + for (uint32_t j = 0; j < m; j++) { float32x4x2_t dot; - // Note: We leave it to the compiler to unroll this loop over mk - for (uint32_t r = 0; r < mk; r++) { - float32x2_t a_rj = vld1_f32((float32_t const *)&p_src_a[r * n + j]); + // Note: We leave it to the compiler to unroll this loop over nk + for (uint32_t r = 0; r < nk; r++) { + float32x2_t a_rj = vld1_f32((float32_t const *)&p_src_a[r * m + j]); // Note: We leave it to the compiler to eliminate the following branch if (r == 0) { // dot.re += a_jr.re * b_ir.re + a_jr.im * b_ir.im; @@ -243,27 +244,27 @@ static void cmplx_mat_mult_ahb_b4x4(uint16_t n, dot.val[1] = vfmsq_lane_f32(dot.val[1], b_rows[r].val[0], a_rj, 1); } } - vst2q_f32((float32_t *)&p_dst[mk * j], dot); + vst2q_f32((float32_t *)&p_dst[nk * j], dot); } #endif } #ifdef ARMRAL_ARCH_SVE -static inline void __attribute__((always_inline)) -cmplx_mat_mult_ahb_dot_product_sve_predicated_unroll( +inline void __attribute__((always_inline)) +cmplx_matmul_ahb_dot_product_sve_predicated_unroll( svbool_t pg, uint16_t m, uint16_t n, uint16_t k, uint16_t i, uint16_t j, - const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, armral_cmplx_f32_t *p_dst) { + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, armral_cmplx_f32_t *p_dst) { svbool_t p2 = svptrue_pat_b32(SV_VL2); svfloat32_t acc_1 = svdup_n_f32(0.0F); svfloat32_t acc_2 = svdup_n_f32(0.0F); // Every row of A and B (where row of A = column of A^H) - for (uint32_t r = 0; r < m; r++) { - float32_t const *a_ptr = (float32_t const *)&p_src_a[r * n + j]; + for (uint32_t r = 0; r < k; r++) { + float32_t const *a_ptr = (float32_t const *)&p_src_a[r * m + j]; svfloat32_t a_rj = svld1rq_f32(p2, a_ptr); - float32_t const *b_ptr = (float32_t const *)&p_src_b[r * k + i]; + float32_t const *b_ptr = (float32_t const *)&p_src_b[r * n + i]; svfloat32_t b_vec = svld1_f32(pg, b_ptr); // [R0 * r, I0 * r, R0 * r, I0 * r] acc_1 = svcmla_lane_f32(acc_1, b_vec, a_rj, 0, 0); @@ -273,15 +274,15 @@ cmplx_mat_mult_ahb_dot_product_sve_predicated_unroll( svfloat32_t result = svreinterpret_f32_f64( svneg_f64_x(pg, svreinterpret_f64_f32(svadd_f32_x(pg, acc_1, acc_2)))); - float32_t *c_ptr = (float32_t *)&p_dst[k * j + i]; + float32_t *c_ptr = (float32_t *)&p_dst[n * j + i]; svst1_f32(pg, c_ptr, result); } -static inline void __attribute__((always_inline)) -cmplx_mat_mult_ahb_dot_product_sve_predicated_unroll_j2( +inline void __attribute__((always_inline)) +cmplx_matmul_ahb_dot_product_sve_predicated_unroll_j2( svbool_t pg, uint16_t m, uint16_t n, uint16_t k, uint16_t i, uint16_t j, - const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, armral_cmplx_f32_t *p_dst) { + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, armral_cmplx_f32_t *p_dst) { svbool_t p4 = svptrue_pat_b32(SV_VL4); svfloat32_t acc_1 = svdup_n_f32(0.0F); @@ -290,11 +291,11 @@ cmplx_mat_mult_ahb_dot_product_sve_predicated_unroll_j2( svfloat32_t acc_4 = svdup_n_f32(0.0F); // Every row of A and B (where row of A = column of A^H) - for (uint32_t r = 0; r < m; r++) { + for (uint32_t r = 0; r < k; r++) { // Load Aij and Aij+1 - float32_t const *a_ptr = (float32_t const *)&p_src_a[r * n + j]; + float32_t const *a_ptr = (float32_t const *)&p_src_a[r * m + j]; svfloat32_t a_rj = svld1rq_f32(p4, a_ptr); - float32_t const *b_ptr = (float32_t const *)&p_src_b[r * k + i]; + float32_t const *b_ptr = (float32_t const *)&p_src_b[r * n + i]; svfloat32_t b_vec = svld1_f32(pg, b_ptr); // [R0 * r, I0 * r, R0 * r, I0 * r] acc_1 = svcmla_lane_f32(acc_1, b_vec, a_rj, 0, 0); @@ -310,17 +311,17 @@ cmplx_mat_mult_ahb_dot_product_sve_predicated_unroll_j2( svneg_f64_x(pg, svreinterpret_f64_f32(svadd_f32_x(pg, acc_1, acc_2)))); svfloat32_t result_2 = svreinterpret_f32_f64( svneg_f64_x(pg, svreinterpret_f64_f32(svadd_f32_x(pg, acc_3, acc_4)))); - float32_t *c_ptr_1 = (float32_t *)&p_dst[k * j + i]; - float32_t *c_ptr_2 = (float32_t *)&p_dst[k * (j + 1) + i]; + float32_t *c_ptr_1 = (float32_t *)&p_dst[n * j + i]; + float32_t *c_ptr_2 = (float32_t *)&p_dst[n * (j + 1) + i]; svst1_f32(pg, c_ptr_1, result_1); svst1_f32(pg, c_ptr_2, result_2); } #else -static inline void __attribute__((always_inline)) -cmplx_mat_mult_ahb_dot_product_unroll_by_i2_j2( +inline void __attribute__((always_inline)) +cmplx_matmul_ahb_dot_product_unroll_by_i2_j2( uint16_t m, uint16_t n, uint16_t k, uint16_t i, uint16_t j, - const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, armral_cmplx_f32_t *p_dst) { + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, armral_cmplx_f32_t *p_dst) { float32x4_t acc_1 = vdupq_n_f32(0.0F); float32x4_t acc_2 = vdupq_n_f32(0.0F); @@ -328,12 +329,12 @@ cmplx_mat_mult_ahb_dot_product_unroll_by_i2_j2( float32x4_t acc_4 = vdupq_n_f32(0.0F); // Every row of A and B (where row of A = column of A^H) - for (uint32_t r = 0; r < m; r++) { + for (uint32_t r = 0; r < k; r++) { // A = [R0, I0, R1, I1] (Aij and Aij+1) - float32_t const *a_ptr = (float32_t const *)&p_src_a[r * n + j]; + float32_t const *a_ptr = (float32_t const *)&p_src_a[r * m + j]; float32x4_t a_rj = vld1q_f32(a_ptr); // B = [r, i, r, i] - float32_t const *b_ptr = (float32_t const *)&p_src_b[r * k + i]; + float32_t const *b_ptr = (float32_t const *)&p_src_b[r * n + i]; float32x4_t b_vec = vld1q_f32(b_ptr); // [R0*r, R0*i, R0*r, R0*i] @@ -349,28 +350,28 @@ cmplx_mat_mult_ahb_dot_product_unroll_by_i2_j2( // Correct sign of acc_2 (a*i*-1 + a*i*-1...) = -(a*i + a*i) float32x4_t result_1 = vaddq_f32(acc_1, vnegq64_f32(vrev64q_f32(acc_2))); float32x4_t result_2 = vaddq_f32(acc_3, vnegq64_f32(vrev64q_f32(acc_4))); - float32_t *c_ptr_1 = (float32_t *)&p_dst[k * j + i]; - float32_t *c_ptr_2 = (float32_t *)&p_dst[k * (j + 1) + i]; + float32_t *c_ptr_1 = (float32_t *)&p_dst[n * j + i]; + float32_t *c_ptr_2 = (float32_t *)&p_dst[n * (j + 1) + i]; vst1q_f32(c_ptr_1, result_1); vst1q_f32(c_ptr_2, result_2); } -static inline void __attribute__((always_inline)) -cmplx_mat_mult_ahb_dot_product_unroll_by_i2( +inline void __attribute__((always_inline)) +cmplx_matmul_ahb_dot_product_unroll_by_i2( uint16_t m, uint16_t n, uint16_t k, uint16_t i, uint16_t j, - const armral_cmplx_f32_t *restrict p_src_a, - const armral_cmplx_f32_t *restrict p_src_b, armral_cmplx_f32_t *p_dst) { + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, armral_cmplx_f32_t *p_dst) { float32x4_t acc_1 = vdupq_n_f32(0.0F); float32x4_t acc_2 = vdupq_n_f32(0.0F); // Every row of A and B (where row of A = column of A^H) - for (uint32_t r = 0; r < m; r++) { + for (uint32_t r = 0; r < k; r++) { // A = [R0, I0] - float32_t const *a_ptr = (float32_t const *)&p_src_a[r * n + j]; + float32_t const *a_ptr = (float32_t const *)&p_src_a[r * m + j]; float32x2_t a_rj = vld1_f32(a_ptr); // B = [r, i, r, i] - float32_t const *b_ptr = (float32_t const *)&p_src_b[r * k + i]; + float32_t const *b_ptr = (float32_t const *)&p_src_b[r * n + i]; float32x4_t b_vec = vld1q_f32(b_ptr); // [R0*r, R0*i, R0*r, R0*i] @@ -381,12 +382,11 @@ cmplx_mat_mult_ahb_dot_product_unroll_by_i2( // Correct sign of acc_2 (a*i*-1 + a*i*-1...) = -(a*i + a*i) float32x4_t result = vaddq_f32(acc_1, vnegq64_f32(vrev64q_f32(acc_2))); - float32_t *c_ptr = (float32_t *)&p_dst[k * j + i]; + float32_t *c_ptr = (float32_t *)&p_dst[n * j + i]; vst1q_f32(c_ptr, result); } -static inline void __attribute__((always_inline)) -cmplx_mat_mult_ahb_dot_product_scalar( +inline void __attribute__((always_inline)) cmplx_matmul_ahb_dot_product_scalar( uint16_t m, uint16_t n, uint16_t k, uint16_t i, uint16_t j, const armral_cmplx_f32_t *__restrict p_src_a, const armral_cmplx_f32_t *__restrict p_src_b, armral_cmplx_f32_t *p_dst) { @@ -394,9 +394,9 @@ cmplx_mat_mult_ahb_dot_product_scalar( float32x2_t acc_1 = vdup_n_f32(0.0F); float32x2_t acc_2 = vdup_n_f32(0.0F); // Every row of A and B (where row of A = column of A^H) - for (uint32_t r = 0; r < m; r++) { - float32x2_t a_rj = vld1_f32((float32_t const *)&p_src_a[r * n + j]); - float32x2_t b_ir = vld1_f32((float32_t const *)&p_src_b[r * k + i]); + for (uint32_t r = 0; r < k; r++) { + float32x2_t a_rj = vld1_f32((float32_t const *)&p_src_a[r * m + j]); + float32x2_t b_ir = vld1_f32((float32_t const *)&p_src_b[r * n + i]); // [R0*r, R0*i, R0*r, R0*i] acc_1 = vfma_lane_f32(acc_1, b_ir, a_rj, 0); // [I0*i, I0*r, I0*i, I0*r] @@ -404,11 +404,11 @@ cmplx_mat_mult_ahb_dot_product_scalar( } float32x2_t result = vadd_f32(acc_1, vneg64_f32(vrev64_f32(acc_2))); - vst1_f32((float32_t *)&p_dst[k * j + i], result); + vst1_f32((float32_t *)&p_dst[n * j + i], result); } -static inline void __attribute__((always_inline)) -cmplx_mat_mult_ahb_dot_product_unroll_by_j2( +inline void __attribute__((always_inline)) +cmplx_matmul_ahb_dot_product_unroll_by_j2( uint16_t m, uint16_t n, uint16_t k, uint16_t i, uint16_t j, const armral_cmplx_f32_t *__restrict p_src_a, const armral_cmplx_f32_t *__restrict p_src_b, armral_cmplx_f32_t *p_dst) { @@ -419,9 +419,9 @@ cmplx_mat_mult_ahb_dot_product_unroll_by_j2( float32x2_t acc_4 = vdup_n_f32(0.0F); // Every row of A and B (where row of A = column of A^H) - for (uint32_t r = 0; r < m; r++) { - float32x4_t a_rj = vld1q_f32((float32_t const *)&p_src_a[r * n + j]); - float32x2_t b_ir = vld1_f32((float32_t const *)&p_src_b[r * k + i]); + for (uint32_t r = 0; r < k; r++) { + float32x4_t a_rj = vld1q_f32((float32_t const *)&p_src_a[r * m + j]); + float32x2_t b_ir = vld1_f32((float32_t const *)&p_src_b[r * n + i]); // [R0*r, R0*i, R0*r, R0*i] acc_1 = vfma_laneq_f32(acc_1, b_ir, a_rj, 0); // [I0*i, I0*r, I0*i, I0*r] @@ -434,111 +434,113 @@ cmplx_mat_mult_ahb_dot_product_unroll_by_j2( float32x2_t result_1 = vadd_f32(acc_1, vneg64_f32(vrev64_f32(acc_2))); float32x2_t result_2 = vadd_f32(acc_3, vneg64_f32(vrev64_f32(acc_4))); - vst1_f32((float32_t *)&p_dst[k * j + i], result_1); - vst1_f32((float32_t *)&p_dst[k * (j + 1) + i], result_2); + vst1_f32((float32_t *)&p_dst[n * j + i], result_1); + vst1_f32((float32_t *)&p_dst[n * (j + 1) + i], result_2); } #endif -static inline void __attribute__((always_inline)) -cmplx_mat_mult_ahb_general_f32(uint16_t m, uint16_t n, uint16_t k, - const armral_cmplx_f32_t *__restrict p_src_a, - const armral_cmplx_f32_t *__restrict p_src_b, - armral_cmplx_f32_t *p_dst) { +inline void __attribute__((always_inline)) +cmplx_matmul_ahb_general_f32(uint16_t m, uint16_t n, uint16_t k, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *p_dst) { #ifdef ARMRAL_ARCH_SVE uint32_t i_unroll = svcntd(); svbool_t ptrue = svptrue_b32(); // For every row of output C (column of A, row of A^H)... uint32_t j = 0; - for (; j + 2 <= n; j += 2) { + for (; j + 2 <= m; j += 2) { uint32_t i = 0; // For every VL/2 columns of output C (column of B)... - for (; i + i_unroll <= k; i += i_unroll) { - cmplx_mat_mult_ahb_dot_product_sve_predicated_unroll_j2( + for (; i + i_unroll <= n; i += i_unroll) { + cmplx_matmul_ahb_dot_product_sve_predicated_unroll_j2( ptrue, m, n, k, i, j, p_src_a, p_src_b, p_dst); } // For the remaining columns of output C... - if (i < k) { - svbool_t tail = svwhilelt_b32(i * 2, (uint32_t)k * 2); - cmplx_mat_mult_ahb_dot_product_sve_predicated_unroll_j2( + if (i < n) { + svbool_t tail = svwhilelt_b32(i * 2, (uint32_t)n * 2); + cmplx_matmul_ahb_dot_product_sve_predicated_unroll_j2( tail, m, n, k, i, j, p_src_a, p_src_b, p_dst); } } - if (j != n) { + if (j != m) { uint32_t i = 0; // For every VL/2 columns of output C (column of B)... - for (; i + i_unroll <= k; i += i_unroll) { - cmplx_mat_mult_ahb_dot_product_sve_predicated_unroll( + for (; i + i_unroll <= n; i += i_unroll) { + cmplx_matmul_ahb_dot_product_sve_predicated_unroll( ptrue, m, n, k, i, j, p_src_a, p_src_b, p_dst); } // For the remaining columns of output C... - if (i < k) { - svbool_t tail = svwhilelt_b32(i * 2, (uint32_t)k * 2); - cmplx_mat_mult_ahb_dot_product_sve_predicated_unroll( + if (i < n) { + svbool_t tail = svwhilelt_b32(i * 2, (uint32_t)n * 2); + cmplx_matmul_ahb_dot_product_sve_predicated_unroll( tail, m, n, k, i, j, p_src_a, p_src_b, p_dst); } } #else // For every two rows of output C (column of A, row of A^H)... uint32_t j = 0; - for (; j + 2 <= n; j += 2) { + for (; j + 2 <= m; j += 2) { uint32_t i = 0; // For every two columns of output C (column of B)... - for (; i + 2 <= k; i += 2) { - cmplx_mat_mult_ahb_dot_product_unroll_by_i2_j2(m, n, k, i, j, p_src_a, - p_src_b, p_dst); + for (; i + 2 <= n; i += 2) { + cmplx_matmul_ahb_dot_product_unroll_by_i2_j2(m, n, k, i, j, p_src_a, + p_src_b, p_dst); } // For the remaining columns of output C... - if (i != k) { - cmplx_mat_mult_ahb_dot_product_unroll_by_j2(m, n, k, i, j, p_src_a, - p_src_b, p_dst); + if (i != n) { + cmplx_matmul_ahb_dot_product_unroll_by_j2(m, n, k, i, j, p_src_a, p_src_b, + p_dst); } } // For the remaining rows of output C... - if (j != n) { + if (j != m) { uint32_t i = 0; // For every two columns of output C (column of B)... - for (; i + 2 <= k; i += 2) { - cmplx_mat_mult_ahb_dot_product_unroll_by_i2(m, n, k, i, j, p_src_a, - p_src_b, p_dst); + for (; i + 2 <= n; i += 2) { + cmplx_matmul_ahb_dot_product_unroll_by_i2(m, n, k, i, j, p_src_a, p_src_b, + p_dst); } // For the remaining columns of output C... - if (i != k) { - cmplx_mat_mult_ahb_dot_product_scalar(m, n, k, i, j, p_src_a, p_src_b, - p_dst); + if (i != n) { + cmplx_matmul_ahb_dot_product_scalar(m, n, k, i, j, p_src_a, p_src_b, + p_dst); } } #endif } +} // namespace + armral_status -armral_cmplx_mat_mult_ahb_f32(uint16_t m, uint16_t n, uint16_t k, - const armral_cmplx_f32_t *__restrict p_src_a, - const armral_cmplx_f32_t *__restrict p_src_b, - armral_cmplx_f32_t *p_dst) { - - if (m == k) { - if (k == 2) { - cmplx_mat_mult_ahb_b2x2(n, p_src_a, p_src_b, p_dst); +armral_cmplx_matmul_ahb_f32(uint16_t m, uint16_t n, uint16_t k, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *p_dst) { + + if (n == k) { + if (n == 2) { + cmplx_matmul_ahb_b2x2(m, p_src_a, p_src_b, p_dst); return ARMRAL_SUCCESS; } - if (k == 3) { - cmplx_mat_mult_ahb_b3x3(n, p_src_a, p_src_b, p_dst); + if (n == 3) { + cmplx_matmul_ahb_b3x3(m, p_src_a, p_src_b, p_dst); return ARMRAL_SUCCESS; } - if (k == 4) { - cmplx_mat_mult_ahb_b4x4(n, p_src_a, p_src_b, p_dst); + if (n == 4) { + cmplx_matmul_ahb_b4x4(m, p_src_a, p_src_b, p_dst); return ARMRAL_SUCCESS; } - if (k == 8) { + if (n == 8) { // Note: With compile-time constant parameters the compiler can sometimes // further optimize the 8x8 case. - cmplx_mat_mult_ahb_general_f32(8, n, 8, p_src_a, p_src_b, p_dst); + cmplx_matmul_ahb_general_f32(m, 8, 8, p_src_a, p_src_b, p_dst); return ARMRAL_SUCCESS; } } // Fallback: General vectorized implementation: - cmplx_mat_mult_ahb_general_f32(m, n, k, p_src_a, p_src_b, p_dst); + cmplx_matmul_ahb_general_f32(m, n, k, p_src_a, p_src_b, p_dst); return ARMRAL_SUCCESS; } diff --git a/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_f32.cpp b/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_f32.cpp new file mode 100644 index 0000000..3c38b69 --- /dev/null +++ b/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_f32.cpp @@ -0,0 +1,1985 @@ +/* + Arm RAN Acceleration Library + SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates +*/ + +#include "armral.h" + +#ifdef ARMRAL_ARCH_SVE +#include +#endif + +namespace { + +#ifndef ARMRAL_ARCH_SVE +inline float32x4_t __attribute__((always_inline)) +vzip1q_f32x2(float32x4_t a, float32x4_t b) { + // This zips a pair of 32-bit floats in a 128-bit vector, e.g. given 32-bit + // vectors + // ^: a = [a0, a1, a2, a3] + // ^: b = [b0, b1, b2, b3] + // ^: returns + // ^: c = [a0, a1, b0, b1] + return vreinterpretq_f32_f64( + vzip1q_f64(vreinterpretq_f64_f32(a), vreinterpretq_f64_f32(b))); +} + +inline float32x4_t __attribute__((always_inline)) +vzip2q_f32x2(float32x4_t a, float32x4_t b) { + // This zips a pair of 32-bit floats in 128-bit vector, e.g. given 32-bit + // vectors + // ^: a = [a0, a1, a2, a3] + // ^: b = [b0, b1, b2, b3] + // ^: returns + // ^: c = [a2, a3, b2, b3] + return vreinterpretq_f32_f64( + vzip2q_f64(vreinterpretq_f64_f32(a), vreinterpretq_f64_f32(b))); +} +#endif + +#ifdef ARMRAL_ARCH_SVE +// Calculates a vector width of consecutive output elements in a matrix product +// of a m x k and k x n matrix. p_src_a and p_src_b must point to the start +// row/column respectively, the operation must be valid and the result will be +// stored at exactly dst. +template +inline void sve_mat_one_row_dot(svbool_t pg, const uint16_t k, + const uint16_t b_dst_stride, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *dst) { + + svfloat32_t c0; + if constexpr (accumulate) { + c0 = svld1_f32(pg, (float32_t *)dst); + } else { + c0 = svdup_f32(0); + } + for (int h = 0; h < k - 1; h += 2) { + svbool_t pa = svwhilelt_b32(h * 2, k * 2); + svfloat32_t a0i = svld1rq_f32(pa, (const float32_t *)&p_src_a[h]); + svfloat32_t bi0 = + svld1_f32(pg, (const float32_t *)&p_src_b[h * b_dst_stride]); + svfloat32_t bi1 = svld1_f32( + pg, (const float32_t *)&p_src_b[h * b_dst_stride + b_dst_stride]); + c0 = svcmla_lane_f32(c0, bi0, a0i, 0, 0); + c0 = svcmla_lane_f32(c0, bi0, a0i, 0, 90); + c0 = svcmla_lane_f32(c0, bi1, a0i, 1, 0); + c0 = svcmla_lane_f32(c0, bi1, a0i, 1, 90); + } + + // If k is odd, we have one more row/col to go + if (k % 2) { + svfloat32_t ak = + svreinterpret_f32_u64(svdup_u64(*((const uint64_t *)&p_src_a[k - 1]))); + svfloat32_t bk = + svld1_f32(pg, (const float32_t *)&p_src_b[(k - 1) * b_dst_stride]); + c0 = svcmla_f32_x(pg, c0, ak, bk, 0); + c0 = svcmla_f32_x(pg, c0, ak, bk, 90); + } + svst1_f32(pg, (float32_t *)dst, c0); +} + +// Calculates 2 vector widths of consecutive output elements in a matrix product +// of a m x k and k x n matrix. p_src_a and p_src_b must point to the start +// row/column respectively, the operation must be valid and the result will be +// stored at exactly dst. +template +inline void sve_mat_two_row_dot(svbool_t pg, const uint16_t k, + const uint16_t a_stride, + const uint16_t b_dst_stride, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *dst) { + + svfloat32_t c0; + svfloat32_t c1; + if constexpr (accumulate) { + c0 = svld1_f32(pg, (float32_t *)&dst[0]); + c1 = svld1_f32(pg, (float32_t *)&dst[b_dst_stride]); + } else { + c0 = svdup_f32(0); + c1 = svdup_f32(0); + } + for (int h = 0; h < k - 1; h += 2) { + svbool_t pa = svwhilelt_b32(h * 2, k * 2); + svfloat32_t a0i = svld1rq_f32(pa, (const float32_t *)&p_src_a[h]); + svfloat32_t a1i = + svld1rq_f32(pa, (const float32_t *)&p_src_a[h + a_stride]); + svfloat32_t bi0 = + svld1_f32(pg, (const float32_t *)&p_src_b[h * b_dst_stride]); + svfloat32_t bi1 = svld1_f32( + pg, (const float32_t *)&p_src_b[h * b_dst_stride + b_dst_stride]); + + c0 = svcmla_lane_f32(c0, bi0, a0i, 0, 0); + c0 = svcmla_lane_f32(c0, bi0, a0i, 0, 90); + c0 = svcmla_lane_f32(c0, bi1, a0i, 1, 0); + c0 = svcmla_lane_f32(c0, bi1, a0i, 1, 90); + c1 = svcmla_lane_f32(c1, bi0, a1i, 0, 0); + c1 = svcmla_lane_f32(c1, bi0, a1i, 0, 90); + c1 = svcmla_lane_f32(c1, bi1, a1i, 1, 0); + c1 = svcmla_lane_f32(c1, bi1, a1i, 1, 90); + } + + // If k is odd, we have one more row/col to go + if (k % 2) { + svfloat32_t a0k = + svreinterpret_f32_u64(svdup_u64(*((const uint64_t *)&p_src_a[k - 1]))); + svfloat32_t a1k = svreinterpret_f32_u64( + svdup_u64(*((const uint64_t *)&p_src_a[k - 1 + a_stride]))); + + svfloat32_t bk = + svld1_f32(pg, (const float32_t *)&p_src_b[(k - 1) * b_dst_stride]); + c0 = svcmla_f32_x(pg, c0, a0k, bk, 0); + c0 = svcmla_f32_x(pg, c0, a0k, bk, 90); + c1 = svcmla_f32_x(pg, c1, a1k, bk, 0); + c1 = svcmla_f32_x(pg, c1, a1k, bk, 90); + } + svst1_f32(pg, (float32_t *)&dst[0], c0); + svst1_f32(pg, (float32_t *)&dst[b_dst_stride], c1); +} +#endif + +// Computes c += a b or c = a b depending on whether the value of accumulate is +// true or false. a, b and c are sub-matrices of p_src_a, p_src_b and p_dst, +// respectively, where a_stride is the total number of columns of matrix +// p_src_a and b_dst_stride is the total number of columns of matrices +// p_src_b/p_dst. This function expects row-major input and gives row-major +// output. +template +inline armral_status +cmplx_matmul_f32(const uint16_t m, const uint16_t n, const uint16_t k, + uint16_t a_stride, uint16_t b_dst_stride, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *__restrict p_dst) { +#ifdef ARMRAL_ARCH_SVE + for (int i = 0; i < m - 1; i += 2) { + for (int j = 0; j < n; j += svcntd()) { + const svbool_t pg = svwhilelt_b32(2 * j, 2 * n); + sve_mat_two_row_dot(pg, k, a_stride, b_dst_stride, + &p_src_a[i * a_stride], &p_src_b[j], + &p_dst[i * b_dst_stride + j]); + } + } + if (m % 2) { + const int i = m - 1; + for (int j = 0; j < n; j += svcntd()) { + const svbool_t pg = svwhilelt_b32(2 * j, 2 * n); + sve_mat_one_row_dot(pg, k, b_dst_stride, + &p_src_a[i * a_stride], &p_src_b[j], + &p_dst[i * b_dst_stride + j]); + } + } + return ARMRAL_SUCCESS; +#else + const float32_t *p_in1 = (const float32_t *)p_src_a; + const float32_t *p_in2 = (const float32_t *)p_src_b; + const armral_cmplx_f32_t *p_in_a = p_src_a; + armral_cmplx_f32_t *p_out = p_dst; + armral_cmplx_f32_t *px; + uint16_t num_rows_a = m; + uint16_t num_cols_b = n; + uint16_t num_cols_a = k; + + float32x4x2_t a0_v; + float32x4x2_t a1_v; + float32x4_t temp_r2; + float32x4_t temp_i2; + float32x4_t b0_v; + float32x4_t b1_v; + float32x4_t b2_v; + float32x4_t b3_v; + float32x4_t b_col_real; + float32x4_t b_col_im; + float32x4_t b_col_real2; + float32x4_t b_col_im2; + float32x2_t accum = vdup_n_f32(0); + const float32_t *p_in1_b = (const float32_t *)p_src_a; + const float32_t *p_in1_b2 = (const float32_t *)p_src_b; + + uint16_t col; + uint16_t i = 0U; + uint16_t j; + uint16_t row_cnt; + uint16_t row = num_rows_a; + uint16_t col_cnt; + armral_cmplx_f32_t *px_b; + + // The following loop performs the dot-product of each row in pSrcA with each + // column in pSrcB + + row_cnt = row >> 1; + // Row loop + while (row_cnt > 0U) { + // Output pointer is set to starting address of the row being processed + px = p_out + i; + px_b = px + b_dst_stride; + + // For every row wise process, the column loop counter is to be initiated + col = num_cols_b; + + // For every row wise process, the pIn2 pointer is set + // to the starting address of the pSrcB data + p_in2 = (const float32_t *)p_src_b; + p_in1_b2 = p_in2 + 2 * b_dst_stride; + + j = 0U; + + // Column loop + col >>= 1; + while (col > 0U) { + // Set the variable sum, that acts as accumulator, to zero + float32_t sum_real1 = 0.0F; + float32_t sum_imag1 = 0.0F; + float32_t sum_real1_b = 0.0F; + float32_t sum_imag1_b = 0.0F; + + float32_t sum_real2 = 0.0F; + float32_t sum_imag2 = 0.0F; + float32_t sum_real2_b = 0.0F; + float32_t sum_imag2_b = 0.0F; + + float32_t sum_real3 = 0.0F; + float32_t sum_imag3 = 0.0F; + float32_t sum_real3_b = 0.0F; + float32_t sum_imag3_b = 0.0F; + + float32_t sum_real4 = 0.0F; + float32_t sum_imag4 = 0.0F; + float32_t sum_real4_b = 0.0F; + float32_t sum_imag4_b = 0.0F; + + // Initialize the pointer pIn1 to point to the starting address of the + // column being processed + p_in1 = (const float32_t *)p_in_a; + p_in1_b = p_in1 + 2 * a_stride; + + float32x4_t acc_r0 = {}; + float32x4_t acc_i0 = {}; + float32x4_t acc_r1 = {}; + float32x4_t acc_i1 = {}; + float32x4_t acc_r2 = {}; + float32x4_t acc_i2 = {}; + float32x4_t acc_r3 = {}; + float32x4_t acc_i3 = {}; + + // Compute 4 MACs simultaneously + col_cnt = num_cols_a >> 2; + + // Matrix multiplication + while (col_cnt > 0U) { + float32x4_t temp_r = {}; + float32x4_t temp_i = {}; + // Load & separate real/imag pSrcA (de-interleave 2) + a0_v = vld2q_f32(p_in1); + // Load & separate real/imag pSrcA (de-interleave 2) + a1_v = vld2q_f32(p_in1_b); + + p_in1 += 8; + p_in1_b += 8; + + // Load but don't separate real/imag + b0_v = vld1q_f32(p_in2); + b1_v = vld1q_f32(p_in1_b2); + b2_v = vld1q_f32(p_in2 + 4 * b_dst_stride); + b3_v = vld1q_f32(p_in1_b2 + 4 * b_dst_stride); + + p_in2 = p_in2 + 8 * b_dst_stride; + p_in1_b2 = p_in1_b2 + 8 * b_dst_stride; + b_col_real = vtrn1q_f32(b0_v, b1_v); // even elem + b_col_im = vtrn2q_f32(b0_v, b1_v); // odd elem + b_col_real2 = vtrn1q_f32(b2_v, b3_v); // even elem + b_col_im2 = vtrn2q_f32(b2_v, b3_v); // odd elem + + // First column + temp_r = vzip1q_f32x2(b_col_real, b_col_real2); + temp_i = vzip1q_f32x2(b_col_im, b_col_im2); + + // Second column + temp_r2 = vzip2q_f32x2(b_col_real, b_col_real2); + temp_i2 = vzip2q_f32x2(b_col_im, b_col_im2); + + acc_r0 = vfmaq_f32(acc_r0, a0_v.val[0], temp_r); + acc_r0 = vfmsq_f32(acc_r0, a0_v.val[1], temp_i); + acc_i0 = vfmaq_f32(acc_i0, a0_v.val[1], temp_r); + acc_i0 = vfmaq_f32(acc_i0, a0_v.val[0], temp_i); + // Same row A, next column B + acc_r2 = vfmaq_f32(acc_r2, a0_v.val[0], temp_r2); + acc_r2 = vfmsq_f32(acc_r2, a0_v.val[1], temp_i2); + + acc_i2 = vfmaq_f32(acc_i2, a0_v.val[1], temp_r2); + acc_i2 = vfmaq_f32(acc_i2, a0_v.val[0], temp_i2); + + acc_r1 = vfmaq_f32(acc_r1, a1_v.val[0], temp_r); + acc_r1 = vfmsq_f32(acc_r1, a1_v.val[1], temp_i); + + acc_i1 = vfmaq_f32(acc_i1, a1_v.val[1], temp_r); + acc_i1 = vfmaq_f32(acc_i1, a1_v.val[0], temp_i); + // Same row A, next column B + acc_r3 = vfmaq_f32(acc_r3, a1_v.val[0], temp_r2); + acc_r3 = vfmsq_f32(acc_r3, a1_v.val[1], temp_i2); + + acc_i3 = vfmaq_f32(acc_i3, a1_v.val[1], temp_r2); + acc_i3 = vfmaq_f32(acc_i3, a1_v.val[0], temp_i2); + + col_cnt--; + } + + sum_real1 += vaddvq_f32(acc_r0); + sum_imag1 += vaddvq_f32(acc_i0); + sum_real3 += vaddvq_f32(acc_r2); + sum_imag3 += vaddvq_f32(acc_i2); + + sum_real1_b += vaddvq_f32(acc_r1); + sum_imag1_b += vaddvq_f32(acc_i1); + sum_real3_b += vaddvq_f32(acc_r3); + sum_imag3_b += vaddvq_f32(acc_i3); + + // If the columns of pSrcA is not a multiple of 4, compute any remaining + // MACs here. + // No loop unrolling is used. + col_cnt = num_cols_a & 3; + while (col_cnt > 0U) { + + float32_t a1 = *p_in1; + float32_t a1_b = *p_in1_b; + + float32_t c1 = *p_in2; + float32_t c1_b = *(p_in2 + 2U); + + float32_t b1 = *(p_in1 + 1U); + float32_t b1_b = *(p_in1_b + 1U); + + float32_t d1 = *(p_in2 + 1U); + float32_t d1_b = *(p_in2 + 3U); + + sum_real1 += a1 * c1; + sum_imag1 += b1 * c1; + + sum_real3 += a1 * c1_b; + sum_imag3 += b1 * c1_b; + + sum_real1_b += a1_b * c1; + sum_imag1_b += b1_b * c1; + + sum_real3_b += a1_b * c1_b; + sum_imag3_b += b1_b * c1_b; + + p_in1 += 2U; + p_in1_b += 2U; + p_in2 += 2 * b_dst_stride; + + sum_real2 -= b1 * d1; + sum_imag2 += a1 * d1; + + sum_real4 -= b1 * d1_b; + sum_imag4 += a1 * d1_b; + + sum_real2_b -= b1_b * d1; + sum_imag2_b += a1_b * d1; + + sum_real4_b -= b1_b * d1_b; + sum_imag4_b += a1_b * d1_b; + + col_cnt--; + } + + sum_real1 += sum_real2; + sum_imag1 += sum_imag2; + + sum_real3 += sum_real4; + sum_imag3 += sum_imag4; + + sum_real1_b += sum_real2_b; + sum_imag1_b += sum_imag2_b; + + sum_real3_b += sum_real4_b; + sum_imag3_b += sum_imag4_b; + + // Store the result in the destination buffer + if constexpr (accumulate) { + (*px).re += sum_real1; + (*px).im += sum_imag1; + px++; + (*px).re += sum_real3; + (*px).im += sum_imag3; + px++; + (*px_b).re += sum_real1_b; + (*px_b).im += sum_imag1_b; + px_b++; + (*px_b).re += sum_real3_b; + (*px_b).im += sum_imag3_b; + px_b++; + } else { + (*px).re = sum_real1; + (*px).im = sum_imag1; + px++; + (*px).re = sum_real3; + (*px).im = sum_imag3; + px++; + (*px_b).re = sum_real1_b; + (*px_b).im = sum_imag1_b; + px_b++; + (*px_b).re = sum_real3_b; + (*px_b).im = sum_imag3_b; + px_b++; + } + // Update the pointer pIn2 to point to the starting address of the + // next column + j++; + p_in2 = (const float32_t *)p_src_b + 4U * j; + p_in1_b2 = p_in2 + 2U * b_dst_stride; + col--; + } + + col = num_cols_b & 1; + if (col) { + // Set the variable sum, that acts as accumulator, to zero + float32_t sum_real1 = 0.0F; + float32_t sum_imag1 = 0.0F; + float32_t sum_real2 = 0.0F; + float32_t sum_imag2 = 0.0F; + float32_t sum_real1_b = 0.0F; + float32_t sum_imag1_b = 0.0F; + float32_t sum_real2_b = 0.0F; + float32_t sum_imag2_b = 0.0F; + + // Initialize the pointer pIn1 to point to the starting address of the + // column being processed + p_in1 = (const float32_t *)p_in_a; + p_in1_b = p_in1 + 2 * a_stride; + + float32x4_t acc_r0 = {}; + float32x4_t acc_i0 = {}; + float32x4_t acc_r1 = {}; + float32x4_t acc_i1 = {}; + + // Compute 4 MACs simultaneously + col_cnt = num_cols_a >> 2; + + // Matrix multiplication + while (col_cnt > 0U) { + // Load & separate real/imag pSrcA (de-interleave 2) + a0_v = vld2q_f32(p_in1); + a1_v = vld2q_f32(p_in1_b); + + p_in1 += 8; + p_in1_b += 8; + + // Load but don't separate real/imag + float32x2_t b_four_rows[4]; + b_four_rows[0] = vld1_f32(p_in2); + b_four_rows[1] = vld1_f32(p_in1_b2); + b_four_rows[2] = vld1_f32(p_in2 + 4 * b_dst_stride); + b_four_rows[3] = vld1_f32(p_in1_b2 + 4 * b_dst_stride); + + p_in2 = p_in2 + 8 * b_dst_stride; + p_in1_b2 = p_in1_b2 + 8 * b_dst_stride; + float32x4_t b_tmp_real; + float32x4_t b_tmp_im; + b_tmp_real = vcombine_f32(vtrn1_f32(b_four_rows[0], b_four_rows[1]), + vtrn1_f32(b_four_rows[2], b_four_rows[3])); + b_tmp_im = vcombine_f32(vtrn2_f32(b_four_rows[0], b_four_rows[1]), + vtrn2_f32(b_four_rows[2], b_four_rows[3])); + + // Real * real + acc_r0 = vfmaq_f32(acc_r0, a0_v.val[0], b_tmp_real); + // Imag * imag + acc_r0 = vfmsq_f32(acc_r0, a0_v.val[1], b_tmp_im); + // Imag * real + acc_i0 = vfmaq_f32(acc_i0, a0_v.val[1], b_tmp_real); + // Real * imag + acc_i0 = vfmaq_f32(acc_i0, a0_v.val[0], b_tmp_im); + + // Real * real + acc_r1 = vfmaq_f32(acc_r1, a1_v.val[0], b_tmp_real); + // Imag * imag + acc_r1 = vfmsq_f32(acc_r1, a1_v.val[1], b_tmp_im); + // Imag * real + acc_i1 = vfmaq_f32(acc_i1, a1_v.val[1], b_tmp_real); + // Real * imag + acc_i1 = vfmaq_f32(acc_i1, a1_v.val[0], b_tmp_im); + + col_cnt--; + } + + sum_real1 += vaddvq_f32(acc_r0); + sum_imag1 += vaddvq_f32(acc_i0); + + sum_real1_b += vaddvq_f32(acc_r1); + sum_imag1_b += vaddvq_f32(acc_i1); + + // If the columns of pSrcA is not a multiple of 4, compute any remaining + // MACs here. + // No loop unrolling is used. + col_cnt = num_cols_a & 3; + while (col_cnt > 0U) { + + float32_t a1 = *p_in1; // real part of entry from A + float32_t a1_b = *p_in1_b; + float32_t c1 = *p_in2; // real part of entry from B + + float32_t b1 = *(p_in1 + 1U); // imaginary part of entry from A + float32_t b1_b = *(p_in1_b + 1U); + float32_t d1 = *(p_in2 + 1U); // imaginary part of entry from B + + // Real * real + sum_real1 += a1 * c1; + // Imag * real + sum_imag1 += b1 * c1; + + // Imag * imag + sum_real2 -= b1 * d1; + // Real * imag + sum_imag2 += a1 * d1; + + // Real * real + sum_real1_b += a1_b * c1; + // Imag * real + sum_imag1_b += b1_b * c1; + + // Imag * imag + sum_real2_b -= b1_b * d1; + // Real * imag + sum_imag2_b += a1_b * d1; + + p_in1 += 2U; + p_in1_b += 2U; + p_in2 += 2 * b_dst_stride; + + col_cnt--; + } + + sum_real1 += sum_real2; + sum_imag1 += sum_imag2; + + sum_real1_b += sum_real2_b; + sum_imag1_b += sum_imag2_b; + + // Store the result in the destination buffer + if constexpr (accumulate) { + (*px).re += sum_real1; + (*px).im += sum_imag1; + px++; + (*px_b).re += sum_real1_b; + (*px_b).im += sum_imag1_b; + px_b++; + } else { + (*px).re = sum_real1; + (*px).im = sum_imag1; + px++; + (*px_b).re = sum_real1_b; + (*px_b).im = sum_imag1_b; + px_b++; + } + // Update the pointer pIn2 to point to the starting address of the next + // column + j++; + } + + // Update the pointer pInA to point to the starting address of the next 2 + // row + i = i + 2 * b_dst_stride; + p_in_a = p_in_a + 2 * a_stride; + + row_cnt--; + } + + row_cnt = row & 1; + while (row_cnt > 0U) { + // Output pointer is set to starting address of the row being processed + px = p_out + i; + + // For every row wise process, the column loop counter is to be initiated + col = num_cols_b; + + // For every row wise process, the pIn2 pointer is set + // to the starting address of the pSrcB data + p_in2 = (const float32_t *)p_src_b; + + j = 0U; + + // Column loop + while (col > 0U) { + // Set the variable sum, that acts as accumulator, to zero + float32_t sum_real1 = 0.0F; + float32_t sum_imag1 = 0.0F; + + float32_t sum_real2 = 0.0F; + float32_t sum_imag2 = 0.0F; + + // Initialize the pointer pIn1 to point to the starting address of the + // column being processed + p_in1 = (const float32_t *)p_in_a; + + float32x4_t acc_r0 = {}; + float32x4_t acc_i0 = {}; + + // Compute 4 MACs simultaneously + col_cnt = num_cols_a >> 2; + + // Matrix multiplication + while (col_cnt > 0U) { + float32x4_t temp_r = {}; + float32x4_t temp_i = {}; + // Reading real part of complex matrix A + // load & separate real/imag p_src_a (de-interleave 2) + a0_v = vld2q_f32(p_in1); + p_in1 += 8; + + temp_r[0] = *p_in2; + temp_i[0] = *(p_in2 + 1U); + p_in2 += 2 * b_dst_stride; + + temp_r[1] = *p_in2; + temp_i[1] = *(p_in2 + 1U); + p_in2 += 2 * b_dst_stride; + + temp_r[2] = *p_in2; + temp_i[2] = *(p_in2 + 1U); + p_in2 += 2 * b_dst_stride; + + temp_r[3] = *p_in2; + temp_i[3] = *(p_in2 + 1U); + p_in2 += 2 * b_dst_stride; + + acc_r0 = vfmaq_f32(acc_r0, a0_v.val[0], temp_r); + acc_r0 = vfmsq_f32(acc_r0, a0_v.val[1], temp_i); + + acc_i0 = vfmaq_f32(acc_i0, a0_v.val[1], temp_r); + acc_i0 = vfmaq_f32(acc_i0, a0_v.val[0], temp_i); + + col_cnt--; + } + + accum = vpadd_f32(vget_low_f32(acc_r0), vget_high_f32(acc_r0)); + sum_real1 += accum[0] + accum[1]; + + accum = vpadd_f32(vget_low_f32(acc_i0), vget_high_f32(acc_i0)); + sum_imag1 += accum[0] + accum[1]; + + // If the columns of pSrcA is not a multiple of 4, compute any remaining + // MACs here. + // No loop unrolling is used. + col_cnt = num_cols_a & 3; + + while (col_cnt > 0U) { + // c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) + float32_t a1 = *p_in1; + float32_t c1 = *p_in2; + + float32_t b1 = *(p_in1 + 1U); + float32_t d1 = *(p_in2 + 1U); + + sum_real1 += a1 * c1; + sum_imag1 += b1 * c1; + + p_in1 += 2U; + p_in2 += 2 * b_dst_stride; + + sum_real2 -= b1 * d1; + sum_imag2 += a1 * d1; + + col_cnt--; + } + + sum_real1 += sum_real2; + sum_imag1 += sum_imag2; + + // Store the result in the destination buffer + if constexpr (accumulate) { + (*px).re += sum_real1; + (*px).im += sum_imag1; + } else { + (*px).re = sum_real1; + (*px).im = sum_imag1; + } + px++; + // Update the pointer pIn2 to point to the starting address of the next + // column + j++; + p_in2 = (const float32_t *)p_src_b + 2U * j; + + col--; + } + + // Update the pointer pInA to point to the starting address of the next + // row + i = i + b_dst_stride; + p_in_a = p_in_a + a_stride; + + row_cnt--; + } + + return ARMRAL_SUCCESS; +#endif +} + +#ifdef ARMRAL_ARCH_SVE +inline svfloat32_t __attribute__((always_inline)) svtrn2iq_f32(svfloat32_t a) { + // Interleaves 32-bit floating point numbers at odd 64-bit indices in an + // SVE vector, e.g. given 256-bit vector + // ^: a = [a0, a1, a2, a3, a4, a5, a6, a7, ...] + // ^: returns + // ^: c = [a2, a3, a2, a3, a6, a7, a6, a7, ...] + return svreinterpret_f32_f64( + svtrn2_f64(svreinterpret_f64_f32(a), svreinterpret_f64_f32(a))); +} +#endif + +// Computes c += a b or c = a b depending on whether the value of accumulate is +// true or false. a, b and c are sub-matrices of p_src_a, p_src_b and p_dst, +// respectively, where a, b and c have size 4-by-4. a_stride is the total +// number of columns of matrix p_src_a and b_dst_stride is the total number of +// columns of matrices p_dst_b/p_dst. This function expects row-major input and +// gives row-major output. +template +inline armral_status +cmplx_matmul_4x4_f32(uint16_t a_stride, uint16_t b_dst_stride, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *__restrict p_dst) { + const float32_t *a_ptr = (const float32_t *)p_src_a; + const float32_t *b_ptr = (const float32_t *)p_src_b; + float32_t *out_ptr = (float32_t *)p_dst; + +#if ARMRAL_ARCH_SVE >= 2 + svbool_t p4 = svptrue_pat_b32(SV_VL4); + + svfloat32_t b00 = svld1_f32(p4, &b_ptr[0]); + svfloat32_t b10 = svld1_f32(p4, &b_ptr[4]); + svfloat32_t b01 = svld1_f32(p4, &b_ptr[2 * b_dst_stride]); + svfloat32_t b11 = svld1_f32(p4, &b_ptr[2 * b_dst_stride + 4]); + svfloat32_t b02 = svld1_f32(p4, &b_ptr[4 * b_dst_stride]); + svfloat32_t b12 = svld1_f32(p4, &b_ptr[4 * b_dst_stride + 4]); + svfloat32_t b03 = svld1_f32(p4, &b_ptr[6 * b_dst_stride]); + svfloat32_t b13 = svld1_f32(p4, &b_ptr[6 * b_dst_stride + 4]); + + svfloat32_t cj0; + svfloat32_t cj1; + for (int j = 0; j < 4; j++) { + if constexpr (accumulate) { + cj0 = svld1_f32(p4, &out_ptr[j * 2 * b_dst_stride]); + cj1 = svld1_f32(p4, &out_ptr[4 + j * 2 * b_dst_stride]); + } else { + cj0 = svdup_n_f32(0); + cj1 = svdup_n_f32(0); + } + svfloat32_t a0j = svld1_f32(p4, &a_ptr[j * 2 * a_stride]); + svfloat32_t a1j = svld1_f32(p4, &a_ptr[4 + j * 2 * a_stride]); + cj0 = svcmla_lane_f32(cj0, b00, a0j, 0, 0); + cj0 = svcmla_lane_f32(cj0, b00, a0j, 0, 90); + cj0 = svcmla_lane_f32(cj0, b01, a0j, 1, 0); + cj0 = svcmla_lane_f32(cj0, b01, a0j, 1, 90); + cj0 = svcmla_lane_f32(cj0, b02, a1j, 0, 0); + cj0 = svcmla_lane_f32(cj0, b02, a1j, 0, 90); + cj0 = svcmla_lane_f32(cj0, b03, a1j, 1, 0); + cj0 = svcmla_lane_f32(cj0, b03, a1j, 1, 90); + + cj1 = svcmla_lane_f32(cj1, b10, a0j, 0, 0); + cj1 = svcmla_lane_f32(cj1, b10, a0j, 0, 90); + cj1 = svcmla_lane_f32(cj1, b11, a0j, 1, 0); + cj1 = svcmla_lane_f32(cj1, b11, a0j, 1, 90); + cj1 = svcmla_lane_f32(cj1, b12, a1j, 0, 0); + cj1 = svcmla_lane_f32(cj1, b12, a1j, 0, 90); + cj1 = svcmla_lane_f32(cj1, b13, a1j, 1, 0); + cj1 = svcmla_lane_f32(cj1, b13, a1j, 1, 90); + + svst1_f32(p4, &out_ptr[j * 2 * b_dst_stride], cj0); + svst1_f32(p4, &out_ptr[4 + j * 2 * b_dst_stride], cj1); + } + +#else + if constexpr (accumulate) { + // Accumulate result into p_dst array + __asm__ __volatile__( + "lsl x5, %x[AStride], #3\n" + "lsl x6, %x[BDstStride], #3\n" + "neg x7, x6\n" + + "ld2 {v10.4s, v11.4s}, [%x[BPtr]], x6\n" + "ld2 {v12.4s, v13.4s}, [%x[BPtr]], x6\n" + + "ld2 {v18.4s, v19.4s}, [%x[APtr]], x5\n" + "ld2 {v20.4s, v21.4s}, [%x[APtr]], x5\n" + + "ld2 {v2.4s, v3.4s}, [%x[outPtr]], x6\n" + "ld2 {v4.4s, v5.4s}, [%x[outPtr]], x7\n" + + "fmla v2.4s, v10.4s, v18.s[0]\n" + "fmls v2.4s, v11.4s, v19.s[0]\n" + "fmla v3.4s, v11.4s, v18.s[0]\n" + "fmla v3.4s, v10.4s, v19.s[0]\n" + + "fmla v4.4s, v10.4s, v20.s[0]\n" + "fmls v4.4s, v11.4s, v21.s[0]\n" + "fmla v5.4s, v11.4s, v20.s[0]\n" + "fmla v5.4s, v10.4s, v21.s[0]\n" + + "ld2 {v14.4s, v15.4s}, [%x[BPtr]], x6\n" + "ld2 {v16.4s, v17.4s}, [%x[BPtr]], x6\n" + + "fmla v2.4s, v12.4s, v18.s[1]\n" + "fmls v2.4s, v13.4s, v19.s[1]\n" + "fmla v3.4s, v13.4s, v18.s[1]\n" + "fmla v3.4s, v12.4s, v19.s[1]\n" + + "fmla v4.4s, v12.4s, v20.s[1]\n" + "fmls v4.4s, v13.4s, v21.s[1]\n" + "fmla v5.4s, v13.4s, v20.s[1]\n" + "fmla v5.4s, v12.4s, v21.s[1]\n" + + "fmla v2.4s, v14.4s, v18.s[2]\n" + "fmls v2.4s, v15.4s, v19.s[2]\n" + "fmla v3.4s, v15.4s, v18.s[2]\n" + "fmla v3.4s, v14.4s, v19.s[2]\n" + + "fmla v4.4s, v14.4s, v20.s[2]\n" + "fmls v4.4s, v15.4s, v21.s[2]\n" + "fmla v5.4s, v15.4s, v20.s[2]\n" + "fmla v5.4s, v14.4s, v21.s[2]\n" + + "fmla v2.4s, v16.4s, v18.s[3]\n" + "fmls v2.4s, v17.4s, v19.s[3]\n" + "fmla v3.4s, v17.4s, v18.s[3]\n" + "fmla v3.4s, v16.4s, v19.s[3]\n" + + "st2 {v2.4s, v3.4s}, [%x[outPtr]], x6\n" + "ld2 {v18.4s, v19.4s}, [%x[APtr]], x5\n" + + "fmla v4.4s, v16.4s, v20.s[3]\n" + "fmls v4.4s, v17.4s, v21.s[3]\n" + "fmla v5.4s, v17.4s, v20.s[3]\n" + "fmla v5.4s, v16.4s, v21.s[3]\n" + + "st2 {v4.4s, v5.4s}, [%x[outPtr]], x6\n" + "ld2 {v20.4s, v21.4s}, [%x[APtr]], x5\n" + + "ld2 {v2.4s, v3.4s}, [%x[outPtr]], x6\n" + "ld2 {v4.4s, v5.4s}, [%x[outPtr]], x7\n" + + "fmla v2.4s, v10.4s, v18.s[0]\n" + "fmls v2.4s, v11.4s, v19.s[0]\n" + "fmla v3.4s, v11.4s, v18.s[0]\n" + "fmla v3.4s, v10.4s, v19.s[0]\n" + + "fmla v4.4s, v10.4s, v20.s[0]\n" + "fmls v4.4s, v11.4s, v21.s[0]\n" + "fmla v5.4s, v11.4s, v20.s[0]\n" + "fmla v5.4s, v10.4s, v21.s[0]\n" + + "fmla v2.4s, v12.4s, v18.s[1]\n" + "fmls v2.4s, v13.4s, v19.s[1]\n" + "fmla v3.4s, v13.4s, v18.s[1]\n" + "fmla v3.4s, v12.4s, v19.s[1]\n" + + "fmla v4.4s, v12.4s, v20.s[1]\n" + "fmls v4.4s, v13.4s, v21.s[1]\n" + "fmla v5.4s, v13.4s, v20.s[1]\n" + "fmla v5.4s, v12.4s, v21.s[1]\n" + + "fmla v2.4s, v14.4s, v18.s[2]\n" + "fmls v2.4s, v15.4s, v19.s[2]\n" + "fmla v3.4s, v15.4s, v18.s[2]\n" + "fmla v3.4s, v14.4s, v19.s[2]\n" + + "fmla v4.4s, v14.4s, v20.s[2]\n" + "fmls v4.4s, v15.4s, v21.s[2]\n" + "fmla v5.4s, v15.4s, v20.s[2]\n" + "fmla v5.4s, v14.4s, v21.s[2]\n" + + "fmla v2.4s, v16.4s, v18.s[3]\n" + "fmls v2.4s, v17.4s, v19.s[3]\n" + "fmla v3.4s, v17.4s, v18.s[3]\n" + "fmla v3.4s, v16.4s, v19.s[3]\n" + + "fmla v4.4s, v16.4s, v20.s[3]\n" + "fmls v4.4s, v17.4s, v21.s[3]\n" + "fmla v5.4s, v17.4s, v20.s[3]\n" + "fmla v5.4s, v16.4s, v21.s[3]\n" + + "st2 {v2.4s, v3.4s}, [%x[outPtr]], x6\n" + "st2 {v4.4s, v5.4s}, [%x[outPtr]]\n" + + : [APtr] "+r"(a_ptr), [BPtr] "+r"(b_ptr), [outPtr] "+r"(out_ptr) + + : [AStride] "r"(a_stride), [BDstStride] "r"(b_dst_stride) + + : "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v2", "v3", "v4", "v5", "x5", "x6", "x7", + "cc", "memory"); + } else { + __asm__ __volatile__( + "lsl x5, %x[AStride], #3\n" + "lsl x6, %x[BDstStride], #3\n" + + "ld2 {v10.4s, v11.4s}, [%x[BPtr]], x6\n" + "ld2 {v18.4s, v19.4s}, [%x[APtr]], x5\n" + "ld2 {v20.4s, v21.4s}, [%x[APtr]], x5\n" + "ld2 {v12.4s, v13.4s}, [%x[BPtr]], x6\n" + + "fmul v2.4s, v10.4s, v18.s[0]\n" + "fmls v2.4s, v11.4s, v19.s[0]\n" + "fmul v3.4s, v11.4s, v18.s[0]\n" + "fmla v3.4s, v10.4s, v19.s[0]\n" + + "fmul v4.4s, v10.4s, v20.s[0]\n" + "fmls v4.4s, v11.4s, v21.s[0]\n" + "fmul v5.4s, v11.4s, v20.s[0]\n" + "fmla v5.4s, v10.4s, v21.s[0]\n" + + "ld2 {v14.4s, v15.4s}, [%x[BPtr]], x6\n" + "ld2 {v16.4s, v17.4s}, [%x[BPtr]], x6\n" + + "fmla v2.4s, v12.4s, v18.s[1]\n" + "fmls v2.4s, v13.4s, v19.s[1]\n" + "fmla v3.4s, v13.4s, v18.s[1]\n" + "fmla v3.4s, v12.4s, v19.s[1]\n" + + "fmla v4.4s, v12.4s, v20.s[1]\n" + "fmls v4.4s, v13.4s, v21.s[1]\n" + "fmla v5.4s, v13.4s, v20.s[1]\n" + "fmla v5.4s, v12.4s, v21.s[1]\n" + + "fmla v2.4s, v14.4s, v18.s[2]\n" + "fmls v2.4s, v15.4s, v19.s[2]\n" + "fmla v3.4s, v15.4s, v18.s[2]\n" + "fmla v3.4s, v14.4s, v19.s[2]\n" + + "fmla v4.4s, v14.4s, v20.s[2]\n" + "fmls v4.4s, v15.4s, v21.s[2]\n" + "fmla v5.4s, v15.4s, v20.s[2]\n" + "fmla v5.4s, v14.4s, v21.s[2]\n" + + "fmla v2.4s, v16.4s, v18.s[3]\n" + "fmls v2.4s, v17.4s, v19.s[3]\n" + "fmla v3.4s, v17.4s, v18.s[3]\n" + "fmla v3.4s, v16.4s, v19.s[3]\n" + + "st2 {v2.4s, v3.4s}, [%x[outPtr]], x6\n" + "ld2 {v18.4s, v19.4s}, [%x[APtr]], x5\n" + + "fmla v4.4s, v16.4s, v20.s[3]\n" + "fmls v4.4s, v17.4s, v21.s[3]\n" + "fmla v5.4s, v17.4s, v20.s[3]\n" + "fmla v5.4s, v16.4s, v21.s[3]\n" + + "st2 {v4.4s, v5.4s}, [%x[outPtr]], x6\n" + "ld2 {v20.4s, v21.4s}, [%x[APtr]], x5\n" + + "fmul v2.4s, v10.4s, v18.s[0]\n" + "fmls v2.4s, v11.4s, v19.s[0]\n" + "fmul v3.4s, v11.4s, v18.s[0]\n" + "fmla v3.4s, v10.4s, v19.s[0]\n" + + "fmul v4.4s, v10.4s, v20.s[0]\n" + "fmls v4.4s, v11.4s, v21.s[0]\n" + "fmul v5.4s, v11.4s, v20.s[0]\n" + "fmla v5.4s, v10.4s, v21.s[0]\n" + + "fmla v2.4s, v12.4s, v18.s[1]\n" + "fmls v2.4s, v13.4s, v19.s[1]\n" + "fmla v3.4s, v13.4s, v18.s[1]\n" + "fmla v3.4s, v12.4s, v19.s[1]\n" + + "fmla v4.4s, v12.4s, v20.s[1]\n" + "fmls v4.4s, v13.4s, v21.s[1]\n" + "fmla v5.4s, v13.4s, v20.s[1]\n" + "fmla v5.4s, v12.4s, v21.s[1]\n" + + "fmla v2.4s, v14.4s, v18.s[2]\n" + "fmls v2.4s, v15.4s, v19.s[2]\n" + "fmla v3.4s, v15.4s, v18.s[2]\n" + "fmla v3.4s, v14.4s, v19.s[2]\n" + + "fmla v4.4s, v14.4s, v20.s[2]\n" + "fmls v4.4s, v15.4s, v21.s[2]\n" + "fmla v5.4s, v15.4s, v20.s[2]\n" + "fmla v5.4s, v14.4s, v21.s[2]\n" + + "fmla v2.4s, v16.4s, v18.s[3]\n" + "fmls v2.4s, v17.4s, v19.s[3]\n" + "fmla v3.4s, v17.4s, v18.s[3]\n" + "fmla v3.4s, v16.4s, v19.s[3]\n" + + "fmla v4.4s, v16.4s, v20.s[3]\n" + "fmls v4.4s, v17.4s, v21.s[3]\n" + "fmla v5.4s, v17.4s, v20.s[3]\n" + "fmla v5.4s, v16.4s, v21.s[3]\n" + + "st2 {v2.4s, v3.4s}, [%x[outPtr]], x6\n" + "st2 {v4.4s, v5.4s}, [%x[outPtr]]\n" + + : [APtr] "+r"(a_ptr), [BPtr] "+r"(b_ptr), [outPtr] "+r"(out_ptr) + + : [AStride] "r"(a_stride), [BDstStride] "r"(b_dst_stride) + + : "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v2", "v3", "v4", "v5", "x5", "x6", "cc", + "memory"); + } +#endif + + return ARMRAL_SUCCESS; +} + +// Computes c += a b or c = a b depending on whether the value of accumulate is +// true or false. a, b and c are sub-matrices of p_src_a, p_src_b and p_dst, +// respectively, where a has size 4-by-4 and b and c have size 4-by-8. +// a_stride is the total number of columns of matrix p_src_a and b_dst_stride +// is the total number of columns of matrices p_dst_b/p_dst. This function +// expects row-major input and gives row-major output. +template +inline armral_status +cmplx_matmul_4x4_f32_unroll(uint16_t a_stride, uint16_t b_dst_stride, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *__restrict p_dst) { +#if ARMRAL_ARCH_SVE >= 2 + cmplx_matmul_4x4_f32(a_stride, b_dst_stride, p_src_a, p_src_b, + p_dst); + cmplx_matmul_4x4_f32(a_stride, b_dst_stride, p_src_a, &p_src_b[4], + &p_dst[4]); +#else + const float32_t *a_ptr = (const float32_t *)p_src_a; + const float32_t *b_ptr = (const float32_t *)p_src_b; + float32_t *out_ptr = (float32_t *)p_dst; + + if constexpr (accumulate) { + // Accumulate result into p_dst array + __asm__ __volatile__( + "lsl x5, %x[AStride], #3\n" + "lsl x6, %x[BDstStride], #3\n" + "add x7, %x[BPtr], #32\n" + "mov x8, %x[outPtr]\n" + + "ld2 {v0.4s, v1.4s}, [%x[APtr]], x5\n" + "ld2 {v8.4s, v9.4s}, [%x[BPtr]], x6\n" + "ld2 {v16.4s, v17.4s}, [%x[outPtr]], x6\n" + "ld2 {v18.4s, v19.4s}, [%x[outPtr]], x6\n" + + "fmla v16.4s, v8.4s, v0.s[0]\n" + "fmls v16.4s, v9.4s, v1.s[0]\n" + "fmla v17.4s, v9.4s, v0.s[0]\n" + "fmla v17.4s, v8.4s, v1.s[0]\n" + + "ld2 {v2.4s, v3.4s}, [%x[APtr]], x5\n" + "ld2 {v4.4s, v5.4s}, [%x[APtr]], x5\n" + + "fmla v18.4s, v8.4s, v2.s[0]\n" + "fmls v18.4s, v9.4s, v3.s[0]\n" + "fmla v19.4s, v9.4s, v2.s[0]\n" + "fmla v19.4s, v8.4s, v3.s[0]\n" + + "ld2 {v20.4s, v21.4s}, [%x[outPtr]], x6\n" + "ld2 {v22.4s, v23.4s}, [%x[outPtr]]\n" + + "mov %x[outPtr], x8\n" + + "fmla v20.4s, v8.4s, v4.s[0]\n" + "fmls v20.4s, v9.4s, v5.s[0]\n" + "fmla v21.4s, v9.4s, v4.s[0]\n" + "fmla v21.4s, v8.4s, v5.s[0]\n" + + "ld2 {v6.4s, v7.4s}, [%x[APtr]], x5\n" + "ld2 {v10.4s, v11.4s}, [%x[BPtr]], x6\n" + + "fmla v22.4s, v8.4s, v6.s[0]\n" + "fmls v22.4s, v9.4s, v7.s[0]\n" + "fmla v23.4s, v9.4s, v6.s[0]\n" + "fmla v23.4s, v8.4s, v7.s[0]\n" + + "ld2 {v8.4s, v9.4s}, [x7], x6\n" + "ld2 {v12.4s, v13.4s}, [%x[BPtr]], x6\n" + + "fmla v16.4s, v10.4s, v0.s[1]\n" + "fmls v16.4s, v11.4s, v1.s[1]\n" + "fmla v17.4s, v11.4s, v0.s[1]\n" + "fmla v17.4s, v10.4s, v1.s[1]\n" + + "fmla v18.4s, v10.4s, v2.s[1]\n" + "fmls v18.4s, v11.4s, v3.s[1]\n" + "fmla v19.4s, v11.4s, v2.s[1]\n" + "fmla v19.4s, v10.4s, v3.s[1]\n" + + "fmla v20.4s, v10.4s, v4.s[1]\n" + "fmls v20.4s, v11.4s, v5.s[1]\n" + "fmla v21.4s, v11.4s, v4.s[1]\n" + "fmla v21.4s, v10.4s, v5.s[1]\n" + + "fmla v22.4s, v10.4s, v6.s[1]\n" + "fmls v22.4s, v11.4s, v7.s[1]\n" + "fmla v23.4s, v11.4s, v6.s[1]\n" + "fmla v23.4s, v10.4s, v7.s[1]\n" + + "ld2 {v10.4s, v11.4s}, [x7], x6\n" + "ld2 {v14.4s, v15.4s}, [%x[BPtr]]\n" + + "fmla v16.4s, v12.4s, v0.s[2]\n" + "fmls v16.4s, v13.4s, v1.s[2]\n" + "fmla v17.4s, v13.4s, v0.s[2]\n" + "fmla v17.4s, v12.4s, v1.s[2]\n" + + "fmla v18.4s, v12.4s, v2.s[2]\n" + "fmls v18.4s, v13.4s, v3.s[2]\n" + "fmla v19.4s, v13.4s, v2.s[2]\n" + "fmla v19.4s, v12.4s, v3.s[2]\n" + + "fmla v20.4s, v12.4s, v4.s[2]\n" + "fmls v20.4s, v13.4s, v5.s[2]\n" + "fmla v21.4s, v13.4s, v4.s[2]\n" + "fmla v21.4s, v12.4s, v5.s[2]\n" + + "fmla v22.4s, v12.4s, v6.s[2]\n" + "fmls v22.4s, v13.4s, v7.s[2]\n" + "fmla v23.4s, v13.4s, v6.s[2]\n" + "fmla v23.4s, v12.4s, v7.s[2]\n" + + "fmla v16.4s, v14.4s, v0.s[3]\n" + "fmls v16.4s, v15.4s, v1.s[3]\n" + "fmla v17.4s, v15.4s, v0.s[3]\n" + "fmla v17.4s, v14.4s, v1.s[3]\n" + + "ld2 {v12.4s, v13.4s}, [x7], x6\n" + "st2 {v16.4s, v17.4s}, [%x[outPtr]], x6\n" + + "fmla v18.4s, v14.4s, v2.s[3]\n" + "fmls v18.4s, v15.4s, v3.s[3]\n" + "fmla v19.4s, v15.4s, v2.s[3]\n" + "fmla v19.4s, v14.4s, v3.s[3]\n" + + "fmla v20.4s, v14.4s, v4.s[3]\n" + "fmls v20.4s, v15.4s, v5.s[3]\n" + "fmla v21.4s, v15.4s, v4.s[3]\n" + "fmla v21.4s, v14.4s, v5.s[3]\n" + + "st2 {v18.4s, v19.4s}, [%x[outPtr]], x6\n" + "st2 {v20.4s, v21.4s}, [%x[outPtr]], x6\n" + + "fmla v22.4s, v14.4s, v6.s[3]\n" + "fmls v22.4s, v15.4s, v7.s[3]\n" + "fmla v23.4s, v15.4s, v6.s[3]\n" + "fmla v23.4s, v14.4s, v7.s[3]\n" + + "st2 {v22.4s, v23.4s}, [%x[outPtr]]\n" + "ld2 {v14.4s, v15.4s}, [x7]\n" + + "add %x[outPtr], x8, #32\n" + + "ld2 {v16.4s, v17.4s}, [%x[outPtr]], x6\n" + "ld2 {v18.4s, v19.4s}, [%x[outPtr]], x6\n" + + "fmla v16.4s, v8.4s, v0.s[0]\n" + "fmls v16.4s, v9.4s, v1.s[0]\n" + "fmla v17.4s, v9.4s, v0.s[0]\n" + "fmla v17.4s, v8.4s, v1.s[0]\n" + + "fmla v18.4s, v8.4s, v2.s[0]\n" + "fmls v18.4s, v9.4s, v3.s[0]\n" + "fmla v19.4s, v9.4s, v2.s[0]\n" + "fmla v19.4s, v8.4s, v3.s[0]\n" + + "ld2 {v20.4s, v21.4s}, [%x[outPtr]], x6\n" + "ld2 {v22.4s, v23.4s}, [%x[outPtr]]\n" + "add %x[outPtr], x8, #32\n" + + "fmla v20.4s, v8.4s, v4.s[0]\n" + "fmls v20.4s, v9.4s, v5.s[0]\n" + "fmla v21.4s, v9.4s, v4.s[0]\n" + "fmla v21.4s, v8.4s, v5.s[0]\n" + + "fmla v22.4s, v8.4s, v6.s[0]\n" + "fmls v22.4s, v9.4s, v7.s[0]\n" + "fmla v23.4s, v9.4s, v6.s[0]\n" + "fmla v23.4s, v8.4s, v7.s[0]\n" + + "fmla v16.4s, v10.4s, v0.s[1]\n" + "fmls v16.4s, v11.4s, v1.s[1]\n" + "fmla v17.4s, v11.4s, v0.s[1]\n" + "fmla v17.4s, v10.4s, v1.s[1]\n" + + "fmla v18.4s, v10.4s, v2.s[1]\n" + "fmls v18.4s, v11.4s, v3.s[1]\n" + "fmla v19.4s, v11.4s, v2.s[1]\n" + "fmla v19.4s, v10.4s, v3.s[1]\n" + + "fmla v20.4s, v10.4s, v4.s[1]\n" + "fmls v20.4s, v11.4s, v5.s[1]\n" + "fmla v21.4s, v11.4s, v4.s[1]\n" + "fmla v21.4s, v10.4s, v5.s[1]\n" + + "fmla v22.4s, v10.4s, v6.s[1]\n" + "fmls v22.4s, v11.4s, v7.s[1]\n" + "fmla v23.4s, v11.4s, v6.s[1]\n" + "fmla v23.4s, v10.4s, v7.s[1]\n" + + "fmla v16.4s, v12.4s, v0.s[2]\n" + "fmls v16.4s, v13.4s, v1.s[2]\n" + "fmla v17.4s, v13.4s, v0.s[2]\n" + "fmla v17.4s, v12.4s, v1.s[2]\n" + + "fmla v18.4s, v12.4s, v2.s[2]\n" + "fmls v18.4s, v13.4s, v3.s[2]\n" + "fmla v19.4s, v13.4s, v2.s[2]\n" + "fmla v19.4s, v12.4s, v3.s[2]\n" + + "fmla v20.4s, v12.4s, v4.s[2]\n" + "fmls v20.4s, v13.4s, v5.s[2]\n" + "fmla v21.4s, v13.4s, v4.s[2]\n" + "fmla v21.4s, v12.4s, v5.s[2]\n" + + "fmla v22.4s, v12.4s, v6.s[2]\n" + "fmls v22.4s, v13.4s, v7.s[2]\n" + "fmla v23.4s, v13.4s, v6.s[2]\n" + "fmla v23.4s, v12.4s, v7.s[2]\n" + + "fmla v16.4s, v14.4s, v0.s[3]\n" + "fmls v16.4s, v15.4s, v1.s[3]\n" + "fmla v17.4s, v15.4s, v0.s[3]\n" + "fmla v17.4s, v14.4s, v1.s[3]\n" + + "fmla v18.4s, v14.4s, v2.s[3]\n" + "fmls v18.4s, v15.4s, v3.s[3]\n" + "fmla v19.4s, v15.4s, v2.s[3]\n" + "fmla v19.4s, v14.4s, v3.s[3]\n" + + "st2 {v16.4s, v17.4s}, [%x[outPtr]], x6\n" + "st2 {v18.4s, v19.4s}, [%x[outPtr]], x6\n" + + "fmla v20.4s, v14.4s, v4.s[3]\n" + "fmls v20.4s, v15.4s, v5.s[3]\n" + "fmla v21.4s, v15.4s, v4.s[3]\n" + "fmla v21.4s, v14.4s, v5.s[3]\n" + + "fmla v22.4s, v14.4s, v6.s[3]\n" + "fmls v22.4s, v15.4s, v7.s[3]\n" + "fmla v23.4s, v15.4s, v6.s[3]\n" + "fmla v23.4s, v14.4s, v7.s[3]\n" + + "st2 {v20.4s, v21.4s}, [%x[outPtr]], x6\n" + "st2 {v22.4s, v23.4s}, [%x[outPtr]]\n" + + : [APtr] "+r"(a_ptr), [BPtr] "+r"(b_ptr), [outPtr] "+r"(out_ptr) + + : [AStride] "r"(a_stride), [BDstStride] "r"(b_dst_stride) + + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "x5", "x6", "x7", "x8", "cc", "memory"); + } else { + __asm__ __volatile__( + "lsl x5, %x[AStride], #3\n" + "lsl x6, %x[BDstStride], #3\n" + "add x7, %x[BPtr], #32\n" + "add x8, %x[outPtr], #32\n" + + "ld2 {v0.4s, v1.4s}, [%x[APtr]], x5\n" + "ld2 {v2.4s, v3.4s}, [%x[APtr]], x5\n" + "ld2 {v8.4s, v9.4s}, [%x[BPtr]], x6\n" + "ld2 {v10.4s, v11.4s}, [%x[BPtr]], x6\n" + + "fmul v16.4s, v8.4s, v0.s[0]\n" + "fmls v16.4s, v9.4s, v1.s[0]\n" + "fmul v17.4s, v9.4s, v0.s[0]\n" + "fmla v17.4s, v8.4s, v1.s[0]\n" + + "fmul v18.4s, v8.4s, v2.s[0]\n" + "fmls v18.4s, v9.4s, v3.s[0]\n" + "fmul v19.4s, v9.4s, v2.s[0]\n" + "fmla v19.4s, v8.4s, v3.s[0]\n" + + "ld2 {v4.4s, v5.4s}, [%x[APtr]], x5\n" + "ld2 {v6.4s, v7.4s}, [%x[APtr]], x5\n" + + "fmul v20.4s, v8.4s, v4.s[0]\n" + "fmls v20.4s, v9.4s, v5.s[0]\n" + "fmul v21.4s, v9.4s, v4.s[0]\n" + "fmla v21.4s, v8.4s, v5.s[0]\n" + + "fmul v22.4s, v8.4s, v6.s[0]\n" + "fmls v22.4s, v9.4s, v7.s[0]\n" + "fmul v23.4s, v9.4s, v6.s[0]\n" + "fmla v23.4s, v8.4s, v7.s[0]\n" + + "ld2 {v8.4s, v9.4s}, [x7], x6\n" + "ld2 {v12.4s, v13.4s}, [%x[BPtr]], x6\n" + + "fmla v16.4s, v10.4s, v0.s[1]\n" + "fmls v16.4s, v11.4s, v1.s[1]\n" + "fmla v17.4s, v11.4s, v0.s[1]\n" + "fmla v17.4s, v10.4s, v1.s[1]\n" + + "fmla v18.4s, v10.4s, v2.s[1]\n" + "fmls v18.4s, v11.4s, v3.s[1]\n" + "fmla v19.4s, v11.4s, v2.s[1]\n" + "fmla v19.4s, v10.4s, v3.s[1]\n" + + "fmla v20.4s, v10.4s, v4.s[1]\n" + "fmls v20.4s, v11.4s, v5.s[1]\n" + "fmla v21.4s, v11.4s, v4.s[1]\n" + "fmla v21.4s, v10.4s, v5.s[1]\n" + + "fmla v22.4s, v10.4s, v6.s[1]\n" + "fmls v22.4s, v11.4s, v7.s[1]\n" + "fmla v23.4s, v11.4s, v6.s[1]\n" + "fmla v23.4s, v10.4s, v7.s[1]\n" + + "ld2 {v14.4s, v15.4s}, [%x[BPtr]]\n" + "ld2 {v10.4s, v11.4s}, [x7], x6\n" + + "fmla v16.4s, v12.4s, v0.s[2]\n" + "fmls v16.4s, v13.4s, v1.s[2]\n" + "fmla v17.4s, v13.4s, v0.s[2]\n" + "fmla v17.4s, v12.4s, v1.s[2]\n" + + "fmla v18.4s, v12.4s, v2.s[2]\n" + "fmls v18.4s, v13.4s, v3.s[2]\n" + "fmla v19.4s, v13.4s, v2.s[2]\n" + "fmla v19.4s, v12.4s, v3.s[2]\n" + + "fmla v20.4s, v12.4s, v4.s[2]\n" + "fmls v20.4s, v13.4s, v5.s[2]\n" + "fmla v21.4s, v13.4s, v4.s[2]\n" + "fmla v21.4s, v12.4s, v5.s[2]\n" + + "fmla v22.4s, v12.4s, v6.s[2]\n" + "fmls v22.4s, v13.4s, v7.s[2]\n" + "fmla v23.4s, v13.4s, v6.s[2]\n" + "fmla v23.4s, v12.4s, v7.s[2]\n" + + "fmla v16.4s, v14.4s, v0.s[3]\n" + "fmls v16.4s, v15.4s, v1.s[3]\n" + "fmla v17.4s, v15.4s, v0.s[3]\n" + "fmla v17.4s, v14.4s, v1.s[3]\n" + + "fmla v18.4s, v14.4s, v2.s[3]\n" + "fmls v18.4s, v15.4s, v3.s[3]\n" + "fmla v19.4s, v15.4s, v2.s[3]\n" + "fmla v19.4s, v14.4s, v3.s[3]\n" + + "ld2 {v12.4s, v13.4s}, [x7], x6\n" + "st2 {v16.4s, v17.4s}, [%x[outPtr]], x6\n" + + "fmla v20.4s, v14.4s, v4.s[3]\n" + "fmls v20.4s, v15.4s, v5.s[3]\n" + "fmla v21.4s, v15.4s, v4.s[3]\n" + "fmla v21.4s, v14.4s, v5.s[3]\n" + + "fmla v22.4s, v14.4s, v6.s[3]\n" + "fmls v22.4s, v15.4s, v7.s[3]\n" + "fmla v23.4s, v15.4s, v6.s[3]\n" + "fmla v23.4s, v14.4s, v7.s[3]\n" + + "st2 {v18.4s, v19.4s}, [%x[outPtr]], x6\n" + "st2 {v20.4s, v21.4s}, [%x[outPtr]], x6\n" + + "fmul v16.4s, v8.4s, v0.s[0]\n" + "fmls v16.4s, v9.4s, v1.s[0]\n" + "fmul v17.4s, v9.4s, v0.s[0]\n" + "fmla v17.4s, v8.4s, v1.s[0]\n" + + "st2 {v22.4s, v23.4s}, [%x[outPtr]]\n" + "ld2 {v14.4s, v15.4s}, [x7], x6\n" + + "mov %x[outPtr], x8\n" + + "fmul v18.4s, v8.4s, v2.s[0]\n" + "fmls v18.4s, v9.4s, v3.s[0]\n" + "fmul v19.4s, v9.4s, v2.s[0]\n" + "fmla v19.4s, v8.4s, v3.s[0]\n" + + "fmul v20.4s, v8.4s, v4.s[0]\n" + "fmls v20.4s, v9.4s, v5.s[0]\n" + "fmul v21.4s, v9.4s, v4.s[0]\n" + "fmla v21.4s, v8.4s, v5.s[0]\n" + + "fmul v22.4s, v8.4s, v6.s[0]\n" + "fmls v22.4s, v9.4s, v7.s[0]\n" + "fmul v23.4s, v9.4s, v6.s[0]\n" + "fmla v23.4s, v8.4s, v7.s[0]\n" + + "fmla v16.4s, v10.4s, v0.s[1]\n" + "fmls v16.4s, v11.4s, v1.s[1]\n" + "fmla v17.4s, v11.4s, v0.s[1]\n" + "fmla v17.4s, v10.4s, v1.s[1]\n" + + "fmla v18.4s, v10.4s, v2.s[1]\n" + "fmls v18.4s, v11.4s, v3.s[1]\n" + "fmla v19.4s, v11.4s, v2.s[1]\n" + "fmla v19.4s, v10.4s, v3.s[1]\n" + + "fmla v20.4s, v10.4s, v4.s[1]\n" + "fmls v20.4s, v11.4s, v5.s[1]\n" + "fmla v21.4s, v11.4s, v4.s[1]\n" + "fmla v21.4s, v10.4s, v5.s[1]\n" + + "fmla v22.4s, v10.4s, v6.s[1]\n" + "fmls v22.4s, v11.4s, v7.s[1]\n" + "fmla v23.4s, v11.4s, v6.s[1]\n" + "fmla v23.4s, v10.4s, v7.s[1]\n" + + "fmla v16.4s, v12.4s, v0.s[2]\n" + "fmls v16.4s, v13.4s, v1.s[2]\n" + "fmla v17.4s, v13.4s, v0.s[2]\n" + "fmla v17.4s, v12.4s, v1.s[2]\n" + + "fmla v18.4s, v12.4s, v2.s[2]\n" + "fmls v18.4s, v13.4s, v3.s[2]\n" + "fmla v19.4s, v13.4s, v2.s[2]\n" + "fmla v19.4s, v12.4s, v3.s[2]\n" + + "fmla v20.4s, v12.4s, v4.s[2]\n" + "fmls v20.4s, v13.4s, v5.s[2]\n" + "fmla v21.4s, v13.4s, v4.s[2]\n" + "fmla v21.4s, v12.4s, v5.s[2]\n" + + "fmla v22.4s, v12.4s, v6.s[2]\n" + "fmls v22.4s, v13.4s, v7.s[2]\n" + "fmla v23.4s, v13.4s, v6.s[2]\n" + "fmla v23.4s, v12.4s, v7.s[2]\n" + + "fmla v16.4s, v14.4s, v0.s[3]\n" + "fmls v16.4s, v15.4s, v1.s[3]\n" + "fmla v17.4s, v15.4s, v0.s[3]\n" + "fmla v17.4s, v14.4s, v1.s[3]\n" + + "fmla v18.4s, v14.4s, v2.s[3]\n" + "fmls v18.4s, v15.4s, v3.s[3]\n" + "fmla v19.4s, v15.4s, v2.s[3]\n" + "fmla v19.4s, v14.4s, v3.s[3]\n" + + "st2 {v16.4s, v17.4s}, [%x[outPtr]], x6\n" + "st2 {v18.4s, v19.4s}, [%x[outPtr]], x6\n" + + "fmla v20.4s, v14.4s, v4.s[3]\n" + "fmls v20.4s, v15.4s, v5.s[3]\n" + "fmla v21.4s, v15.4s, v4.s[3]\n" + "fmla v21.4s, v14.4s, v5.s[3]\n" + + "fmla v22.4s, v14.4s, v6.s[3]\n" + "fmls v22.4s, v15.4s, v7.s[3]\n" + "fmla v23.4s, v15.4s, v6.s[3]\n" + "fmla v23.4s, v14.4s, v7.s[3]\n" + + "st2 {v20.4s, v21.4s}, [%x[outPtr]], x6\n" + "st2 {v22.4s, v23.4s}, [%x[outPtr]]\n" + + : [APtr] "+r"(a_ptr), [BPtr] "+r"(b_ptr), [outPtr] "+r"(out_ptr) + + : [AStride] "r"(a_stride), [BDstStride] "r"(b_dst_stride) + + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "x5", "x6", "x7", "x8", "cc", "memory"); + } +#endif + + return ARMRAL_SUCCESS; +} + +inline void use_general_matmul(uint16_t msize, uint16_t nsize, uint16_t ksize, + uint16_t n, uint16_t k, uint16_t kbi, + uint16_t ki, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *__restrict p_dst) { + if (kbi == 0 && ki == 0) { + // This is the first pass, don't accumulate into dst array + cmplx_matmul_f32(msize, nsize, ksize, k, n, p_src_a, p_src_b, p_dst); + } else { + // Accumulate into dst array + cmplx_matmul_f32(msize, nsize, ksize, k, n, p_src_a, p_src_b, p_dst); + } +} + +inline void matmul_inner(uint16_t m, uint16_t n, uint16_t k, uint16_t kbi, + uint16_t kb_size, uint16_t nb_size, + uint16_t rem_k_inner, uint16_t rem_n_inner, + uint16_t rem_m_inner, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *__restrict p_dst) { + // For every 4 rows of A/C + for (int32_t mi = 0; mi < (m - 3); mi += 4) { + uint32_t a_idx = mi * k; + uint32_t dst_idx = mi * n; + for (int32_t ki = 0; ki < (kb_size - 3); ki += 4) { + // Unrolling 4x4 inner matmul means we do 8 cols of B/C at a time + int32_t ni = 0; + for (; ni < (nb_size - 7); ni += 8) { + if (kbi == 0 && ki == 0) { + // This is the first pass, don't accumulate into dst array + cmplx_matmul_4x4_f32_unroll( + k, n, &p_src_a[a_idx], &p_src_b[ni], &p_dst[dst_idx + ni]); + } else { + // Accumulate into dst array + cmplx_matmul_4x4_f32_unroll(k, n, &p_src_a[a_idx + ki], + &p_src_b[ki * n + ni], + &p_dst[dst_idx + ni]); + } + } + // If there are at least 4 cols of B/C remaining, use the 4x4 inner kernel + for (; ni < (nb_size - 3); ni += 4) { + if (kbi == 0 && ki == 0) { + // This is the first pass, don't accumulate into dst array + cmplx_matmul_4x4_f32(k, n, &p_src_a[a_idx], &p_src_b[ni], + &p_dst[dst_idx + ni]); + } else { + // Accumulate into dst array + cmplx_matmul_4x4_f32(k, n, &p_src_a[a_idx + ki], + &p_src_b[ki * n + ni], + &p_dst[dst_idx + ni]); + } + } + // Do leftover ns + if (rem_n_inner != 0U) { + ni = nb_size - rem_n_inner; + use_general_matmul(4, rem_n_inner, 4, n, k, kbi, ki, + &p_src_a[a_idx + ki], &p_src_b[ki * n + ni], + &p_dst[dst_idx + ni]); + } + } + // Do leftover ks + if (rem_k_inner != 0U) { + uint16_t ki = kb_size - rem_k_inner; + use_general_matmul(4, nb_size, rem_k_inner, n, k, kbi, ki, + &p_src_a[a_idx + ki], &p_src_b[ki * n], + &p_dst[dst_idx]); + } + } + // Do leftover ms + if (rem_m_inner != 0U) { + uint16_t mi = m - rem_m_inner; + use_general_matmul(rem_m_inner, nb_size, kb_size, n, k, kbi, 0, + &p_src_a[mi * k], p_src_b, &p_dst[mi * n]); + } +} + +inline void matmul_n_block(uint16_t m, uint16_t n, uint16_t k, uint16_t kb_size, + uint16_t kbi, uint16_t nb, uint16_t rem_n, + uint16_t n_idx, uint16_t rem_k_inner, + uint16_t rem_m_inner, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *__restrict p_dst) { + uint16_t rem_n_inner = nb % 4; + for (int32_t nbi = 0; nbi < n - (nb - 1); nbi += nb) { + matmul_inner(m, n, k, kbi, kb_size, nb, rem_k_inner, rem_n_inner, + rem_m_inner, &p_src_a[kbi], &p_src_b[kbi * n + nbi], + &p_dst[nbi]); + } + if (rem_n != 0U) { + rem_n_inner = rem_n % 4; + matmul_inner(m, n, k, kbi, kb_size, rem_n, rem_k_inner, rem_n_inner, + rem_m_inner, &p_src_a[kbi], &p_src_b[kbi * n + n_idx], + &p_dst[n_idx]); + } +} + +} // anonymous namespace + +armral_status +armral_cmplx_mat_mult_2x2_f32(const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *p_dst) { + const float32_t *a_ptr = (const float32_t *)p_src_a; + const float32_t *b_ptr = (const float32_t *)p_src_b; + float32_t *out_ptr = (float32_t *)p_dst; + +#ifdef ARMRAL_ARCH_SVE + svbool_t p4 = svptrue_pat_b32(SV_VL4); + svfloat32_t a0 = svld1_f32(p4, &a_ptr[0]); + svfloat32_t a1 = svld1_f32(p4, &a_ptr[4]); + svfloat32_t b0 = svld1_f32(p4, &b_ptr[0]); + svfloat32_t b1 = svld1_f32(p4, &b_ptr[4]); + svfloat32_t c0 = svdup_n_f32(0); + svfloat32_t c1 = svdup_n_f32(0); + + c0 = svcmla_lane_f32(c0, a0, b0, 0, 0); + c0 = svcmla_lane_f32(c0, a0, b0, 0, 90); + c0 = svcmla_lane_f32(c0, a1, b0, 1, 0); + c0 = svcmla_lane_f32(c0, a1, b0, 1, 90); + c1 = svcmla_lane_f32(c1, a0, b1, 0, 0); + c1 = svcmla_lane_f32(c1, a0, b1, 0, 90); + c1 = svcmla_lane_f32(c1, a1, b1, 1, 0); + c1 = svcmla_lane_f32(c1, a1, b1, 1, 90); + + svst1_f32(p4, &out_ptr[0], c0); + svst1_f32(p4, &out_ptr[4], c1); +#else + float32x2x2_t a_col[2]; + float32x2x2_t b[2]; + float32x2x2_t result[2]; + + a_col[0] = vld2_f32(a_ptr); + a_ptr = a_ptr + 4; + + b[0] = vld2_f32(b_ptr); + b_ptr = b_ptr + 4; + + // result[0] 4 rows elem 1 RE * first column elem 1 RE + result[0].val[0] = vmul_lane_f32(a_col[0].val[0], b[0].val[0], 0); + // result[0] 4 rows elem 1 IM * first column elem 1 IM + result[0].val[0] = + vfms_lane_f32(result[0].val[0], a_col[0].val[1], b[0].val[1], 0); + b[1] = vld2_f32(b_ptr); + b_ptr = b_ptr + 4; + // result[1] 4 rows elem 1 IM * first row elem 1 RE + result[0].val[1] = vmul_lane_f32(a_col[0].val[1], b[0].val[0], 0); + // result[1] 4 rows elem 1 RE * first row elem 1 IM + result[0].val[1] = + vfma_lane_f32(result[0].val[1], a_col[0].val[0], b[0].val[1], 0); + a_col[1] = vld2_f32(a_ptr); + a_ptr = a_ptr + 4; + + // result[1].val[0] 4 rows elem 1 RE * second row elem 1 RE + result[1].val[0] = vmul_lane_f32(a_col[0].val[0], b[1].val[0], 0); + // result[1].val[0] 4 rows elem 1 IM * second row elem 1 IM + result[1].val[0] = + vfms_lane_f32(result[1].val[0], a_col[0].val[1], b[1].val[1], 0); + result[1].val[1] = vmul_lane_f32(a_col[0].val[1], b[1].val[0], 0); + result[1].val[1] = + vfma_lane_f32(result[1].val[1], a_col[0].val[0], b[1].val[1], 0); + + // result[0] 4 rows elem 2 RE * first row elem 2 RE + result[0].val[0] = + vfma_lane_f32(result[0].val[0], a_col[1].val[0], b[0].val[0], 1); + // result[0] 4 rows elem 2 IM * first row elem 2 IM + result[0].val[0] = + vfms_lane_f32(result[0].val[0], a_col[1].val[1], b[0].val[1], 1); + // result[1] 4 rows elem 2 IM * first row elem 2 RE + result[0].val[1] = + vfma_lane_f32(result[0].val[1], a_col[1].val[1], b[0].val[0], 1); + // result[1] 4 rows elem 2 RE * first row elem 2 IM + result[0].val[1] = + vfma_lane_f32(result[0].val[1], a_col[1].val[0], b[0].val[1], 1); + + // result[0] 4 rows elem 2 RE * second row elem 2 RE + result[1].val[0] = + vfma_lane_f32(result[1].val[0], a_col[1].val[0], b[1].val[0], 1); + // result[0] 4 rows elem 2 IM * second row elem 2 IM + result[1].val[0] = + vfms_lane_f32(result[1].val[0], a_col[1].val[1], b[1].val[1], 1); + // result[1] 4 rows elem 2 IM * second row elem 2 RE + result[1].val[1] = + vfma_lane_f32(result[1].val[1], a_col[1].val[1], b[1].val[0], 1); + // result[1] 4 rows elem 2 RE * second row elem 2 IM + result[1].val[1] = + vfma_lane_f32(result[1].val[1], a_col[1].val[0], b[1].val[1], 1); + + vst2_f32(out_ptr, result[0]); + out_ptr = out_ptr + 4; + + vst2_f32(out_ptr, result[1]); + out_ptr = out_ptr + 4; +#endif + + return ARMRAL_SUCCESS; +} + +armral_status armral_cmplx_mat_mult_2x2_f32_iq( + const float32_t *__restrict src_a_re, const float32_t *__restrict src_a_im, + const float32_t *__restrict src_b_re, const float32_t *__restrict src_b_im, + float32_t *dst_re, float32_t *dst_im) { + +#ifdef ARMRAL_ARCH_SVE + svbool_t p4 = svptrue_pat_b32(SV_VL4); + svfloat32_t a_re = svld1_f32(p4, src_a_re); + svfloat32_t a_im = svld1_f32(p4, src_a_im); + svfloat32_t b_re = svld1_f32(p4, src_b_re); + svfloat32_t b_im = svld1_f32(p4, src_b_im); + + svfloat32_t c_re; + svfloat32_t c_im; + + svfloat32_t tmp_a_re = svtrn2iq_f32(a_re); + svfloat32_t tmp_a_im = svtrn2iq_f32(a_im); + svfloat32_t tmp_b_re = svtrn2(b_re, b_re); + svfloat32_t tmp_b_im = svtrn2(b_im, b_im); + + c_re = svmul_f32_x(p4, tmp_a_re, tmp_b_re); + c_re = svmls_f32_x(p4, c_re, tmp_a_im, tmp_b_im); + c_re = svcmla_lane_f32(c_re, b_re, a_re, 0, 0); + c_re = svcmla_lane_f32(c_re, b_im, a_im, 0, 180); + + c_im = svmul_f32_x(p4, tmp_a_re, tmp_b_im); + c_im = svmla_f32_x(p4, c_im, tmp_a_im, tmp_b_re); + c_im = svcmla_lane_f32(c_im, b_im, a_re, 0, 0); + c_im = svcmla_lane_f32(c_im, b_re, a_im, 0, 0); + + svst1_f32(p4, dst_re, c_re); + svst1_f32(p4, dst_im, c_im); +#else + + float32x2_t a_col0_re = vld1_f32(&src_a_re[0]); + float32x2_t a_col0_im = vld1_f32(&src_a_im[0]); + float32x2_t a_col1_re = vld1_f32(&src_a_re[2]); + float32x2_t a_col1_im = vld1_f32(&src_a_im[2]); + float32x2_t b0_re = vld1_f32(&src_b_re[0]); + float32x2_t b0_im = vld1_f32(&src_b_im[0]); + float32x2_t b1_re = vld1_f32(&src_b_re[2]); + float32x2_t b1_im = vld1_f32(&src_b_im[2]); + + float32x2x2_t result[2]; + result[0].val[0] = vmul_lane_f32(a_col0_re, b0_re, 0); + result[0].val[0] = vfms_lane_f32(result[0].val[0], a_col0_im, b0_im, 0); + result[0].val[1] = vmul_lane_f32(a_col0_im, b0_re, 0); + result[0].val[1] = vfma_lane_f32(result[0].val[1], a_col0_re, b0_im, 0); + + result[1].val[0] = vmul_lane_f32(a_col0_re, b1_re, 0); + result[1].val[0] = vfms_lane_f32(result[1].val[0], a_col0_im, b1_im, 0); + result[1].val[1] = vmul_lane_f32(a_col0_im, b1_re, 0); + result[1].val[1] = vfma_lane_f32(result[1].val[1], a_col0_re, b1_im, 0); + + result[0].val[0] = vfma_lane_f32(result[0].val[0], a_col1_re, b0_re, 1); + result[0].val[0] = vfms_lane_f32(result[0].val[0], a_col1_im, b0_im, 1); + result[0].val[1] = vfma_lane_f32(result[0].val[1], a_col1_im, b0_re, 1); + result[0].val[1] = vfma_lane_f32(result[0].val[1], a_col1_re, b0_im, 1); + + result[1].val[0] = vfma_lane_f32(result[1].val[0], a_col1_re, b1_re, 1); + result[1].val[0] = vfms_lane_f32(result[1].val[0], a_col1_im, b1_im, 1); + result[1].val[1] = vfma_lane_f32(result[1].val[1], a_col1_im, b1_re, 1); + result[1].val[1] = vfma_lane_f32(result[1].val[1], a_col1_re, b1_im, 1); + + vst1_f32(&dst_re[0], result[0].val[0]); + vst1_f32(&dst_im[0], result[0].val[1]); + vst1_f32(&dst_re[2], result[1].val[0]); + vst1_f32(&dst_im[2], result[1].val[1]); +#endif + + return ARMRAL_SUCCESS; +} + +armral_status +armral_cmplx_mat_mult_4x4_f32(const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *p_dst) { + // Note: the a/b flip is intentional since the 4x4 function expects + // row-major input, making all matrices transposed + // i.e. C = A^T * B^T = (B * A)^T + // Unfortunately clang-tidy thinks this is an error, because + // the function has parameters named p_src_a and p_src_b + // so we must disable checking for swapped arguments here + // NOLINTBEGIN(readability-suspicious-call-argument) + return cmplx_matmul_4x4_f32(4, 4, p_src_b, p_src_a, p_dst); + // NOLINTEND(readability-suspicious-call-argument) +} + +armral_status armral_cmplx_mat_mult_4x4_f32_iq( + const float32_t *__restrict src_a_re, const float32_t *__restrict src_a_im, + const float32_t *__restrict src_b_re, const float32_t *__restrict src_b_im, + float32_t *dst_re, float32_t *dst_im) { +#ifdef ARMRAL_ARCH_SVE + svbool_t p4 = svptrue_pat_b32(SV_VL4); + + svfloat32_t a_col0_re = svld1_f32(p4, &src_a_re[0 * 4]); + svfloat32_t a_col1_re = svld1_f32(p4, &src_a_re[1 * 4]); + svfloat32_t a_col2_re = svld1_f32(p4, &src_a_re[2 * 4]); + svfloat32_t a_col3_re = svld1_f32(p4, &src_a_re[3 * 4]); + svfloat32_t a_col0_im = svld1_f32(p4, &src_a_im[0 * 4]); + svfloat32_t a_col1_im = svld1_f32(p4, &src_a_im[1 * 4]); + svfloat32_t a_col2_im = svld1_f32(p4, &src_a_im[2 * 4]); + svfloat32_t a_col3_im = svld1_f32(p4, &src_a_im[3 * 4]); + + svfloat32_t c_re; + svfloat32_t c_im; + + for (int j = 0; j < 4; j++) { + svfloat32_t b_re = svld1_f32(p4, &src_b_re[j * 4]); + svfloat32_t b_im = svld1_f32(p4, &src_b_im[j * 4]); + + c_re = svmul_lane_f32(a_col0_re, b_re, 0); + c_re = svmla_lane_f32(c_re, a_col1_re, b_re, 1); + c_re = svmla_lane_f32(c_re, a_col2_re, b_re, 2); + c_re = svmla_lane_f32(c_re, a_col3_re, b_re, 3); + c_re = svmls_lane_f32(c_re, a_col0_im, b_im, 0); + c_re = svmls_lane_f32(c_re, a_col1_im, b_im, 1); + c_re = svmls_lane_f32(c_re, a_col2_im, b_im, 2); + c_re = svmls_lane_f32(c_re, a_col3_im, b_im, 3); + + c_im = svmul_lane_f32(a_col0_re, b_im, 0); + c_im = svmla_lane_f32(c_im, a_col1_im, b_re, 1); + c_im = svmla_lane_f32(c_im, a_col2_re, b_im, 2); + c_im = svmla_lane_f32(c_im, a_col3_im, b_re, 3); + c_im = svmla_lane_f32(c_im, a_col0_im, b_re, 0); + c_im = svmla_lane_f32(c_im, a_col1_re, b_im, 1); + c_im = svmla_lane_f32(c_im, a_col2_im, b_re, 2); + c_im = svmla_lane_f32(c_im, a_col3_re, b_im, 3); + + svst1_f32(p4, &dst_re[j * 4], c_re); + svst1_f32(p4, &dst_im[j * 4], c_im); + } + +#else + const float32_t *a_ptr_re = (const float32_t *)src_a_re; + const float32_t *a_ptr_im = (const float32_t *)src_a_im; + const float32_t *b_ptr_re = (const float32_t *)src_b_re; + const float32_t *b_ptr_im = (const float32_t *)src_b_im; + float32_t *out_ptr_re = dst_re; + float32_t *out_ptr_im = dst_im; + __asm__ __volatile__( + + "ld1 {v10.4s}, [%x[APtr_re]], #16\n" + "ld1 {v11.4s}, [%x[APtr_im]], #16\n" + + "ld1 {v18.4s}, [%x[BPtr_re]], #16\n" + "ld1 {v19.4s}, [%x[BPtr_im]], #16\n" + + "ld1 {v20.4s}, [%x[BPtr_re]], #16\n" + "ld1 {v21.4s}, [%x[BPtr_im]], #16\n" + + "fmul v2.4s, v10.4s, v18.s[0]\n" + "fmls v2.4s, v11.4s, v19.s[0]\n" + "ld1 {v12.4s}, [%x[APtr_re]], #16\n" + "ld1 {v13.4s}, [%x[APtr_im]], #16\n" + "fmul v4.4s, v10.4s, v20.s[0]\n" + "fmls v4.4s, v11.4s, v21.s[0]\n" + "ld1 {v14.4s}, [%x[APtr_re]], #16\n" + "ld1 {v15.4s}, [%x[APtr_im]], #16\n" + "fmul v3.4s, v11.4s, v18.s[0]\n" + "fmla v3.4s, v10.4s, v19.s[0]\n" + "ld1 {v16.4s}, [%x[APtr_re]], #16\n" + "ld1 {v17.4s}, [%x[APtr_im]], #16\n" + "fmul v5.4s, v11.4s, v20.s[0]\n" + "fmla v5.4s, v10.4s, v21.s[0]\n" + "fmla v2.4s, v12.4s, v18.s[1]\n" + "fmls v2.4s, v13.4s, v19.s[1]\n" + "fmla v3.4s, v13.4s, v18.s[1]\n" + "fmla v3.4s, v12.4s, v19.s[1]\n" + "fmla v4.4s, v12.4s, v20.s[1]\n" + "fmls v4.4s, v13.4s, v21.s[1]\n" + "fmla v5.4s, v13.4s, v20.s[1]\n" + "fmla v5.4s, v12.4s, v21.s[1]\n" + + "fmla v2.4s, v14.4s, v18.s[2]\n" + "fmls v2.4s, v15.4s, v19.s[2]\n" + "fmla v3.4s, v15.4s, v18.s[2]\n" + "fmla v3.4s, v14.4s, v19.s[2]\n" + "fmla v4.4s, v14.4s, v20.s[2]\n" + "fmls v4.4s, v15.4s, v21.s[2]\n" + "fmla v5.4s, v15.4s, v20.s[2]\n" + "fmla v5.4s, v14.4s, v21.s[2]\n" + + "fmla v2.4s, v16.4s, v18.s[3]\n" + "fmls v2.4s, v17.4s, v19.s[3]\n" + "fmla v3.4s, v17.4s, v18.s[3]\n" + "fmla v3.4s, v16.4s, v19.s[3]\n" + "st1 {v2.4s}, [%x[outPtr_re]], #16\n" + "st1 {v3.4s}, [%x[outPtr_im]], #16\n" + "fmla v4.4s, v16.4s, v20.s[3]\n" + "fmls v4.4s, v17.4s, v21.s[3]\n" + "ld1 {v18.4s}, [%x[BPtr_re]], #16\n" + "ld1 {v19.4s}, [%x[BPtr_im]], #16\n" + "fmla v5.4s, v17.4s, v20.s[3]\n" + "fmla v5.4s, v16.4s, v21.s[3]\n" + + "st1 {v4.4s}, [%x[outPtr_re]], #16\n" + "st1 {v5.4s}, [%x[outPtr_im]], #16\n" + "ld1 {v20.4s}, [%x[BPtr_re]], #16\n" + "ld1 {v21.4s}, [%x[BPtr_im]], #16\n" + "fmul v2.4s, v10.4s, v18.s[0]\n" + "fmls v2.4s, v11.4s, v19.s[0]\n" + "fmul v4.4s, v10.4s, v20.s[0]\n" + "fmls v4.4s, v11.4s, v21.s[0]\n" + "fmul v3.4s, v11.4s, v18.s[0]\n" + "fmla v3.4s, v10.4s, v19.s[0]\n" + "fmul v5.4s, v11.4s, v20.s[0]\n" + "fmla v5.4s, v10.4s, v21.s[0]\n" + + "fmla v2.4s, v12.4s, v18.s[1]\n" + "fmls v2.4s, v13.4s, v19.s[1]\n" + "fmla v3.4s, v13.4s, v18.s[1]\n" + "fmla v3.4s, v12.4s, v19.s[1]\n" + "fmla v4.4s, v12.4s, v20.s[1]\n" + "fmls v4.4s, v13.4s, v21.s[1]\n" + "fmla v5.4s, v13.4s, v20.s[1]\n" + "fmla v5.4s, v12.4s, v21.s[1]\n" + + "fmla v2.4s, v14.4s, v18.s[2]\n" + "fmls v2.4s, v15.4s, v19.s[2]\n" + "fmla v3.4s, v15.4s, v18.s[2]\n" + "fmla v3.4s, v14.4s, v19.s[2]\n" + "fmla v4.4s, v14.4s, v20.s[2]\n" + "fmls v4.4s, v15.4s, v21.s[2]\n" + "fmla v5.4s, v15.4s, v20.s[2]\n" + "fmla v5.4s, v14.4s, v21.s[2]\n" + + "fmla v2.4s, v16.4s, v18.s[3]\n" + "fmls v2.4s, v17.4s, v19.s[3]\n" + "fmla v3.4s, v17.4s, v18.s[3]\n" + "fmla v3.4s, v16.4s, v19.s[3]\n" + "fmla v4.4s, v16.4s, v20.s[3]\n" + "fmls v4.4s, v17.4s, v21.s[3]\n" + "fmla v5.4s, v17.4s, v20.s[3]\n" + "fmla v5.4s, v16.4s, v21.s[3]\n" + + "st1 {v2.4s}, [%x[outPtr_re]], #16\n" + "st1 {v3.4s}, [%x[outPtr_im]], #16\n" + "st1 {v4.4s}, [%x[outPtr_re]], #16\n" + "st1 {v5.4s}, [%x[outPtr_im]], #16\n" + + : [APtr_re] "+r"(a_ptr_re), [APtr_im] "+r"(a_ptr_im), + [BPtr_re] "+r"(b_ptr_re), [BPtr_im] "+r"(b_ptr_im), + [outPtr_re] "+r"(out_ptr_re), [outPtr_im] "+r"(out_ptr_im) + + : + + : "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v2", "v3", "v4", "v5", "cc"); +#endif + + return ARMRAL_SUCCESS; +} + +armral_status +armral_cmplx_matmul_f32(const uint16_t m, const uint16_t n, const uint16_t k, + const armral_cmplx_f32_t *__restrict p_src_a, + const armral_cmplx_f32_t *__restrict p_src_b, + armral_cmplx_f32_t *p_dst) { + // Note: the a/b flip is intentional since the 2x2 function expects + // column-major input, making all matrices transposed + // i.e. C = A^T * B^T = (B * A)^T + // Unfortunately clang-tidy thinks this is an error, because + // the function has parameters named p_src_a and p_src_b + // so we must disable checking for swapped arguments here + // NOLINTBEGIN(readability-suspicious-call-argument) + if (m == 2 && n == 2 && k == 2) { + return armral_cmplx_mat_mult_2x2_f32(p_src_b, p_src_a, p_dst); + } + // NOLINTEND(readability-suspicious-call-argument) + if (m == 4 && n == 4 && k == 4) { + // This function expects row-major input so no need to flip a/b here + return cmplx_matmul_4x4_f32(4, 4, p_src_a, p_src_b, p_dst); + } + + constexpr uint32_t kb = 72; + constexpr uint32_t nb = 8; + const bool do_k_blocking = k >= kb; + const bool do_n_blocking = n >= nb; + + if (!do_k_blocking && !do_n_blocking) { + // No blocking + return cmplx_matmul_f32(m, n, k, k, n, p_src_a, p_src_b, p_dst); + } + // Set up parameters for 4-by-4 inner blocking + uint16_t rem_m_inner = m % 4; + uint16_t rem_n = n % nb; + uint32_t n_idx = (n / nb) * nb; + uint16_t rem_k_inner = kb % 4; + // k blocking + for (int32_t kbi = 0; kbi < int32_t(k) - (int32_t(kb) - 1); kbi += kb) { + // n blocking + matmul_n_block(m, n, k, kb, kbi, nb, rem_n, n_idx, rem_k_inner, rem_m_inner, + p_src_a, p_src_b, p_dst); + } + // Clean up remaining ks + uint16_t rem_k = k % kb; + if (rem_k != 0U) { + uint32_t kbi = (k / kb) * kb; + rem_k_inner = rem_k % 4; + matmul_n_block(m, n, k, rem_k, kbi, nb, rem_n, n_idx, rem_k_inner, + rem_m_inner, p_src_a, p_src_b, p_dst); + } + return ARMRAL_SUCCESS; +} diff --git a/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_i16.cpp b/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_i16.cpp new file mode 100644 index 0000000..26c11d3 --- /dev/null +++ b/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_i16.cpp @@ -0,0 +1,14 @@ +/* + Arm RAN Acceleration Library + SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates +*/ +#include "armral.h" +#include "cmplx_matmul_i16.hpp" + +armral_status +armral_cmplx_matmul_i16(const uint16_t m, const uint16_t n, const uint16_t k, + const armral_cmplx_int16_t *__restrict p_src_a, + const armral_cmplx_int16_t *__restrict p_src_b, + armral_cmplx_int16_t *p_dst) { + return cmplx_matmul_i16(m, n, k, p_src_a, p_src_b, p_dst); +} diff --git a/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_i16_32bit.cpp b/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_i16_32bit.cpp new file mode 100644 index 0000000..b0fe41f --- /dev/null +++ b/src/BasicMathFun/MatrixMult/arm_cmplx_matmul_i16_32bit.cpp @@ -0,0 +1,15 @@ +/* + Arm RAN Acceleration Library + SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates +*/ +#include "armral.h" +#include "cmplx_matmul_i16_32bit.hpp" + +armral_status +armral_cmplx_matmul_i16_32bit(const uint16_t m, const uint16_t n, + const uint16_t k, + const armral_cmplx_int16_t *__restrict p_src_a, + const armral_cmplx_int16_t *__restrict p_src_b, + armral_cmplx_int16_t *p_dst) { + return cmplx_matmul_i16_32bit(m, n, k, p_src_a, p_src_b, p_dst); +} diff --git a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_i16.c b/src/BasicMathFun/MatrixMult/cmplx_matmul_i16.hpp similarity index 79% rename from src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_i16.c rename to src/BasicMathFun/MatrixMult/cmplx_matmul_i16.hpp index 8aa33c5..ebabc55 100644 --- a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_i16.c +++ b/src/BasicMathFun/MatrixMult/cmplx_matmul_i16.hpp @@ -2,10 +2,14 @@ Arm RAN Acceleration Library SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ +#pragma once + #include "armral.h" #include "intrinsics.h" -static inline int16x4_t __attribute__((always_inline)) +namespace { + +inline int16x4_t __attribute__((always_inline)) vzip1_u16x2(int16x4_t a, int16x4_t b) { // should be zip1 c.2s, a.2s, b.2s // but this expression appears to give better performance @@ -13,7 +17,7 @@ vzip1_u16x2(int16x4_t a, int16x4_t b) { vset_lane_u32(vreinterpret_u32_s16(b)[0], vreinterpret_u32_s16(a), 1)); } -static inline int16x4_t __attribute__((always_inline)) +inline int16x4_t __attribute__((always_inline)) vzip2_u16x2(int16x4_t a, int16x4_t b) { // should be zip2 c.2s, a.2s, b.2s // but this expression appears to give better performance @@ -21,40 +25,40 @@ vzip2_u16x2(int16x4_t a, int16x4_t b) { vset_lane_u32(vreinterpret_u32_s16(a)[1], vreinterpret_u32_s16(b), 0)); } -armral_status -armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, - const armral_cmplx_int16_t *restrict p_src_a, - const armral_cmplx_int16_t *restrict p_src_b, - armral_cmplx_int16_t *p_dst) { +inline armral_status +cmplx_matmul_i16(const uint16_t m, const uint16_t n, const uint16_t k, + const armral_cmplx_int16_t *__restrict p_src_a, + const armral_cmplx_int16_t *__restrict p_src_b, + armral_cmplx_int16_t *p_dst) { const int16_t *p_in1 = (const int16_t *)p_src_a; const armral_cmplx_int16_t *p_in_a = p_src_a; armral_cmplx_int16_t *p_out = p_dst; - uint16_t num_rows_a = m; /* number of rows of input matrix A */ - uint16_t num_cols_b = k; /* number of columns of input matrix B */ - uint16_t num_cols_a = n; /* number of columns of input matrix A */ + uint16_t num_rows_a = m; // Number of rows of input matrix A + uint16_t num_cols_b = n; // Number of columns of input matrix B + uint16_t num_cols_a = k; // Number of columns of input matrix A const int16_t *p_in1_b = (const int16_t *)p_src_a; const int16_t *p_in1_b2 = (const int16_t *)p_src_b; - /* Row loop */ + // Row loop for (uint16_t row_cnt = num_rows_a >> 1; row_cnt > 0; --row_cnt, p_out += 2 * num_cols_b, p_in_a += 2 * num_cols_a) { - /* Output pointer is set to starting address of the row being processed */ + // Output pointer is set to starting address of the row being processed armral_cmplx_int16_t *px = p_out; armral_cmplx_int16_t *px_b = px + num_cols_b; - /* For every row wise process, the column loop counter is to be initiated */ + // For every row-wise process, the column loop counter is initialized uint16_t col = num_cols_b; - /* For every row wise process, the pIn2 pointer is set - ** to the starting address of the pSrcB data */ + // For every row-wise process, the pIn2 pointer is set + // to the starting address of the pSrcB data const int16_t *p_in2 = (const int16_t *)p_src_b; p_in1_b2 = p_in2 + 2 * num_cols_b; uint16_t j = 0U; - /* Column loop */ + // Column loop while (col > 1U) { - /* Set the variable sum, that acts as accumulator, to zero */ + // Set the variable sum, that acts as accumulator, to zero int64_t sum_real1_ext = 0; int64_t sum_imag1_ext = 0; int64_t sum_real1_b_ext = 0; @@ -65,45 +69,45 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, int64_t sum_real_b_ext2 = 0; int64_t sum_imag_b_ext2 = 0; - /* Initiate the pointer pIn1 to point to the starting address of the - * column being processed */ + // Initialize the pointer pIn1 to point to the starting address of the + // column being processed p_in1 = (const int16_t *)p_in_a; p_in1_b = p_in1 + 2 * num_cols_a; - int64x2_t acc_r0 = {}; - int64x2_t acc_i0 = {}; - int64x2_t acc_r1 = {}; - int64x2_t acc_i1 = {}; - int64x2_t acc_r2 = {}; - int64x2_t acc_i2 = {}; - int64x2_t acc_r3 = {}; - int64x2_t acc_i3 = {}; + int64x2_t acc_r0 = {0}; + int64x2_t acc_i0 = {0}; + int64x2_t acc_r1 = {0}; + int64x2_t acc_i1 = {0}; + int64x2_t acc_r2 = {0}; + int64x2_t acc_i2 = {0}; + int64x2_t acc_r3 = {0}; + int64x2_t acc_i3 = {0}; - /* Matrix multiplication */ + // Matrix multiplication for (uint16_t col_cnt = num_cols_a >> 3; col_cnt > 0; --col_cnt) { - /*int16x8x2_t load and de-interleave*/ + // int16x8x2_t load and de-interleave int16x8x2_t a0_v = vld2q_s16(p_in1); - /*extend to 32bit Real part int32x4x2_t*/ + // Extend to 32bit Real part int32x4x2_t int32x4_t a0_v_rextended[2]; a0_v_rextended[0] = vmovl_low_s16(a0_v.val[0]); a0_v_rextended[1] = vmovl_high_s16(a0_v.val[0]); - /*extend to 32bit Imag part int32x4x2_t*/ + // Extend to 32bit Imag part int32x4x2_t int32x4_t a0_v_iextended[2]; a0_v_iextended[0] = vmovl_low_s16(a0_v.val[1]); a0_v_iextended[1] = vmovl_high_s16(a0_v.val[1]); int16x8x2_t a1_v = vld2q_s16(p_in1_b); int32x4x2_t a1_v_rextended; - a1_v_rextended.val[0] = vmovl_low_s16(a1_v.val[0]); /*extend to 32bit*/ - a1_v_rextended.val[1] = vmovl_high_s16(a1_v.val[0]); /*extend to 32bit*/ + a1_v_rextended.val[0] = vmovl_low_s16(a1_v.val[0]); // Extend to 32bit + a1_v_rextended.val[1] = vmovl_high_s16(a1_v.val[0]); // Extend to 32bit int32x4x2_t a1_v_iextended; - a1_v_iextended.val[0] = vmovl_low_s16(a1_v.val[1]); /*extend to 32bit*/ - a1_v_iextended.val[1] = vmovl_high_s16(a1_v.val[1]); /*extend to 32bit*/ + a1_v_iextended.val[0] = vmovl_low_s16(a1_v.val[1]); // Extend to 32bit + a1_v_iextended.val[1] = vmovl_high_s16(a1_v.val[1]); // Extend to 32bit p_in1 += 16; p_in1_b += 16; - /*4 B rows*/ - // load but NOT separate real/imag + // 4 B rows + // Load but NOT separate real/imag int16x4_t b0_v = vld1_s16(p_in2); int16x4_t b1_v = vld1_s16(p_in1_b2); int16x4_t b2_v = vld1_s16(p_in2 + 4 * num_cols_b); @@ -111,8 +115,8 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, p_in2 = p_in2 + 8 * num_cols_b; p_in1_b2 = p_in1_b2 + 8 * num_cols_b; - /*4 B rows*/ - // load but NOT separate real/imag + // 4 B rows + // Load but NOT separate real/imag int16x4_t b4_v = vld1_s16(p_in2); int16x4_t b5_v = vld1_s16(p_in1_b2); int16x4_t b6_v = vld1_s16(p_in2 + 4 * num_cols_b); @@ -120,55 +124,55 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, p_in2 = p_in2 + 8 * num_cols_b; p_in1_b2 = p_in1_b2 + 8 * num_cols_b; - // even elem first 2 B rows, 4 columns + // Even elem first 2 B rows, 4 columns int16x4_t b_col_real = vtrn1_s16(b0_v, b1_v); - // odd elem first 2 B rows, 4 columns + // Odd elem first 2 B rows, 4 columns int16x4_t b_col_im = vtrn2_s16(b0_v, b1_v); - // even elem 3rd and 4th B rows, 4 columns + // Even elem 3rd and 4th B rows, 4 columns int16x4_t b_col_real2 = vtrn1_s16(b2_v, b3_v); - // odd elem 3rd and 4th B rows, 4 columns + // Odd elem 3rd and 4th B rows, 4 columns int16x4_t b_col_im2 = vtrn2_s16(b2_v, b3_v); - // even elem 5th and 6th B rows, 4 columns + // Even elem 5th and 6th B rows, 4 columns int16x4_t b_col_real3 = vtrn1_s16(b4_v, b5_v); - // odd elem 5th and 6th B rows, 4 columns + // Odd elem 5th and 6th B rows, 4 columns int16x4_t b_col_im3 = vtrn2_s16(b4_v, b5_v); - // even elem 7th and 8th B rows, 4 columns + // Even elem 7th and 8th B rows, 4 columns int16x4_t b_col_real4 = vtrn1_s16(b6_v, b7_v); - // odd elem 7th and 8th B rows, 4 columns + // Odd elem 7th and 8th B rows, 4 columns int16x4_t b_col_im4 = vtrn2_s16(b6_v, b7_v); - /*First column B first 4 rows*/ + // First column B first 4 rows int16x4_t temp_r0 = vzip1_u16x2(b_col_real, b_col_real2); int16x4_t temp_i0 = vzip1_u16x2(b_col_im, b_col_im2); - /*First column second 4 rows*/ + // First column second 4 rows int16x4_t temp_r1 = vzip1_u16x2(b_col_real3, b_col_real4); int16x4_t temp_i1 = vzip1_u16x2(b_col_im3, b_col_im4); - /*Second column first four rows*/ + // Second column first four rows int16x4_t temp_r2 = vzip2_u16x2(b_col_real, b_col_real2); int16x4_t temp_i2 = vzip2_u16x2(b_col_im, b_col_im2); - /*Second column B second four rows*/ + // Second column B second four rows int16x4_t temp_r3 = vzip2_u16x2(b_col_real3, b_col_real4); int16x4_t temp_i3 = vzip2_u16x2(b_col_im3, b_col_im4); - int32x4_t temp_r0extended = vmovl_s16(temp_r0); /*32x4*/ - int32x4_t temp_i0extended = vmovl_s16(temp_i0); /*32x4*/ + int32x4_t temp_r0extended = vmovl_s16(temp_r0); // 32x4 + int32x4_t temp_i0extended = vmovl_s16(temp_i0); // 32x4 - int32x4_t temp_r1extended = vmovl_s16(temp_r1); /*32x4*/ - int32x4_t temp_i1extended = vmovl_s16(temp_i1); /*32x4*/ + int32x4_t temp_r1extended = vmovl_s16(temp_r1); // 32x4 + int32x4_t temp_i1extended = vmovl_s16(temp_i1); // 32x4 - int32x4_t temp_r2extended = vmovl_s16(temp_r2); /*32x4*/ - int32x4_t temp_i2extended = vmovl_s16(temp_i2); /*32x4*/ + int32x4_t temp_r2extended = vmovl_s16(temp_r2); // 32x4 + int32x4_t temp_i2extended = vmovl_s16(temp_i2); // 32x4 - int32x4_t temp_r3extended = vmovl_s16(temp_r3); /*32x4*/ - int32x4_t temp_i3extended = vmovl_s16(temp_i3); /*32x4*/ + int32x4_t temp_r3extended = vmovl_s16(temp_r3); // 32x4 + int32x4_t temp_i3extended = vmovl_s16(temp_i3); // 32x4 - /*FIrst row * first col (4 B rows high)*/ + // First row * first col (4 B rows high) acc_r0 = vmlal_high_s32(acc_r0, a0_v_rextended[0], temp_r0extended); acc_r0 = vmlal_low_s32(acc_r0, a0_v_rextended[0], temp_r0extended); - /*FIrst row * first col (4 B rows low)*/ + // First row * first col (4 B rows low) acc_r0 = vmlal_high_s32(acc_r0, a0_v_rextended[1], temp_r1extended); acc_r0 = vmlal_low_s32(acc_r0, a0_v_rextended[1], temp_r1extended); acc_r0 = vmlsl_high_s32(acc_r0, a0_v_iextended[0], temp_i0extended); @@ -188,10 +192,10 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, acc_i0 = vmlal_high_s32(acc_i0, a0_v_rextended[1], temp_i1extended); acc_i0 = vmlal_low_s32(acc_i0, a0_v_rextended[1], temp_i1extended); - /*Second row * first col (4 B rows high)*/ + // Second row * first col (4 B rows high) acc_r1 = vmlal_high_s32(acc_r1, a1_v_rextended.val[0], temp_r0extended); acc_r1 = vmlal_low_s32(acc_r1, a1_v_rextended.val[0], temp_r0extended); - /*Second row * first col (4 B rows high)*/ + // Second row * first col (4 B rows high) acc_r1 = vmlal_high_s32(acc_r1, a1_v_rextended.val[1], temp_r1extended); acc_r1 = vmlal_low_s32(acc_r1, a1_v_rextended.val[1], temp_r1extended); @@ -213,10 +217,10 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, acc_i1 = vmlal_high_s32(acc_i1, a1_v_rextended.val[1], temp_i1extended); acc_i1 = vmlal_low_s32(acc_i1, a1_v_rextended.val[1], temp_i1extended); - /*FIrst row * second col (4 B rows high)*/ + // First row * second col (4 B rows high) acc_r2 = vmlal_high_s32(acc_r2, a0_v_rextended[0], temp_r2extended); acc_r2 = vmlal_low_s32(acc_r2, a0_v_rextended[0], temp_r2extended); - /*FIrst row * first col (4 B rows low)*/ + // First row * first col (4 B rows low) acc_r2 = vmlal_high_s32(acc_r2, a0_v_rextended[1], temp_r3extended); acc_r2 = vmlal_low_s32(acc_r2, a0_v_rextended[1], temp_r3extended); @@ -238,10 +242,10 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, acc_i2 = vmlal_high_s32(acc_i2, a0_v_rextended[1], temp_i3extended); acc_i2 = vmlal_low_s32(acc_i2, a0_v_rextended[1], temp_i3extended); - /*Second row * second col (4 B rows high)*/ + // Second row * second col (4 B rows high) acc_r3 = vmlal_high_s32(acc_r3, a1_v_rextended.val[0], temp_r2extended); acc_r3 = vmlal_low_s32(acc_r3, a1_v_rextended.val[0], temp_r2extended); - /*Second row * second col (4 B rows high)*/ + // Second row * second col (4 B rows high) acc_r3 = vmlal_high_s32(acc_r3, a1_v_rextended.val[1], temp_r3extended); acc_r3 = vmlal_low_s32(acc_r3, a1_v_rextended.val[1], temp_r3extended); @@ -264,11 +268,11 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, acc_i3 = vmlal_low_s32(acc_i3, a1_v_rextended.val[1], temp_i3extended); } - /* Matrix multiplication */ - if (num_cols_a & 4) { + // Matrix multiplication + if ((num_cols_a & 4) != 0) { int16x4x2_t a2_v = vld2_s16(p_in1); - p_in1 += 8; /*int16x4x2_t*/ - /*extend to 32bit int32x4x2_t*/ + p_in1 += 8; // int16x4x2_t + // Extend to 32bit int32x4x2_t int32x4x2_t a0_vextended; a0_vextended.val[0] = vmovl_s16(a2_v.val[0]); a0_vextended.val[1] = vmovl_s16(a2_v.val[1]); @@ -277,37 +281,37 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, int32x4x2_t a1_vextended; a1_vextended.val[0] = vmovl_s16(a3_v.val[0]); a1_vextended.val[1] = vmovl_s16(a3_v.val[1]); - /*4 B rows*/ - // load but NOT separate real/imag + // 4 B rows + // Load but NOT separate real/imag int16x4_t b0_v = vld1_s16(p_in2); int16x4_t b1_v = vld1_s16(p_in1_b2); int16x4_t b2_v = vld1_s16(p_in2 + 4 * num_cols_b); int16x4_t b3_v = vld1_s16(p_in1_b2 + 4 * num_cols_b); p_in2 = p_in2 + 8 * num_cols_b; - // even elem first 2 B rows, 4 columns + // Even elem first 2 B rows, 4 columns int16x4_t b_col_real = vtrn1_s16(b0_v, b1_v); - // odd elem first 2 B rows, 4 columns + // Odd elem first 2 B rows, 4 columns int16x4_t b_col_im = vtrn2_s16(b0_v, b1_v); - // even elem 3rd and 4th B rows, 4 columns + // Even elem 3rd and 4th B rows, 4 columns int16x4_t b_col_real2 = vtrn1_s16(b2_v, b3_v); - // odd elem 3rd and 4th B rows, 4 columns + // Odd elem 3rd and 4th B rows, 4 columns int16x4_t b_col_im2 = vtrn2_s16(b2_v, b3_v); - /*First column B first 4 rows*/ + // First column B first 4 rows int16x4_t temp_r0 = vzip1_u16x2(b_col_real, b_col_real2); int16x4_t temp_i0 = vzip1_u16x2(b_col_im, b_col_im2); - /*Second column first four rows*/ + // Second column first four rows int16x4_t temp_r2 = vzip2_u16x2(b_col_real, b_col_real2); int16x4_t temp_i2 = vzip2_u16x2(b_col_im, b_col_im2); - int32x4_t temp_r0extended = vmovl_s16(temp_r0); /*32x4*/ - int32x4_t temp_i0extended = vmovl_s16(temp_i0); /*32x4*/ - int32x4_t temp_r2extended = vmovl_s16(temp_r2); /*32x4*/ - int32x4_t temp_i2extended = vmovl_s16(temp_i2); /*32x4*/ + int32x4_t temp_r0extended = vmovl_s16(temp_r0); // 32x4 + int32x4_t temp_i0extended = vmovl_s16(temp_i0); // 32x4 + int32x4_t temp_r2extended = vmovl_s16(temp_r2); // 32x4 + int32x4_t temp_i2extended = vmovl_s16(temp_i2); // 32x4 - /*First row * first col*/ + // First row * first col acc_r0 = vmlal_high_s32(acc_r0, a0_vextended.val[0], temp_r0extended); acc_r0 = vmlal_low_s32(acc_r0, a0_vextended.val[0], temp_r0extended); @@ -319,7 +323,7 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, acc_i0 = vmlal_high_s32(acc_i0, a0_vextended.val[0], temp_i0extended); acc_i0 = vmlal_low_s32(acc_i0, a0_vextended.val[0], temp_i0extended); - /*Second row * first col*/ + // Second row * first col acc_r1 = vmlal_high_s32(acc_r1, a1_vextended.val[0], temp_r0extended); acc_r1 = vmlal_low_s32(acc_r1, a1_vextended.val[0], temp_r0extended); @@ -331,7 +335,7 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, acc_i1 = vmlal_high_s32(acc_i1, a1_vextended.val[0], temp_i0extended); acc_i1 = vmlal_low_s32(acc_i1, a1_vextended.val[0], temp_i0extended); - /*First row * second col*/ + // First row * second col acc_r2 = vmlal_high_s32(acc_r2, a0_vextended.val[0], temp_r2extended); acc_r2 = vmlal_low_s32(acc_r2, a0_vextended.val[0], temp_r2extended); @@ -343,7 +347,7 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, acc_i2 = vmlal_high_s32(acc_i2, a0_vextended.val[0], temp_i2extended); acc_i2 = vmlal_low_s32(acc_i2, a0_vextended.val[0], temp_i2extended); - /*Second row * second col*/ + // Second row * second col acc_r3 = vmlal_high_s32(acc_r3, a1_vextended.val[0], temp_r2extended); acc_r3 = vmlal_low_s32(acc_r3, a1_vextended.val[0], temp_r2extended); @@ -405,7 +409,7 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, } // sumReal1Ext += sumReal2Ext; - int16x4_t out[2] = {}; + int16x4_t out[2] = {0}; out[0] = vset_lane_s16(vqmovns_s32(vqshrnd_n_s64(sum_real1_ext, 15)), out[0], 0); out[0] = vset_lane_s16(vqmovns_s32(vqshrnd_n_s64(sum_imag1_ext, 15)), @@ -424,14 +428,14 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, out[1] = vset_lane_s16(vqmovns_s32(vqshrnd_n_s64(sum_imag_b_ext2, 15)), out[1], 3); - /* Store the result in the destination buffer */ + // Store the result in the destination buffer vst1_s16((int16_t *)px, out[0]); vst1_s16((int16_t *)px_b, out[1]); px += 2; px_b += 2; - /* Update the pointer pIn2 to point to the starting address of the next - * column */ + // Update the pointer pIn2 to point to the starting address of the next + // column j++; p_in2 = (const int16_t *)p_src_b + 4U * j; p_in1_b2 = p_in2 + 2U * num_cols_b; @@ -439,51 +443,51 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, } // Deal with a single column of B - if (num_cols_b & 1) { - /* Set the variable sum, that acts as accumulator, to zero */ + if ((num_cols_b & 1) != 0) { + // Set the variable sum, that acts as accumulator, to zero int64_t sum_real1_ext = 0; int64_t sum_imag1_ext = 0; int64_t sum_real1_b_ext = 0; int64_t sum_imag1_b_ext = 0; - /* Initiate the pointer pIn1 to point to the starting address of the - * column being processed */ + // Initialize the pointer pIn1 to point to the starting address of the + // column being processed p_in1 = (const int16_t *)p_in_a; p_in1_b = p_in1 + 2 * num_cols_a; - int64x2_t acc_r0 = {}; - int64x2_t acc_i0 = {}; - int64x2_t acc_r1 = {}; - int64x2_t acc_i1 = {}; + int64x2_t acc_r0 = {0}; + int64x2_t acc_i0 = {0}; + int64x2_t acc_r1 = {0}; + int64x2_t acc_i1 = {0}; - /* Compute 8 MACs simultaneously. */ + // Compute 8 MACs simultaneously uint16_t col_cnt = num_cols_a >> 3; - /* Matrix multiplication */ + // Matrix multiplication while (col_cnt > 0U) { - /*int16x8x2_t load and de-interleave*/ + // int16x8x2_t load and de-interleave int16x8x2_t a0_v = vld2q_s16(p_in1); - /*extend to 32bit Real part int32x4x2_t*/ + // Extend to 32bit Real part int32x4x2_t int32x4_t a0_v_rextended[2]; a0_v_rextended[0] = vmovl_low_s16(a0_v.val[0]); a0_v_rextended[1] = vmovl_high_s16(a0_v.val[0]); - /*extend to 32bit Imag part int32x4x2_t*/ + // Extend to 32bit Imag part int32x4x2_t int32x4_t a0_v_iextended[2]; a0_v_iextended[0] = vmovl_low_s16(a0_v.val[1]); a0_v_iextended[1] = vmovl_high_s16(a0_v.val[1]); int16x8x2_t a1_v = vld2q_s16(p_in1_b); int32x4_t a1_v_rextended[2]; - a1_v_rextended[0] = vmovl_low_s16(a1_v.val[0]); /*extend to 32bit*/ - a1_v_rextended[1] = vmovl_high_s16(a1_v.val[0]); /*extend to 32bit*/ + a1_v_rextended[0] = vmovl_low_s16(a1_v.val[0]); // Extend to 32bit + a1_v_rextended[1] = vmovl_high_s16(a1_v.val[0]); // Extend to 32bit int32x4_t a1_v_iextended[2]; - a1_v_iextended[0] = vmovl_low_s16(a1_v.val[1]); /*extend to 32bit*/ - a1_v_iextended[1] = vmovl_high_s16(a1_v.val[1]); /*extend to 32bit*/ + a1_v_iextended[0] = vmovl_low_s16(a1_v.val[1]); // Extend to 32bit + a1_v_iextended[1] = vmovl_high_s16(a1_v.val[1]); // Extend to 32bit // Load the first four rows of B, splitting real and imaginary // components - int16x4x2_t tmp_first_four = {}; + int16x4x2_t tmp_first_four = {0}; tmp_first_four = vld2_lane_s16(p_in2, tmp_first_four, 0); tmp_first_four = vld2_lane_s16(p_in2 + 2 * num_cols_b, tmp_first_four, 1); @@ -493,7 +497,7 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, vld2_lane_s16(p_in2 + 6 * num_cols_b, tmp_first_four, 3); // Load the next four rows of B, splitting real and imaginary components - int16x4x2_t tmp_second_four = {}; + int16x4x2_t tmp_second_four = {0}; tmp_second_four = vld2_lane_s16(p_in2 + 8 * num_cols_b, tmp_second_four, 0); tmp_second_four = @@ -503,54 +507,54 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, tmp_second_four = vld2_lane_s16(p_in2 + 14 * num_cols_b, tmp_second_four, 3); - int32x4_t r_32bit[2] = {}; + int32x4_t r_32bit[2] = {0}; r_32bit[0] = vmovl_s16(tmp_first_four.val[0]); r_32bit[1] = vmovl_s16(tmp_second_four.val[0]); - int32x4_t i_32bit[2] = {}; + int32x4_t i_32bit[2] = {0}; i_32bit[0] = vmovl_s16(tmp_first_four.val[1]); i_32bit[1] = vmovl_s16(tmp_second_four.val[1]); - /* First row * column of B */ + // First row * column of B // Real * real acc_r0 = vmlal_low_s32(acc_r0, a0_v_rextended[0], r_32bit[0]); acc_r0 = vmlal_high_s32(acc_r0, a0_v_rextended[0], r_32bit[0]); acc_r0 = vmlal_low_s32(acc_r0, a0_v_rextended[1], r_32bit[1]); acc_r0 = vmlal_high_s32(acc_r0, a0_v_rextended[1], r_32bit[1]); - // imag * imag + // Imag * imag acc_r0 = vmlsl_low_s32(acc_r0, a0_v_iextended[0], i_32bit[0]); acc_r0 = vmlsl_high_s32(acc_r0, a0_v_iextended[0], i_32bit[0]); acc_r0 = vmlsl_low_s32(acc_r0, a0_v_iextended[1], i_32bit[1]); acc_r0 = vmlsl_high_s32(acc_r0, a0_v_iextended[1], i_32bit[1]); - // real * imag + // Real * imag acc_i0 = vmlal_low_s32(acc_i0, a0_v_rextended[0], i_32bit[0]); acc_i0 = vmlal_high_s32(acc_i0, a0_v_rextended[0], i_32bit[0]); acc_i0 = vmlal_low_s32(acc_i0, a0_v_rextended[1], i_32bit[1]); acc_i0 = vmlal_high_s32(acc_i0, a0_v_rextended[1], i_32bit[1]); - // imag * real + // Imag * real acc_i0 = vmlal_low_s32(acc_i0, a0_v_iextended[0], r_32bit[0]); acc_i0 = vmlal_high_s32(acc_i0, a0_v_iextended[0], r_32bit[0]); acc_i0 = vmlal_low_s32(acc_i0, a0_v_iextended[1], r_32bit[1]); acc_i0 = vmlal_high_s32(acc_i0, a0_v_iextended[1], r_32bit[1]); - /*Second row * column of B */ + // Second row * column of B // Real * real acc_r1 = vmlal_low_s32(acc_r1, a1_v_rextended[0], r_32bit[0]); acc_r1 = vmlal_high_s32(acc_r1, a1_v_rextended[0], r_32bit[0]); acc_r1 = vmlal_low_s32(acc_r1, a1_v_rextended[1], r_32bit[1]); acc_r1 = vmlal_high_s32(acc_r1, a1_v_rextended[1], r_32bit[1]); - // imag * imag + // Imag * imag acc_r1 = vmlsl_low_s32(acc_r1, a1_v_iextended[0], i_32bit[0]); acc_r1 = vmlsl_high_s32(acc_r1, a1_v_iextended[0], i_32bit[0]); acc_r1 = vmlsl_low_s32(acc_r1, a1_v_iextended[1], i_32bit[1]); acc_r1 = vmlsl_high_s32(acc_r1, a1_v_iextended[1], i_32bit[1]); - // real * imag + // Real * imag acc_i1 = vmlal_low_s32(acc_i1, a1_v_rextended[0], i_32bit[0]); acc_i1 = vmlal_high_s32(acc_i1, a1_v_rextended[0], i_32bit[0]); acc_i1 = vmlal_low_s32(acc_i1, a1_v_rextended[1], i_32bit[1]); acc_i1 = vmlal_high_s32(acc_i1, a1_v_rextended[1], i_32bit[1]); - // imag * real + // Imag * real acc_i1 = vmlal_low_s32(acc_i1, a1_v_iextended[0], r_32bit[0]); acc_i1 = vmlal_high_s32(acc_i1, a1_v_iextended[0], r_32bit[0]); acc_i1 = vmlal_low_s32(acc_i1, a1_v_iextended[1], r_32bit[1]); @@ -564,19 +568,19 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, // If the remainder of columns in the row of A is greater than 4, // do an unrolled loop of size 4 - if (num_cols_a & 4) { + if ((num_cols_a & 4) != 0) { // Load four complex numbers from A. Split into real and imaginary parts int16x4x2_t a0_v = vld2_s16(p_in1); int16x4x2_t a1_v = vld2_s16(p_in1_b); - int32x4_t a0_vextended[2] = {}; + int32x4_t a0_vextended[2] = {0}; a0_vextended[0] = vmovl_s16(a0_v.val[0]); a0_vextended[1] = vmovl_s16(a0_v.val[1]); - int32x4_t a1_vextended[2] = {}; + int32x4_t a1_vextended[2] = {0}; a1_vextended[0] = vmovl_s16(a1_v.val[0]); a1_vextended[1] = vmovl_s16(a1_v.val[1]); - int16x4x2_t tmp_b = {}; + int16x4x2_t tmp_b = {0}; tmp_b = vld2_lane_s16(p_in2, tmp_b, 0); tmp_b = vld2_lane_s16(p_in2 + 2 * num_cols_b, tmp_b, 1); tmp_b = vld2_lane_s16(p_in2 + 4 * num_cols_b, tmp_b, 2); @@ -593,10 +597,10 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, // Imag * imag acc_r0 = vmlsl_low_s32(acc_r0, a0_vextended[1], i_32bit); acc_r0 = vmlsl_high_s32(acc_r0, a0_vextended[1], i_32bit); - // real * imag + // Real * imag acc_i0 = vmlal_low_s32(acc_i0, a0_vextended[0], i_32bit); acc_i0 = vmlal_high_s32(acc_i0, a0_vextended[0], i_32bit); - // imag * real + // Imag * real acc_i0 = vmlal_low_s32(acc_i0, a0_vextended[1], r_32bit); acc_i0 = vmlal_high_s32(acc_i0, a0_vextended[1], r_32bit); @@ -607,10 +611,10 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, // Imag * imag acc_r1 = vmlsl_low_s32(acc_r1, a1_vextended[1], i_32bit); acc_r1 = vmlsl_high_s32(acc_r1, a1_vextended[1], i_32bit); - // real * imag + // Real * imag acc_i1 = vmlal_low_s32(acc_i1, a1_vextended[0], i_32bit); acc_i1 = vmlal_high_s32(acc_i1, a1_vextended[0], i_32bit); - // imag * real + // Imag * real acc_i1 = vmlal_low_s32(acc_i1, a1_vextended[1], r_32bit); acc_i1 = vmlal_high_s32(acc_i1, a1_vextended[1], r_32bit); @@ -637,16 +641,16 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, int16_t b_r = p_in2[0]; int16_t b_i = p_in2[1]; - // real * real + // Real * real sum_real1_ext += a_r1 * b_r; sum_real1_b_ext += a_r2 * b_r; - // imag * imag + // Imag * imag sum_real1_ext -= a_i1 * b_i; sum_real1_b_ext -= a_i2 * b_i; - // real * imag + // Real * imag sum_imag1_ext += a_r1 * b_i; sum_imag1_b_ext += a_r2 * b_i; - // imag * real + // Imag * real sum_imag1_ext += a_i1 * b_r; sum_imag1_b_ext += a_i2 * b_r; @@ -663,54 +667,54 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, out[1].re = vqmovns_s32(vqshrnd_n_s64(sum_real1_b_ext, 15)); out[1].im = vqmovns_s32(vqshrnd_n_s64(sum_imag1_b_ext, 15)); - /* Store the result in the destination buffer */ + // Store the result in the destination buffer *(px++) = out[0]; *(px_b++) = out[1]; } } - /*Odd number of rows*/ - if (num_rows_a & 1) { - /* Output pointer is set to starting address of the row being processed */ + // Odd number of rows + if ((num_rows_a & 1) != 0) { + // Output pointer is set to starting address of the row being processed armral_cmplx_int16_t *px = p_out; - /* For every row wise process, the column loop counter is to be initiated */ + // For every row-wise process, the column loop counter is initialized uint16_t col = num_cols_b; - /* For every row wise process, the pIn2 pointer is set - ** to the starting address of the pSrcB data */ + // For every row-wise process, the pIn2 pointer is set + // to the starting address of the pSrcB data const int16_t *p_in2 = (const int16_t *)p_src_b; uint16_t j = 0U; - /* Column loop */ + // Column loop while (col > 0U) { - /* Set the variable sum, that acts as accumulator, to zero */ + // Set the variable sum, that acts as accumulator, to zero int64_t sum_real1_ext = 0; int64_t sum_imag1_ext = 0; int64_t sum_real2_ext = 0; int64_t sum_imag2_ext = 0; - /* Initiate the pointer pIn1 to point to the starting address of the row - * being processed */ + // Initialize the pointer pIn1 to point to the starting address of the row + // being processed p_in1 = (const int16_t *)p_in_a; int64x2_t acc_r0 = vdupq_n_s64(0); int64x2_t acc_i0 = vdupq_n_s64(0); - /* Compute 8 MACs simultaneously. */ + // Compute 8 MACs simultaneously uint16_t col_cnt = num_cols_a >> 3; - /* Matrix multiplication */ + // Matrix multiplication while (col_cnt > 0U) { - /* load & separate real/imag pSrcA (de-interleave 2)*/ + // Load & separate real/imag pSrcA (de-interleave 2) int16x8x2_t a0_v = vld2q_s16(p_in1); - /*extend to 32bit Real part int32x4x2_t*/ + // Extend to 32bit Real part int32x4x2_t int32x4_t a0_v_rextended[2]; a0_v_rextended[0] = vmovl_low_s16(a0_v.val[0]); a0_v_rextended[1] = vmovl_high_s16(a0_v.val[0]); - /*extend to 32bit Imag part int32x4x2_t*/ + // Extend to 32bit Imag part int32x4x2_t int32x4_t a0_v_iextended[2]; a0_v_iextended[0] = vmovl_low_s16(a0_v.val[1]); a0_v_iextended[1] = vmovl_high_s16(a0_v.val[1]); @@ -753,10 +757,10 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, temp_i1[3] = *(p_in2 + 1U); p_in2 += 2 * num_cols_b; - int32x4_t temp_r0extended = vmovl_s16(temp_r0); /*32x4*/ - int32x4_t temp_i0extended = vmovl_s16(temp_i0); /*32x4*/ - int32x4_t temp_r1extended = vmovl_s16(temp_r1); - int32x4_t temp_i1extended = vmovl_s16(temp_i1); + int32x4_t temp_r0extended = vmovl_s16(temp_r0); // 32x4 + int32x4_t temp_i0extended = vmovl_s16(temp_i0); // 32x4 + int32x4_t temp_r1extended = vmovl_s16(temp_r1); // 32x4 + int32x4_t temp_i1extended = vmovl_s16(temp_i1); // 32x4 acc_r0 = vmlal_high_s32(acc_r0, a0_v_rextended[0], temp_r0extended); acc_r0 = vmlal_low_s32(acc_r0, a0_v_rextended[0], temp_r0extended); @@ -785,12 +789,12 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, col_cnt--; } - /* Compute 4 MACs simultaneously. */ + // Compute 4 MACs simultaneously col_cnt = (num_cols_a & 7) >> 2; - /* Matrix multiplication */ + // Matrix multiplication while (col_cnt > 0U) { - int16x4x2_t a2_v = vld2_s16(p_in1); /*int16x4x2_t*/ + int16x4x2_t a2_v = vld2_s16(p_in1); // int16x4x2_t int32x4x2_t a0_vextended; a0_vextended.val[0] = vmovl_s16(a2_v.val[0]); a0_vextended.val[1] = vmovl_s16(a2_v.val[1]); @@ -865,12 +869,12 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, int32_t sum_imag1_ext32 = vqshrnd_n_s64(sum_imag1_ext, 15); int16_t sum_imag1_q15 = vqmovns_s32(sum_imag1_ext32); - /* Store the result in the destination buffer */ - (*px).re = sum_real1_q15; - (*px).im = sum_imag1_q15; + // Store the result in the destination buffer + px->re = sum_real1_q15; + px->im = sum_imag1_q15; px++; - /* Update the pointer pIn2 to point to the starting address of the next - * column */ + // Update the pointer pIn2 to point to the starting address of the next + // column j++; p_in2 = (const int16_t *)p_src_b + 2U * j; @@ -880,3 +884,4 @@ armral_cmplx_mat_mult_i16(const uint16_t m, const uint16_t n, const uint16_t k, return ARMRAL_SUCCESS; } +} // anonymous namespace diff --git a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_i16_32bit.c b/src/BasicMathFun/MatrixMult/cmplx_matmul_i16_32bit.hpp similarity index 82% rename from src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_i16_32bit.c rename to src/BasicMathFun/MatrixMult/cmplx_matmul_i16_32bit.hpp index 54b182e..f734fb6 100644 --- a/src/BasicMathFun/MatrixMult/arm_cmplx_mat_mult_i16_32bit.c +++ b/src/BasicMathFun/MatrixMult/cmplx_matmul_i16_32bit.hpp @@ -2,50 +2,53 @@ Arm RAN Acceleration Library SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ +#pragma once + #include "armral.h" +namespace { + typedef struct { int32_t re; int32_t im; } cmplx_int32_t; -static void -armral_cmplx_mat_mult_i16_32bit_2xnx4(uint16_t n, - const armral_cmplx_int16_t *restrict a, - const armral_cmplx_int16_t *restrict b, - int ldb, armral_cmplx_int16_t *dst) { +void armral_cmplx_matmul_i16_32bit_2xkx4( + uint16_t k, const armral_cmplx_int16_t *__restrict a, + const armral_cmplx_int16_t *__restrict b, int ldb, + armral_cmplx_int16_t *dst) { // Performs the multiplication of a row of matrix A by a set of four columns - // of matrix b. It is assumed that B has four columns and no bounds checking - // is done. Equally, it is assumed that the pointer to a is for a row vector - // of length n. + // of matrix B. It is assumed that B has four columns and no bounds checking + // is done. Equally, it is assumed that the pointer to A is for a row vector + // of length k. const int16_t *a_int16 = (const int16_t *)a; const int16_t *b_int16 = (const int16_t *)b; // Accumulators for the real and imaginary components of the first row - int32x4_t real_acc[4] = {}; - int32x4_t imag_acc[4] = {}; + int32x4_t real_acc[4] = {0}; + int32x4_t imag_acc[4] = {0}; // Accumulators for the real and imaginary components of the second row - int32x4_t real_acc1[4] = {}; - int32x4_t imag_acc1[4] = {}; + int32x4_t real_acc1[4] = {0}; + int32x4_t imag_acc1[4] = {0}; - // Loop over n in blocks of 8 - for (int blk8 = n >> 3; blk8 > 0; + // Loop over k in blocks of 8 + for (int blk8 = k >> 3; blk8 > 0; --blk8, a_int16 += 16, b_int16 += 16 * ldb) { // Load 8 complex numbers from A into arrays of real and complex components int16x8x2_t a_row0 = vld2q_s16(a_int16); - int16x8x2_t a_row1 = vld2q_s16(a_int16 + 2 * n); + int16x8x2_t a_row1 = vld2q_s16(a_int16 + 2 * k); // Load 8 rows from B - int16x8_t b_tmp[8] = {}; + int16x8_t b_tmp[8] = {0}; for (int i = 0; i < 8; ++i) { b_tmp[i] = vld1q_s16(b_int16 + 2 * i * ldb); } // We now want to transpose the 8x4 matrix of complex numbers, and // de-interleave into real and complex components - int16x8_t real_bs[4] = {}; - int16x8_t imag_bs[4] = {}; - // We first separate out the real and imagninary components + int16x8_t real_bs[4] = {0}; + int16x8_t imag_bs[4] = {0}; + // We first separate out the real and imaginary components for (int i = 0; i < 4; ++i) { real_bs[i] = vtrn1q_s16(b_tmp[2 * i], b_tmp[2 * i + 1]); imag_bs[i] = vtrn2q_s16(b_tmp[2 * i], b_tmp[2 * i + 1]); @@ -53,9 +56,9 @@ armral_cmplx_mat_mult_i16_32bit_2xnx4(uint16_t n, // Now we interleave pairs of real numbers to start to get them in order. // For example, for the first two real vectors we have - // ^: [r_0, r_1, r_n, r_n+1, r_2n, r_2n+1, r_3n, r_3n+1] = trn1(v0, v1) - // ^: [r_2, r_3, r_n+2, r_n+3, r_2n+2, r_2n+3, r_3n+2, r_3n+3] = trn1(v2, v3) - // ^: zip1(trn1(v0, v1), trn1(v2, v3)) = [r_0, r_1, r_2, r_3, r_n, r_n+1, r_n+2, r_n+3] + // ^: [r_0, r_1, r_k, r_k+1, r_2k, r_2k+1, r_3k, r_3k+1] = trn1(v0, v1) + // ^: [r_2, r_3, r_k+2, r_k+3, r_2k+2, r_2k+3, r_3k+2, r_3k+3] = trn1(v2, v3) + // ^: zip1(trn1(v0, v1), trn1(v2, v3)) = [r_0, r_1, r_2, r_3, r_k, r_k+1, r_k+2, r_k+3] for (int i = 0; i < 2; ++i) { b_tmp[2 * i] = vreinterpretq_s16_u32( vzip1q_u32(vreinterpretq_u32_s16(real_bs[2 * i]), @@ -91,19 +94,19 @@ armral_cmplx_mat_mult_i16_32bit_2xnx4(uint16_t n, // Now perform the dot product of the row of A with columns of B, expanding // to 32 bits Row 1 for (int i = 0; i < 4; ++i) { - // real * real + // Real * real real_acc[i] = vqdmlal_s16(real_acc[i], vget_low_s16(a_row0.val[0]), vget_low_s16(real_bs[i])); real_acc[i] = vqdmlal_high_s16(real_acc[i], a_row0.val[0], real_bs[i]); - // imag * imag + // Imag * imag real_acc[i] = vqdmlsl_s16(real_acc[i], vget_low_s16(a_row0.val[1]), vget_low_s16(imag_bs[i])); real_acc[i] = vqdmlsl_high_s16(real_acc[i], a_row0.val[1], imag_bs[i]); - // real * imag + // Real * imag imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_low_s16(a_row0.val[0]), vget_low_s16(imag_bs[i])); imag_acc[i] = vqdmlal_high_s16(imag_acc[i], a_row0.val[0], imag_bs[i]); - // imag * real + // Imag * real imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_low_s16(a_row0.val[1]), vget_low_s16(real_bs[i])); imag_acc[i] = vqdmlal_high_s16(imag_acc[i], a_row0.val[1], real_bs[i]); @@ -111,34 +114,34 @@ armral_cmplx_mat_mult_i16_32bit_2xnx4(uint16_t n, // Row 2 for (int i = 0; i < 4; ++i) { - // real * real + // Real * real real_acc1[i] = vqdmlal_s16(real_acc1[i], vget_low_s16(a_row1.val[0]), vget_low_s16(real_bs[i])); real_acc1[i] = vqdmlal_high_s16(real_acc1[i], a_row1.val[0], real_bs[i]); - // imag * imag + // Imag * imag real_acc1[i] = vqdmlsl_s16(real_acc1[i], vget_low_s16(a_row1.val[1]), vget_low_s16(imag_bs[i])); real_acc1[i] = vqdmlsl_high_s16(real_acc1[i], a_row1.val[1], imag_bs[i]); - // real * imag + // Real * imag imag_acc1[i] = vqdmlal_s16(imag_acc1[i], vget_low_s16(a_row1.val[0]), vget_low_s16(imag_bs[i])); imag_acc1[i] = vqdmlal_high_s16(imag_acc1[i], a_row1.val[0], imag_bs[i]); - // imag * real + // Imag * real imag_acc1[i] = vqdmlal_s16(imag_acc1[i], vget_low_s16(a_row1.val[1]), vget_low_s16(real_bs[i])); imag_acc1[i] = vqdmlal_high_s16(imag_acc1[i], a_row1.val[1], real_bs[i]); } } - if (n & 4) { + if ((k & 4) != 0) { // Load 4 complex numbers from A at a time int16x4x2_t a_tmp = vld2_s16(a_int16); int16x8_t a_row0 = vcombine_s16(a_tmp.val[0], a_tmp.val[1]); - a_tmp = vld2_s16(a_int16 + 2 * n); + a_tmp = vld2_s16(a_int16 + 2 * k); int16x8_t a_row1 = vcombine_s16(a_tmp.val[0], a_tmp.val[1]); // Load 4 rows from B - int16x8_t b_vals[4] = {}; + int16x8_t b_vals[4] = {0}; for (int i = 0; i < 4; ++i) { b_vals[i] = vld1q_s16(b_int16 + 2 * i * ldb); } @@ -146,7 +149,7 @@ armral_cmplx_mat_mult_i16_32bit_2xnx4(uint16_t n, // Separate out the real and imaginary components and reorder the columns // of b_tmp as rows. Real components go into the low half of int16x8 // vectors, and imaginary components to the high half - int16x8_t b_tmp[4] = {}; + int16x8_t b_tmp[4] = {0}; for (int i = 0; i < 2; ++i) { b_tmp[2 * i] = vzip1q_s16(b_vals[2 * i], b_vals[2 * i + 1]); b_tmp[2 * i + 1] = vzip2q_s16(b_vals[2 * i], b_vals[2 * i + 1]); @@ -167,12 +170,12 @@ armral_cmplx_mat_mult_i16_32bit_2xnx4(uint16_t n, // Real * real real_acc[i] = vqdmlal_s16(real_acc[i], vget_low_s16(a_row0), vget_low_s16(b_vals[i])); - // imag * imag + // Imag * imag real_acc[i] = vqdmlsl_high_s16(real_acc[i], a_row0, b_vals[i]); - // real * imag + // Real * imag imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_low_s16(a_row0), vget_high_s16(b_vals[i])); - // imag * real + // Imag * real imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_high_s16(a_row0), vget_low_s16(b_vals[i])); } @@ -182,12 +185,12 @@ armral_cmplx_mat_mult_i16_32bit_2xnx4(uint16_t n, // Real * real real_acc1[i] = vqdmlal_s16(real_acc1[i], vget_low_s16(a_row1), vget_low_s16(b_vals[i])); - // imag * imag + // Imag * imag real_acc1[i] = vqdmlsl_high_s16(real_acc1[i], a_row1, b_vals[i]); - // real * imag + // Real * imag imag_acc1[i] = vqdmlal_s16(imag_acc1[i], vget_low_s16(a_row1), vget_high_s16(b_vals[i])); - // imag * real + // Imag * real imag_acc1[i] = vqdmlal_s16(imag_acc1[i], vget_high_s16(a_row1), vget_low_s16(b_vals[i])); } @@ -207,11 +210,11 @@ armral_cmplx_mat_mult_i16_32bit_2xnx4(uint16_t n, sum_32_3[i] = vqmovnd_s64(vaddlvq_s32(imag_acc1[i])); } - for (int a_col_cnt = n & 3; a_col_cnt > 0; + for (int a_col_cnt = k & 3; a_col_cnt > 0; --a_col_cnt, a_int16 += 2, b_int16 += 2 * ldb) { armral_cmplx_int16_t a_row0 = *((const armral_cmplx_int16_t *)a_int16); armral_cmplx_int16_t a_row1 = - *((const armral_cmplx_int16_t *)(a_int16 + 2 * n)); + *((const armral_cmplx_int16_t *)(a_int16 + 2 * k)); // Load four columns from B int16x8_t b_vals = vld1q_s16(b_int16); @@ -240,14 +243,14 @@ armral_cmplx_mat_mult_i16_32bit_2xnx4(uint16_t n, // narrow, and we do this on {re, im} pairs of numbers. // Zip real and imaginary components - int32x4_t out_32bit[4] = {}; + int32x4_t out_32bit[4] = {0}; out_32bit[0] = vzip1q_s32(sum_32_0, sum_32_1); out_32bit[1] = vzip2q_s32(sum_32_0, sum_32_1); out_32bit[2] = vzip1q_s32(sum_32_2, sum_32_3); out_32bit[3] = vzip2q_s32(sum_32_2, sum_32_3); // Narrow to 16-bit - int16x8_t out[2] = {}; + int16x8_t out[2] = {0}; out[0] = vqshrn_high_n_s32(vqshrn_n_s32(out_32bit[0], 16), out_32bit[1], 16); out[1] = vqshrn_high_n_s32(vqshrn_n_s32(out_32bit[2], 16), out_32bit[3], 16); @@ -256,39 +259,38 @@ armral_cmplx_mat_mult_i16_32bit_2xnx4(uint16_t n, vst1q_s16((int16_t *)(dst + ldb), out[1]); } -static void -armral_cmplx_mat_mult_i16_32bit_1xnx4(uint16_t n, - const armral_cmplx_int16_t *restrict a, - const armral_cmplx_int16_t *restrict b, - int ldb, armral_cmplx_int16_t *dst) { +void armral_cmplx_matmul_i16_32bit_1xkx4( + uint16_t k, const armral_cmplx_int16_t *__restrict a, + const armral_cmplx_int16_t *__restrict b, int ldb, + armral_cmplx_int16_t *dst) { // Performs the multiplication of a row of matrix A by a set of four columns - // of matrix b. It is assumed that B has four columns and no bounds checking - // is done. Equally, it is assumed that the pointer to a is for a row vector - // of length n. + // of matrix B. It is assumed that B has four columns and no bounds checking + // is done. Equally, it is assumed that the pointer to A is for a row vector + // of length k. const int16_t *a_int16 = (const int16_t *)a; const int16_t *b_int16 = (const int16_t *)b; // Accumulators for the real and imaginary components - int32x4_t real_acc[4] = {}; - int32x4_t imag_acc[4] = {}; + int32x4_t real_acc[4] = {0}; + int32x4_t imag_acc[4] = {0}; - // Loop over n in blocks of 8 - for (int blk8 = n >> 3; blk8 > 0; + // Loop over k in blocks of 8 + for (int blk8 = k >> 3; blk8 > 0; --blk8, a_int16 += 16, b_int16 += 16 * ldb) { // Load 8 complex numbers from A into arrays of real and complex components int16x8x2_t a_row = vld2q_s16(a_int16); // Load 8 rows from B - int16x8_t b_tmp[8] = {}; + int16x8_t b_tmp[8] = {0}; for (int i = 0; i < 8; ++i) { b_tmp[i] = vld1q_s16(b_int16 + 2 * i * ldb); } // We now want to transpose the 8x4 matrix of complex numbers, and // de-interleave into real and complex components - int16x8_t real_bs[4] = {}; - int16x8_t imag_bs[4] = {}; - // We first separate out the real and imagninary components + int16x8_t real_bs[4] = {0}; + int16x8_t imag_bs[4] = {0}; + // We first separate out the real and imaginary components for (int i = 0; i < 4; ++i) { real_bs[i] = vtrn1q_s16(b_tmp[2 * i], b_tmp[2 * i + 1]); imag_bs[i] = vtrn2q_s16(b_tmp[2 * i], b_tmp[2 * i + 1]); @@ -296,9 +298,9 @@ armral_cmplx_mat_mult_i16_32bit_1xnx4(uint16_t n, // Now we interleave pairs of real numbers to start to get them in order. // For example, for the first two real vectors we have - // ^: [r_0, r_1, r_n, r_n+1, r_2n, r_2n+1, r_3n, r_3n+1] = trn1(v0, v1) - // ^: [r_2, r_3, r_n+2, r_n+3, r_2n+2, r_2n+3, r_3n+2, r_3n+3] = trn1(v2, v3) - // ^: zip1(trn1(v0, v1), trn1(v2, v3)) = [r_0, r_1, r_2, r_3, r_n, r_n+1, r_n+2, r_n+3] + // ^: [r_0, r_1, r_k, r_k+1, r_2k, r_2k+1, r_3k, r_3k+1] = trn1(v0, v1) + // ^: [r_2, r_3, r_k+2, r_k+3, r_2k+2, r_2k+3, r_3k+2, r_3k+3] = trn1(v2, v3) + // ^: zip1(trn1(v0, v1), trn1(v2, v3)) = [r_0, r_1, r_2, r_3, r_k, r_k+1, r_k+2, r_k+3] for (int i = 0; i < 2; ++i) { b_tmp[2 * i] = vreinterpretq_s16_u32( vzip1q_u32(vreinterpretq_u32_s16(real_bs[2 * i]), @@ -334,31 +336,31 @@ armral_cmplx_mat_mult_i16_32bit_1xnx4(uint16_t n, // Now perform the dot product of the row of A with columns of B, expanding // to 32 bits for (int i = 0; i < 4; ++i) { - // real * real + // Real * real real_acc[i] = vqdmlal_s16(real_acc[i], vget_low_s16(a_row.val[0]), vget_low_s16(real_bs[i])); real_acc[i] = vqdmlal_high_s16(real_acc[i], a_row.val[0], real_bs[i]); - // imag * imag + // Imag * imag real_acc[i] = vqdmlsl_s16(real_acc[i], vget_low_s16(a_row.val[1]), vget_low_s16(imag_bs[i])); real_acc[i] = vqdmlsl_high_s16(real_acc[i], a_row.val[1], imag_bs[i]); - // real * imag + // Real * imag imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_low_s16(a_row.val[0]), vget_low_s16(imag_bs[i])); imag_acc[i] = vqdmlal_high_s16(imag_acc[i], a_row.val[0], imag_bs[i]); - // imag * real + // Imag * real imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_low_s16(a_row.val[1]), vget_low_s16(real_bs[i])); imag_acc[i] = vqdmlal_high_s16(imag_acc[i], a_row.val[1], real_bs[i]); } } - if (n & 4) { + if ((k & 4) != 0) { // Load 4 complex numbers from A at a time int16x4x2_t a_row = vld2_s16(a_int16); // Load 4 rows from B - int16x8_t b_vals[4] = {}; + int16x8_t b_vals[4] = {0}; for (int i = 0; i < 4; ++i) { b_vals[i] = vld1q_s16(b_int16 + 2 * i * ldb); } @@ -366,7 +368,7 @@ armral_cmplx_mat_mult_i16_32bit_1xnx4(uint16_t n, // Separate out the real and imaginary components and reorder the columns of // what was read in as rows. Real components go into the low half of int16x8 // vectors, and imaginary components to the high half - int16x8_t b_tmp[4] = {}; + int16x8_t b_tmp[4] = {0}; for (int i = 0; i < 2; ++i) { b_tmp[2 * i] = vzip1q_s16(b_vals[2 * i], b_vals[2 * i + 1]); b_tmp[2 * i + 1] = vzip2q_s16(b_vals[2 * i], b_vals[2 * i + 1]); @@ -386,13 +388,13 @@ armral_cmplx_mat_mult_i16_32bit_1xnx4(uint16_t n, // Real * real real_acc[i] = vqdmlal_s16(real_acc[i], a_row.val[0], vget_low_s16(b_vals[i])); - // imag * imag + // Imag * imag real_acc[i] = vqdmlsl_s16(real_acc[i], a_row.val[1], vget_high_s16(b_vals[i])); - // real * imag + // Real * imag imag_acc[i] = vqdmlal_s16(imag_acc[i], a_row.val[0], vget_high_s16(b_vals[i])); - // imag * real + // Imag * real imag_acc[i] = vqdmlal_s16(imag_acc[i], a_row.val[1], vget_low_s16(b_vals[i])); } @@ -408,7 +410,7 @@ armral_cmplx_mat_mult_i16_32bit_1xnx4(uint16_t n, sum_32_1[i] = vqmovnd_s64(vaddlvq_s32(imag_acc[i])); } - for (int a_col_cnt = n & 3; a_col_cnt > 0; + for (int a_col_cnt = k & 3; a_col_cnt > 0; --a_col_cnt, a_int16 += 2, b_int16 += 2 * ldb) { armral_cmplx_int16_t a_row = *((const armral_cmplx_int16_t *)a_int16); @@ -432,46 +434,45 @@ armral_cmplx_mat_mult_i16_32bit_1xnx4(uint16_t n, // narrow, and we do this on {re, im} pairs of numbers. // Zip real and imaginary components - int32x4_t out_32bit[2] = {}; + int32x4_t out_32bit[2] = {0}; out_32bit[0] = vzip1q_s32(sum_32_0, sum_32_1); out_32bit[1] = vzip2q_s32(sum_32_0, sum_32_1); // Narrow to 16-bit - int16x8_t out = {}; + int16x8_t out = {0}; out = vqshrn_high_n_s32(vqshrn_n_s32(out_32bit[0], 16), out_32bit[1], 16); // Now write to the destination array vst1q_s16((int16_t *)dst, out); } -static void -armral_cmplx_mat_mult_i16_32bit_2xnx2(uint16_t n, - const armral_cmplx_int16_t *restrict a, - const armral_cmplx_int16_t *restrict b, - int ldb, armral_cmplx_int16_t *dst) { +void armral_cmplx_matmul_i16_32bit_2xkx2( + uint16_t k, const armral_cmplx_int16_t *__restrict a, + const armral_cmplx_int16_t *__restrict b, int ldb, + armral_cmplx_int16_t *dst) { // Performs the multiplication of a row of matrix A by a pair of columns of - // matrix b. It is assumed that B has four columns and no bounds checking is - // done. Equally, it is assumed that the pointer to a is for a row vector of - // length n. + // matrix B. It is assumed that B has four columns and no bounds checking is + // done. Equally, it is assumed that the pointer to A is for a row vector of + // length k. const int16_t *a_int16 = (const int16_t *)a; const int16_t *b_int16 = (const int16_t *)b; // Accumulators for the real and imaginary components of the first row - int32x4_t real_acc[2] = {}; - int32x4_t imag_acc[2] = {}; + int32x4_t real_acc[2] = {0}; + int32x4_t imag_acc[2] = {0}; // Accumulators for the real and imaginary components of the second row - int32x4_t real_acc1[2] = {}; - int32x4_t imag_acc1[2] = {}; + int32x4_t real_acc1[2] = {0}; + int32x4_t imag_acc1[2] = {0}; - // Loop over n in blocks of 8 - for (int blk8 = n >> 3; blk8 > 0; + // Loop over k in blocks of 8 + for (int blk8 = k >> 3; blk8 > 0; --blk8, a_int16 += 16, b_int16 += 16 * ldb) { // Load 8 complex numbers from A into arrays of real and complex components int16x8x2_t a_row0 = vld2q_s16(a_int16); - int16x8x2_t a_row1 = vld2q_s16(a_int16 + 2 * n); + int16x8x2_t a_row1 = vld2q_s16(a_int16 + 2 * k); // Load 8 rows from B - int16x8_t b_vals[4] = {}; + int16x8_t b_vals[4] = {0}; for (int i = 0; i < 4; ++i) { b_vals[i] = vcombine_s16(vld1_s16(b_int16 + 4 * i * ldb), vld1_s16(b_int16 + (4 * i + 2) * ldb)); @@ -479,7 +480,7 @@ armral_cmplx_mat_mult_i16_32bit_2xnx2(uint16_t n, // We now want to transpose the 8x2 matrix of complex numbers, and // de-interleave into real and complex components - int16x8_t b_tmp[4] = {}; + int16x8_t b_tmp[4] = {0}; for (int i = 0; i < 2; ++i) { // Real numbers only b_tmp[2 * i] = vuzp1q_s16(b_vals[2 * i], b_vals[2 * i + 1]); @@ -497,21 +498,21 @@ armral_cmplx_mat_mult_i16_32bit_2xnx2(uint16_t n, // Now perform the dot product of the row of A with columns of B, expanding // to 32 bits Row 1 for (int i = 0; i < 2; ++i) { - // real * real + // Real * real real_acc[i] = vqdmlal_s16(real_acc[i], vget_low_s16(a_row0.val[0]), vget_low_s16(b_vals[2 * i])); real_acc[i] = vqdmlal_high_s16(real_acc[i], a_row0.val[0], b_vals[2 * i]); - // imag * imag + // Imag * imag real_acc[i] = vqdmlsl_s16(real_acc[i], vget_low_s16(a_row0.val[1]), vget_low_s16(b_vals[2 * i + 1])); real_acc[i] = vqdmlsl_high_s16(real_acc[i], a_row0.val[1], b_vals[2 * i + 1]); - // real * imag + // Real * imag imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_low_s16(a_row0.val[0]), vget_low_s16(b_vals[2 * i + 1])); imag_acc[i] = vqdmlal_high_s16(imag_acc[i], a_row0.val[0], b_vals[2 * i + 1]); - // imag * real + // Imag * real imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_low_s16(a_row0.val[1]), vget_low_s16(b_vals[2 * i])); imag_acc[i] = vqdmlal_high_s16(imag_acc[i], a_row0.val[1], b_vals[2 * i]); @@ -519,22 +520,22 @@ armral_cmplx_mat_mult_i16_32bit_2xnx2(uint16_t n, // Row 2 for (int i = 0; i < 2; ++i) { - // real * real + // Real * real real_acc1[i] = vqdmlal_s16(real_acc1[i], vget_low_s16(a_row1.val[0]), vget_low_s16(b_vals[2 * i])); real_acc1[i] = vqdmlal_high_s16(real_acc1[i], a_row1.val[0], b_vals[2 * i]); - // imag * imag + // Imag * imag real_acc1[i] = vqdmlsl_s16(real_acc1[i], vget_low_s16(a_row1.val[1]), vget_low_s16(b_vals[2 * i + 1])); real_acc1[i] = vqdmlsl_high_s16(real_acc1[i], a_row1.val[1], b_vals[2 * i + 1]); - // real * imag + // Real * imag imag_acc1[i] = vqdmlal_s16(imag_acc1[i], vget_low_s16(a_row1.val[0]), vget_low_s16(b_vals[2 * i + 1])); imag_acc1[i] = vqdmlal_high_s16(imag_acc1[i], a_row1.val[0], b_vals[2 * i + 1]); - // imag * real + // Imag * real imag_acc1[i] = vqdmlal_s16(imag_acc1[i], vget_low_s16(a_row1.val[1]), vget_low_s16(b_vals[2 * i])); imag_acc1[i] = @@ -542,15 +543,15 @@ armral_cmplx_mat_mult_i16_32bit_2xnx2(uint16_t n, } } - if (n & 4) { + if ((k & 4) != 0) { // Load 4 complex numbers from A at a time int16x4x2_t a_tmp = vld2_s16(a_int16); int16x8_t a_row0 = vcombine_s16(a_tmp.val[0], a_tmp.val[1]); - a_tmp = vld2_s16(a_int16 + 2 * n); + a_tmp = vld2_s16(a_int16 + 2 * k); int16x8_t a_row1 = vcombine_s16(a_tmp.val[0], a_tmp.val[1]); // Load 4 rows from B - int16x8_t b_vals[2] = {}; + int16x8_t b_vals[2] = {0}; for (int i = 0; i < 2; ++i) { b_vals[i] = vcombine_s16(vld1_s16(b_int16 + 4 * i * ldb), vld1_s16(b_int16 + (4 * i + 2) * ldb)); @@ -559,7 +560,7 @@ armral_cmplx_mat_mult_i16_32bit_2xnx2(uint16_t n, // Separate out the real and imaginary components and reorder the columns // of b_tmp as rows. Real components go into the low half of int16x8 // vectors, and imaginary components to the high half - int16x8_t b_tmp[2] = {}; + int16x8_t b_tmp[2] = {0}; b_tmp[0] = vuzp1q_s16(b_vals[0], b_vals[1]); b_tmp[1] = vuzp2q_s16(b_vals[0], b_vals[1]); @@ -572,12 +573,12 @@ armral_cmplx_mat_mult_i16_32bit_2xnx2(uint16_t n, // Real * real real_acc[i] = vqdmlal_s16(real_acc[i], vget_low_s16(a_row0), vget_low_s16(b_vals[i])); - // imag * imag + // Imag * imag real_acc[i] = vqdmlsl_high_s16(real_acc[i], a_row0, b_vals[i]); - // real * imag + // Real * imag imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_low_s16(a_row0), vget_high_s16(b_vals[i])); - // imag * real + // Imag * real imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_high_s16(a_row0), vget_low_s16(b_vals[i])); } @@ -587,12 +588,12 @@ armral_cmplx_mat_mult_i16_32bit_2xnx2(uint16_t n, // Real * real real_acc1[i] = vqdmlal_s16(real_acc1[i], vget_low_s16(a_row1), vget_low_s16(b_vals[i])); - // imag * imag + // Imag * imag real_acc1[i] = vqdmlsl_high_s16(real_acc1[i], a_row1, b_vals[i]); - // real * imag + // Real * imag imag_acc1[i] = vqdmlal_s16(imag_acc1[i], vget_low_s16(a_row1), vget_high_s16(b_vals[i])); - // imag * real + // Imag * real imag_acc1[i] = vqdmlal_s16(imag_acc1[i], vget_high_s16(a_row1), vget_low_s16(b_vals[i])); } @@ -613,12 +614,12 @@ armral_cmplx_mat_mult_i16_32bit_2xnx2(uint16_t n, } // For the last three columns of A, iterate element-by-element - for (int a_col_cnt = n & 3; a_col_cnt > 0; + for (int a_col_cnt = k & 3; a_col_cnt > 0; --a_col_cnt, a_int16 += 2, b_int16 += 2 * ldb) { // Load a single element from A (from two consecutive rows) armral_cmplx_int16_t a_row0 = *((const armral_cmplx_int16_t *)a_int16); armral_cmplx_int16_t a_row1 = - *((const armral_cmplx_int16_t *)(a_int16 + 2 * n)); + *((const armral_cmplx_int16_t *)(a_int16 + 2 * k)); // Load two columns from B int16x4_t b_vals = vld1_s16(b_int16); @@ -647,14 +648,14 @@ armral_cmplx_mat_mult_i16_32bit_2xnx2(uint16_t n, // narrow, and we do this on {re, im} pairs of numbers. // Zip real and imaginary components - int32x4_t out_32bit[2] = {}; + int32x4_t out_32bit[2] = {0}; out_32bit[0] = vcombine_s32(vzip1_s32(sum_32_0, sum_32_1), vzip2_s32(sum_32_0, sum_32_1)); out_32bit[1] = vcombine_s32(vzip1_s32(sum_32_2, sum_32_3), vzip2_s32(sum_32_2, sum_32_3)); // Narrow to 16-bit - int16x4_t out[2] = {}; + int16x4_t out[2] = {0}; out[0] = vqshrn_n_s32(out_32bit[0], 16); out[1] = vqshrn_n_s32(out_32bit[1], 16); @@ -663,30 +664,29 @@ armral_cmplx_mat_mult_i16_32bit_2xnx2(uint16_t n, vst1_s16((int16_t *)(dst + ldb), out[1]); } -static void -armral_cmplx_mat_mult_i16_32bit_1xnx2(uint16_t n, - const armral_cmplx_int16_t *restrict a, - const armral_cmplx_int16_t *restrict b, - int ldb, armral_cmplx_int16_t *dst) { +void armral_cmplx_matmul_i16_32bit_1xkx2( + uint16_t k, const armral_cmplx_int16_t *__restrict a, + const armral_cmplx_int16_t *__restrict b, int ldb, + armral_cmplx_int16_t *dst) { // Performs the multiplication of a row of matrix A by a pair of columns of - // matrix b. It is assumed that B has four columns and no bounds checking is - // done. Equally, it is assumed that the pointer to a is for a row vector of - // length n. + // matrix B. It is assumed that B has four columns and no bounds checking is + // done. Equally, it is assumed that the pointer to A is for a row vector of + // length k. const int16_t *a_int16 = (const int16_t *)a; const int16_t *b_int16 = (const int16_t *)b; // Accumulators for the real and imaginary components of the first row - int32x4_t real_acc[2] = {}; - int32x4_t imag_acc[2] = {}; + int32x4_t real_acc[2] = {0}; + int32x4_t imag_acc[2] = {0}; - // Loop over n in blocks of 8 - for (int blk8 = n >> 3; blk8 > 0; + // Loop over k in blocks of 8 + for (int blk8 = k >> 3; blk8 > 0; --blk8, a_int16 += 16, b_int16 += 16 * ldb) { // Load 8 complex numbers from A into arrays of real and complex components int16x8x2_t a_row = vld2q_s16(a_int16); // Load 8 rows from B - int16x8_t b_vals[4] = {}; + int16x8_t b_vals[4] = {0}; for (int i = 0; i < 4; ++i) { b_vals[i] = vcombine_s16(vld1_s16(b_int16 + 4 * i * ldb), vld1_s16(b_int16 + (4 * i + 2) * ldb)); @@ -694,7 +694,7 @@ armral_cmplx_mat_mult_i16_32bit_1xnx2(uint16_t n, // We now want to transpose the 8x2 matrix of complex numbers, and // de-interleave into real and complex components - int16x8_t b_tmp[4] = {}; + int16x8_t b_tmp[4] = {0}; for (int i = 0; i < 2; ++i) { // Real numbers only b_tmp[2 * i] = vuzp1q_s16(b_vals[2 * i], b_vals[2 * i + 1]); @@ -712,34 +712,34 @@ armral_cmplx_mat_mult_i16_32bit_1xnx2(uint16_t n, // Perform the dot product of the row of A with columns of B, expanding to // 32 bits for (int i = 0; i < 2; ++i) { - // real * real + // Real * real real_acc[i] = vqdmlal_s16(real_acc[i], vget_low_s16(a_row.val[0]), vget_low_s16(b_vals[2 * i])); real_acc[i] = vqdmlal_high_s16(real_acc[i], a_row.val[0], b_vals[2 * i]); - // imag * imag + // Imag * imag real_acc[i] = vqdmlsl_s16(real_acc[i], vget_low_s16(a_row.val[1]), vget_low_s16(b_vals[2 * i + 1])); real_acc[i] = vqdmlsl_high_s16(real_acc[i], a_row.val[1], b_vals[2 * i + 1]); - // real * imag + // Real * imag imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_low_s16(a_row.val[0]), vget_low_s16(b_vals[2 * i + 1])); imag_acc[i] = vqdmlal_high_s16(imag_acc[i], a_row.val[0], b_vals[2 * i + 1]); - // imag * real + // Imag * real imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_low_s16(a_row.val[1]), vget_low_s16(b_vals[2 * i])); imag_acc[i] = vqdmlal_high_s16(imag_acc[i], a_row.val[1], b_vals[2 * i]); } } - if (n & 4) { + if ((k & 4) != 0) { // Load 4 complex numbers from A at a time int16x4x2_t a_tmp = vld2_s16(a_int16); int16x8_t a_row = vcombine_s16(a_tmp.val[0], a_tmp.val[1]); // Load 4 rows from B - int16x8_t b_vals[2] = {}; + int16x8_t b_vals[2] = {0}; for (int i = 0; i < 2; ++i) { b_vals[i] = vcombine_s16(vld1_s16(b_int16 + 4 * i * ldb), vld1_s16(b_int16 + (4 * i + 2) * ldb)); @@ -748,7 +748,7 @@ armral_cmplx_mat_mult_i16_32bit_1xnx2(uint16_t n, // Separate out the real and imaginary components and reorder the columns // of b_tmp as rows. Real components go into the low half of int16x8 // vectors, and imaginary components to the high half - int16x8_t b_tmp[2] = {}; + int16x8_t b_tmp[2] = {0}; b_tmp[0] = vuzp1q_s16(b_vals[0], b_vals[1]); b_tmp[1] = vuzp2q_s16(b_vals[0], b_vals[1]); @@ -760,12 +760,12 @@ armral_cmplx_mat_mult_i16_32bit_1xnx2(uint16_t n, // Real * real real_acc[i] = vqdmlal_s16(real_acc[i], vget_low_s16(a_row), vget_low_s16(b_vals[i])); - // imag * imag + // Imag * imag real_acc[i] = vqdmlsl_high_s16(real_acc[i], a_row, b_vals[i]); - // real * imag + // Real * imag imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_low_s16(a_row), vget_high_s16(b_vals[i])); - // imag * real + // Imag * real imag_acc[i] = vqdmlal_s16(imag_acc[i], vget_high_s16(a_row), vget_low_s16(b_vals[i])); } @@ -782,7 +782,7 @@ armral_cmplx_mat_mult_i16_32bit_1xnx2(uint16_t n, } // For the last three columns of A, iterate element-by-element - for (int a_col_cnt = n & 3; a_col_cnt > 0; + for (int a_col_cnt = k & 3; a_col_cnt > 0; --a_col_cnt, a_int16 += 2, b_int16 += 2 * ldb) { // Load a single element from A (from two consecutive rows) armral_cmplx_int16_t a_row = *((const armral_cmplx_int16_t *)a_int16); @@ -817,34 +817,33 @@ armral_cmplx_mat_mult_i16_32bit_1xnx2(uint16_t n, vst1_s16((int16_t *)dst, out); } -static void -armral_cmplx_mat_mult_i16_32bit_2xnx1(uint16_t n, - const armral_cmplx_int16_t *restrict a, - const armral_cmplx_int16_t *restrict b, - int ldb, armral_cmplx_int16_t *dst) { +void armral_cmplx_matmul_i16_32bit_2xkx1( + uint16_t k, const armral_cmplx_int16_t *__restrict a, + const armral_cmplx_int16_t *__restrict b, int ldb, + armral_cmplx_int16_t *dst) { // Performs the multiplication of a row of matrix A by a single column of - // matrix b. It is assumed that B has four columns and no bounds checking is - // done. Equally, it is assumed that the pointer to a is for a row vector of - // length n. + // matrix B. It is assumed that B has four columns and no bounds checking is + // done. Equally, it is assumed that the pointer to A is for a row vector of + // length k. const int16_t *a_int16 = (const int16_t *)a; const int16_t *b_int16 = (const int16_t *)b; // Accumulators for the real and imaginary components of the first row - int32x4_t real_acc = {}; - int32x4_t imag_acc = {}; + int32x4_t real_acc = {0}; + int32x4_t imag_acc = {0}; // Accumulators for the real and imaginary components of the second row - int32x4_t real_acc1 = {}; - int32x4_t imag_acc1 = {}; + int32x4_t real_acc1 = {0}; + int32x4_t imag_acc1 = {0}; - // Loop over n in blocks of 8 - for (int blk8 = n >> 3; blk8 > 0; + // Loop over k in blocks of 8 + for (int blk8 = k >> 3; blk8 > 0; --blk8, a_int16 += 16, b_int16 += 16 * ldb) { // Load 8 complex numbers from A into arrays of real and complex components int16x8x2_t a_row0 = vld2q_s16(a_int16); - int16x8x2_t a_row1 = vld2q_s16(a_int16 + 2 * n); + int16x8x2_t a_row1 = vld2q_s16(a_int16 + 2 * k); // Load 8 rows from B, and store real and complex parts separately - int16x8x2_t b_vals = {}; + int16x8x2_t b_vals = {0}; b_vals = vld2q_lane_s16(b_int16, b_vals, 0); b_vals = vld2q_lane_s16(b_int16 + 2 * ldb, b_vals, 1); b_vals = vld2q_lane_s16(b_int16 + 4 * ldb, b_vals, 2); @@ -859,47 +858,47 @@ armral_cmplx_mat_mult_i16_32bit_2xnx1(uint16_t n, real_acc = vqdmlal_s16(real_acc, vget_low_s16(a_row0.val[0]), vget_low_s16(b_vals.val[0])); real_acc = vqdmlal_high_s16(real_acc, a_row0.val[0], b_vals.val[0]); - // imag * imag + // Imag * imag real_acc = vqdmlsl_s16(real_acc, vget_low_s16(a_row0.val[1]), vget_low_s16(b_vals.val[1])); real_acc = vqdmlsl_high_s16(real_acc, a_row0.val[1], b_vals.val[1]); - // real * imag + // Real * imag imag_acc = vqdmlal_s16(imag_acc, vget_low_s16(a_row0.val[0]), vget_low_s16(b_vals.val[1])); imag_acc = vqdmlal_high_s16(imag_acc, a_row0.val[0], b_vals.val[1]); - // imag * real + // Imag * real imag_acc = vqdmlal_s16(imag_acc, vget_low_s16(a_row0.val[1]), vget_low_s16(b_vals.val[0])); imag_acc = vqdmlal_high_s16(imag_acc, a_row0.val[1], b_vals.val[0]); // Row 2 - // real * real + // Real * real real_acc1 = vqdmlal_s16(real_acc1, vget_low_s16(a_row1.val[0]), vget_low_s16(b_vals.val[0])); real_acc1 = vqdmlal_high_s16(real_acc1, a_row1.val[0], b_vals.val[0]); - // imag * imag + // Imag * imag real_acc1 = vqdmlsl_s16(real_acc1, vget_low_s16(a_row1.val[1]), vget_low_s16(b_vals.val[1])); real_acc1 = vqdmlsl_high_s16(real_acc1, a_row1.val[1], b_vals.val[1]); - // real * imag + // Real * imag imag_acc1 = vqdmlal_s16(imag_acc1, vget_low_s16(a_row1.val[0]), vget_low_s16(b_vals.val[1])); imag_acc1 = vqdmlal_high_s16(imag_acc1, a_row1.val[0], b_vals.val[1]); - // imag * real + // Imag * real imag_acc1 = vqdmlal_s16(imag_acc1, vget_low_s16(a_row1.val[1]), vget_low_s16(b_vals.val[0])); imag_acc1 = vqdmlal_high_s16(imag_acc1, a_row1.val[1], b_vals.val[0]); } - if (n & 4) { + if ((k & 4) != 0) { // Load 4 complex numbers from A at a time int16x4x2_t a_tmp = vld2_s16(a_int16); int16x8_t a_row0 = vcombine_s16(a_tmp.val[0], a_tmp.val[1]); - a_tmp = vld2_s16(a_int16 + 2 * n); + a_tmp = vld2_s16(a_int16 + 2 * k); int16x8_t a_row1 = vcombine_s16(a_tmp.val[0], a_tmp.val[1]); // Load 4 rows from B - int16x4x2_t b_tmp = {}; + int16x4x2_t b_tmp = {0}; b_tmp = vld2_lane_s16(b_int16, b_tmp, 0); b_tmp = vld2_lane_s16(b_int16 + 2 * ldb, b_tmp, 1); b_tmp = vld2_lane_s16(b_int16 + 4 * ldb, b_tmp, 2); @@ -911,12 +910,12 @@ armral_cmplx_mat_mult_i16_32bit_2xnx1(uint16_t n, // Real * real real_acc = vqdmlal_s16(real_acc, vget_low_s16(a_row0), vget_low_s16(b_vals)); - // imag * imag + // Imag * imag real_acc = vqdmlsl_high_s16(real_acc, a_row0, b_vals); - // real * imag + // Real * imag imag_acc = vqdmlal_s16(imag_acc, vget_low_s16(a_row0), vget_high_s16(b_vals)); - // imag * real + // Imag * real imag_acc = vqdmlal_s16(imag_acc, vget_high_s16(a_row0), vget_low_s16(b_vals)); @@ -924,12 +923,12 @@ armral_cmplx_mat_mult_i16_32bit_2xnx1(uint16_t n, // Real * real real_acc1 = vqdmlal_s16(real_acc1, vget_low_s16(a_row1), vget_low_s16(b_vals)); - // imag * imag + // Imag * imag real_acc1 = vqdmlsl_high_s16(real_acc1, a_row1, b_vals); - // real * imag + // Real * imag imag_acc1 = vqdmlal_s16(imag_acc1, vget_low_s16(a_row1), vget_high_s16(b_vals)); - // imag * real + // Imag * real imag_acc1 = vqdmlal_s16(imag_acc1, vget_high_s16(a_row1), vget_low_s16(b_vals)); @@ -944,12 +943,12 @@ armral_cmplx_mat_mult_i16_32bit_2xnx1(uint16_t n, sum_32[1].im = vqmovnd_s64(vaddlvq_s32(imag_acc1)); // For the last three columns of A, iterate element-by-element - for (int a_col_cnt = n & 3; a_col_cnt > 0; + for (int a_col_cnt = k & 3; a_col_cnt > 0; --a_col_cnt, a_int16 += 2, b_int16 += 2 * ldb) { // Load a single element from A (from two consecutive rows) armral_cmplx_int16_t a_row0 = *((const armral_cmplx_int16_t *)a_int16); armral_cmplx_int16_t a_row1 = - *((const armral_cmplx_int16_t *)(a_int16 + 2 * n)); + *((const armral_cmplx_int16_t *)(a_int16 + 2 * k)); // Load one value from B armral_cmplx_int16_t b_val = *((const armral_cmplx_int16_t *)b_int16); @@ -979,15 +978,15 @@ armral_cmplx_mat_mult_i16_32bit_2xnx1(uint16_t n, dst[ldb].im = vqshrns_n_s32(sum_32[1].im, 16); } -static inline __attribute__((always_inline)) void -armral_cmplx_mat_mult_i16_32bit_1xnx1(uint16_t n, - const armral_cmplx_int16_t *restrict a, - const armral_cmplx_int16_t *restrict b, - int ldb, armral_cmplx_int16_t *dst) { - // Just do a naive implementation for the 1xnx1 case. This is a simple loop, +inline __attribute__((always_inline)) void +armral_cmplx_matmul_i16_32bit_1xkx1(uint16_t k, + const armral_cmplx_int16_t *__restrict a, + const armral_cmplx_int16_t *__restrict b, + int ldb, armral_cmplx_int16_t *dst) { + // Just do a naive implementation for the 1xkx1 case. This is a simple loop, // and we will trust the compiler to not do something horrible with it. cmplx_int32_t accum = {}; - for (int i = 0; i < n; ++i) { + for (int i = 0; i < k; ++i) { armral_cmplx_int16_t a_tmp = a[i]; armral_cmplx_int16_t b_tmp = b[i * ldb]; accum.re = vqdmlalh_s16(accum.re, a_tmp.re, b_tmp.re); @@ -1001,53 +1000,56 @@ armral_cmplx_mat_mult_i16_32bit_1xnx1(uint16_t n, dst->im = vqshrns_n_s32(accum.im, 16); } -armral_status armral_cmplx_mat_mult_i16_32bit( - const uint16_t m, const uint16_t n, const uint16_t k, - const armral_cmplx_int16_t *restrict p_src_a, - const armral_cmplx_int16_t *restrict p_src_b, armral_cmplx_int16_t *p_dst) { +armral_status +cmplx_matmul_i16_32bit(const uint16_t m, const uint16_t n, const uint16_t k, + const armral_cmplx_int16_t *__restrict p_src_a, + const armral_cmplx_int16_t *__restrict p_src_b, + armral_cmplx_int16_t *p_dst) { // Loop over two rows of A at a time const armral_cmplx_int16_t *a_ptr = p_src_a; armral_cmplx_int16_t *out_ptr = p_dst; for (uint16_t a_row_cnt = m >> 1; a_row_cnt > 0; - --a_row_cnt, a_ptr += 2 * n, out_ptr += 2 * k) { + --a_row_cnt, a_ptr += 2 * k, out_ptr += 2 * n) { armral_cmplx_int16_t *out_row_ptr = out_ptr; // Loop over four columns of B const armral_cmplx_int16_t *b_ptr = p_src_b; - for (uint16_t b_col_cnt = k >> 2; b_col_cnt > 0; + for (uint16_t b_col_cnt = n >> 2; b_col_cnt > 0; --b_col_cnt, b_ptr += 4, out_row_ptr += 4) { - armral_cmplx_mat_mult_i16_32bit_2xnx4(n, a_ptr, b_ptr, k, out_row_ptr); + armral_cmplx_matmul_i16_32bit_2xkx4(k, a_ptr, b_ptr, n, out_row_ptr); } // If there are two or more columns left in B, unroll by two columns - if (k & 2) { - armral_cmplx_mat_mult_i16_32bit_2xnx2(n, a_ptr, b_ptr, k, out_row_ptr); + if ((n & 2) != 0) { + armral_cmplx_matmul_i16_32bit_2xkx2(k, a_ptr, b_ptr, n, out_row_ptr); b_ptr += 2; out_row_ptr += 2; } // Deal with a tail, if there is one in the columns of B - if (k & 1) { - armral_cmplx_mat_mult_i16_32bit_2xnx1(n, a_ptr, b_ptr, k, out_row_ptr); + if ((n & 1) != 0) { + armral_cmplx_matmul_i16_32bit_2xkx1(k, a_ptr, b_ptr, n, out_row_ptr); } } - if (m & 1) { + if ((m & 1) != 0) { armral_cmplx_int16_t *out_row_ptr = out_ptr; // Loop over four columns of B const armral_cmplx_int16_t *b_ptr = p_src_b; - for (uint16_t b_col_cnt = k >> 2; b_col_cnt > 0; + for (uint16_t b_col_cnt = n >> 2; b_col_cnt > 0; --b_col_cnt, b_ptr += 4, out_row_ptr += 4) { - armral_cmplx_mat_mult_i16_32bit_1xnx4(n, a_ptr, b_ptr, k, out_row_ptr); + armral_cmplx_matmul_i16_32bit_1xkx4(k, a_ptr, b_ptr, n, out_row_ptr); } - if (k & 2) { - armral_cmplx_mat_mult_i16_32bit_1xnx2(n, a_ptr, b_ptr, k, out_row_ptr); + if ((n & 2) != 0) { + armral_cmplx_matmul_i16_32bit_1xkx2(k, a_ptr, b_ptr, n, out_row_ptr); b_ptr += 2; out_row_ptr += 2; } - if (k & 1) { - armral_cmplx_mat_mult_i16_32bit_1xnx1(n, a_ptr, b_ptr, k, out_row_ptr); + if ((n & 1) != 0) { + armral_cmplx_matmul_i16_32bit_1xkx1(k, a_ptr, b_ptr, n, out_row_ptr); } } return ARMRAL_SUCCESS; } + +} // anonymous namespace diff --git a/src/BasicMathFun/MatrixPseudoInv/arm_cmplx_pseudo_inverse_direct_f32.cpp b/src/BasicMathFun/MatrixPseudoInv/arm_cmplx_pseudo_inverse_direct_f32.cpp index ebca77c..a8ec7c3 100644 --- a/src/BasicMathFun/MatrixPseudoInv/arm_cmplx_pseudo_inverse_direct_f32.cpp +++ b/src/BasicMathFun/MatrixPseudoInv/arm_cmplx_pseudo_inverse_direct_f32.cpp @@ -21,7 +21,7 @@ void left_pseudo_inverse(uint16_t m, const float32_t lambda, // We can use p_dst as an intermediate N-by-N array since it has size N-by-M, // and N < M auto *mat_aha = p_dst; - armral_cmplx_mat_mult_ahb_f32(m, n, n, p_src, p_src, mat_aha); + armral_cmplx_matmul_ahb_f32(n, n, m, p_src, p_src, mat_aha); // Compute C += lambda * I armral::cmplx_mat_pseudo_inv::add_lambda(lambda, mat_aha); @@ -49,7 +49,7 @@ void right_pseudo_inverse(uint16_t n, const float32_t lambda, // We can use p_dst as an intermediate M-by-M array since it has size N-by-M, // and N >= M auto *mat_aah = p_dst; - armral_cmplx_mat_mult_aah_f32(m, n, p_src, mat_aah); + armral_cmplx_matmul_aah_f32(m, n, p_src, mat_aah); // Compute C += lambda * I armral::cmplx_mat_pseudo_inv::add_lambda(lambda, mat_aah); @@ -65,7 +65,7 @@ void right_pseudo_inverse(uint16_t n, const float32_t lambda, } // Compute A^H * B - armral_cmplx_mat_mult_ahb_f32(m, n, m, p_src, mat_inv.get(), p_dst); + armral_cmplx_matmul_ahb_f32(n, m, m, p_src, mat_inv.get(), p_dst); } template diff --git a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_f32.c b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_f32.c index 9d95aa2..45498f7 100644 --- a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_f32.c +++ b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_f32.c @@ -9,17 +9,17 @@ #endif armral_status -armral_cmplx_vecdot_f32(int32_t n, const armral_cmplx_f32_t *restrict p_src_a, +armral_cmplx_vecdot_f32(uint32_t n, const armral_cmplx_f32_t *restrict p_src_a, const armral_cmplx_f32_t *restrict p_src_b, armral_cmplx_f32_t *p_src_c) { #ifdef ARMRAL_ARCH_SVE - int32_t num_lanes = svcntd(); + uint32_t num_lanes = svcntd(); svbool_t ptrue = svptrue_b32(); svfloat32_t acc0 = svdup_n_f32(0); svfloat32_t acc1 = svdup_n_f32(0); - int32_t i = 0; - for (; i * num_lanes <= n - 2 * num_lanes; i += 2) { + uint32_t i = 0; + for (; (i + 2) * num_lanes <= n; i += 2) { svbool_t pg = svptrue_b32(); svfloat32_t vec_a0 = svld1_vnum_f32(pg, (const float32_t *)p_src_a, i); svfloat32_t vec_b0 = svld1_vnum_f32(pg, (const float32_t *)p_src_b, i); diff --git a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_f32_2.c b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_f32_2.c index 4a2da77..3fb0835 100644 --- a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_f32_2.c +++ b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_f32_2.c @@ -7,7 +7,7 @@ #include #endif -armral_status armral_cmplx_vecdot_f32_2(int32_t n, +armral_status armral_cmplx_vecdot_f32_2(uint32_t n, const float32_t *restrict p_src_a_re, const float32_t *restrict p_src_a_im, const float32_t *restrict p_src_b_re, @@ -15,13 +15,13 @@ armral_status armral_cmplx_vecdot_f32_2(int32_t n, float32_t *p_src_c_re, float32_t *p_src_c_im) { #ifdef ARMRAL_ARCH_SVE - int32_t num_lanes = svcntw(); - int32_t full_vectors = n / num_lanes; + uint32_t num_lanes = svcntw(); + uint32_t full_vectors = n / num_lanes; svbool_t pg = svptrue_b32(); svfloat32_t acc_real = svdup_f32(0); svfloat32_t acc_imag = svdup_f32(0); - for (int32_t i = 0; i < full_vectors; ++i) { + for (uint32_t i = 0; i < full_vectors; ++i) { svfloat32_t vec_a_real = svld1_f32(pg, p_src_a_re); svfloat32_t vec_a_imag = svld1_f32(pg, p_src_a_im); svfloat32_t vec_b_real = svld1_f32(pg, p_src_b_re); @@ -41,9 +41,9 @@ armral_status armral_cmplx_vecdot_f32_2(int32_t n, acc_imag = svmla_f32_x(pg, acc_imag, vec_a_imag, vec_b_real); } - int32_t tail_size = n % num_lanes; + uint32_t tail_size = n % num_lanes; if (tail_size) { - svbool_t tail_pg = svwhilelt_b32(0, tail_size); + svbool_t tail_pg = svwhilelt_b32(0U, tail_size); svfloat32_t vec_a_real = svld1_f32(tail_pg, p_src_a_re); svfloat32_t vec_a_imag = svld1_f32(tail_pg, p_src_a_im); diff --git a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16.c b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16.c index 8005432..32d88d6 100644 --- a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16.c +++ b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16.c @@ -10,17 +10,18 @@ #endif armral_status -armral_cmplx_vecdot_i16(int32_t n, const armral_cmplx_int16_t *restrict p_src_a, +armral_cmplx_vecdot_i16(uint32_t n, + const armral_cmplx_int16_t *restrict p_src_a, const armral_cmplx_int16_t *restrict p_src_b, armral_cmplx_int16_t *p_src_c) { #if ARMRAL_ARCH_SVE >= 2 - int32_t num_32bit_lanes = svcntw(); - int32_t full_vectors = n / num_32bit_lanes; + uint32_t num_32bit_lanes = svcntw(); + uint32_t full_vectors = n / num_32bit_lanes; svbool_t pg = svptrue_b16(); svint64_t acc_real = svdup_n_s64(0); svint64_t acc_imag = svdup_n_s64(0); - for (int32_t i = 0; i < full_vectors; ++i) { + for (uint32_t i = 0; i < full_vectors; ++i) { svint16_t vec_a = svld1_s16(pg, (const int16_t *)p_src_a); svint16_t vec_b = svld1_s16(pg, (const int16_t *)p_src_b); p_src_a += num_32bit_lanes; @@ -30,9 +31,9 @@ armral_cmplx_vecdot_i16(int32_t n, const armral_cmplx_int16_t *restrict p_src_a, acc_imag = svcdot_s64(acc_imag, vec_a, vec_b, 90); } - int32_t tail_size = n % num_32bit_lanes; + uint32_t tail_size = n % num_32bit_lanes; if (tail_size) { - svbool_t tail_pg = svwhilelt_b16(0, 2 * tail_size); + svbool_t tail_pg = svwhilelt_b16(0U, 2 * tail_size); svint16_t vec_a = svld1_s16(tail_pg, (const int16_t *)p_src_a); svint16_t vec_b = svld1_s16(tail_pg, (const int16_t *)p_src_b); diff --git a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_2.c b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_2.c index 23516b9..af44504 100644 --- a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_2.c +++ b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_2.c @@ -9,7 +9,7 @@ #include #endif -armral_status armral_cmplx_vecdot_i16_2(int32_t n, +armral_status armral_cmplx_vecdot_i16_2(uint32_t n, const int16_t *restrict p_src_a_re, const int16_t *restrict p_src_a_im, const int16_t *restrict p_src_b_re, @@ -17,8 +17,8 @@ armral_status armral_cmplx_vecdot_i16_2(int32_t n, int16_t *p_src_c_re, int16_t *p_src_c_im) { #if ARMRAL_ARCH_SVE >= 2 - int32_t num_32bit_lanes = svcntw(); - int32_t full_vectors = n / num_32bit_lanes; + uint32_t num_32bit_lanes = svcntw(); + uint32_t full_vectors = n / num_32bit_lanes; svbool_t pg = svptrue_b32(); svint64_t acc_real_top = svdup_s64(0); @@ -26,7 +26,7 @@ armral_status armral_cmplx_vecdot_i16_2(int32_t n, svint64_t acc_imag_top = svdup_s64(0); svint64_t acc_imag_bot = svdup_s64(0); - for (int32_t i = 0; i < full_vectors; ++i) { + for (uint32_t i = 0; i < full_vectors; ++i) { svint32_t vec_a_real = svld1sh_s32(pg, p_src_a_re); svint32_t vec_a_imag = svld1sh_s32(pg, p_src_a_im); svint32_t vec_b_real = svld1sh_s32(pg, p_src_b_re); @@ -48,9 +48,9 @@ armral_status armral_cmplx_vecdot_i16_2(int32_t n, p_src_b_im += num_32bit_lanes; } - int32_t tail_size = n % num_32bit_lanes; + uint32_t tail_size = n % num_32bit_lanes; if (tail_size) { - svbool_t tail_pg = svwhilelt_b32(0, tail_size); + svbool_t tail_pg = svwhilelt_b32(0U, tail_size); svint32_t vec_a_real = svld1sh_s32(tail_pg, p_src_a_re); svint32_t vec_a_imag = svld1sh_s32(tail_pg, p_src_a_im); svint32_t vec_b_real = svld1sh_s32(tail_pg, p_src_b_re); diff --git a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_2_32bit.c b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_2_32bit.c index acf15f3..190e459 100644 --- a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_2_32bit.c +++ b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_2_32bit.c @@ -9,14 +9,14 @@ #endif armral_status -armral_cmplx_vecdot_i16_2_32bit(int32_t n, const int16_t *restrict p_src_a_re, +armral_cmplx_vecdot_i16_2_32bit(uint32_t n, const int16_t *restrict p_src_a_re, const int16_t *restrict p_src_a_im, const int16_t *restrict p_src_b_re, const int16_t *restrict p_src_b_im, int16_t *p_src_c_re, int16_t *p_src_c_im) { #if ARMRAL_ARCH_SVE >= 2 - int32_t num_lanes = svcnth(); - int32_t full_vectors = n / num_lanes; + uint32_t num_lanes = svcnth(); + uint32_t full_vectors = n / num_lanes; svbool_t pg = svptrue_b16(); svint32_t real_acc_top = svdup_s32(0); @@ -24,7 +24,7 @@ armral_cmplx_vecdot_i16_2_32bit(int32_t n, const int16_t *restrict p_src_a_re, svint32_t imag_acc_top = svdup_s32(0); svint32_t imag_acc_bot = svdup_s32(0); - for (int32_t i = 0; i < full_vectors; ++i) { + for (uint32_t i = 0; i < full_vectors; ++i) { svint16_t vec_a_real = svld1_s16(pg, p_src_a_re); svint16_t vec_a_imag = svld1_s16(pg, p_src_a_im); svint16_t vec_b_real = svld1_s16(pg, p_src_b_re); @@ -48,9 +48,9 @@ armral_cmplx_vecdot_i16_2_32bit(int32_t n, const int16_t *restrict p_src_a_re, imag_acc_bot = svqdmlalb_s32(imag_acc_bot, vec_a_real, vec_b_imag); } - int32_t tail_size = n % num_lanes; + uint32_t tail_size = n % num_lanes; if (tail_size) { - svbool_t tail_pg = svwhilelt_b16(0, tail_size); + svbool_t tail_pg = svwhilelt_b16(0U, tail_size); svint16_t vec_a_real = svld1_s16(tail_pg, p_src_a_re); svint16_t vec_a_imag = svld1_s16(tail_pg, p_src_a_im); svint16_t vec_b_real = svld1_s16(tail_pg, p_src_b_re); diff --git a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_32bit.c b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_32bit.c index 0eb7ab5..60aec7f 100644 --- a/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_32bit.c +++ b/src/BasicMathFun/VectorDotProd/arm_cmplx_vecdot_i16_32bit.c @@ -9,21 +9,21 @@ #endif armral_status -armral_cmplx_vecdot_i16_32bit(int32_t n, +armral_cmplx_vecdot_i16_32bit(uint32_t n, const armral_cmplx_int16_t *restrict p_src_a, const armral_cmplx_int16_t *restrict p_src_b, armral_cmplx_int16_t *p_src_c) { #if ARMRAL_ARCH_SVE >= 2 svbool_t pg = svptrue_b16(); - int32_t num_16bit_lanes = svcnth(); - int32_t full_vectors = n / num_16bit_lanes; + uint32_t num_16bit_lanes = svcnth(); + uint32_t full_vectors = n / num_16bit_lanes; svint32_t real_acc_top = svdup_n_s32(0); svint32_t real_acc_bot = svdup_n_s32(0); svint32_t imag_acc_top = svdup_n_s32(0); svint32_t imag_acc_bot = svdup_n_s32(0); - for (int32_t i = 0; i < full_vectors; ++i) { + for (uint32_t i = 0; i < full_vectors; ++i) { svint16x2_t vec_a = svld2_s16(pg, (const int16_t *)p_src_a); svint16x2_t vec_b = svld2_s16(pg, (const int16_t *)p_src_b); p_src_a += num_16bit_lanes; @@ -50,9 +50,9 @@ armral_cmplx_vecdot_i16_32bit(int32_t n, svqdmlalb_s32(imag_acc_bot, svget2_s16(vec_a, 0), svget2_s16(vec_b, 1)); } - int32_t tail_size = n % num_16bit_lanes; + uint32_t tail_size = n % num_16bit_lanes; if (tail_size) { - svbool_t tail_pg = svwhilelt_b16(0, tail_size); + svbool_t tail_pg = svwhilelt_b16(0U, tail_size); svint16x2_t vec_a = svld2_s16(tail_pg, (const int16_t *)p_src_a); svint16x2_t vec_b = svld2_s16(tail_pg, (const int16_t *)p_src_b); diff --git a/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_f32.c b/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_f32.c index 36de2e3..87e9636 100644 --- a/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_f32.c +++ b/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_f32.c @@ -7,16 +7,16 @@ #include #endif -armral_status armral_cmplx_vecmul_f32(int32_t n, +armral_status armral_cmplx_vecmul_f32(uint32_t n, const armral_cmplx_f32_t *restrict a, const armral_cmplx_f32_t *restrict b, armral_cmplx_f32_t *c) { #ifdef ARMRAL_ARCH_SVE - int32_t num_64bit_lanes = svcntd(); - int32_t full_vectors = n / num_64bit_lanes; + uint32_t num_64bit_lanes = svcntd(); + uint32_t full_vectors = n / num_64bit_lanes; svbool_t pg = svptrue_b32(); - for (int32_t i = 0; i < full_vectors; i++) { + for (uint32_t i = 0; i < full_vectors; i++) { svfloat32_t vec_a = svld1_f32(pg, (const float32_t *)a); svfloat32_t vec_b = svld1_f32(pg, (const float32_t *)b); svfloat32_t vec_c = svdup_n_f32(0); @@ -31,9 +31,9 @@ armral_status armral_cmplx_vecmul_f32(int32_t n, c += num_64bit_lanes; } - int32_t tail_size = n % num_64bit_lanes; + uint32_t tail_size = n % num_64bit_lanes; if (tail_size) { - pg = svwhilelt_b32(0, 2 * tail_size); + pg = svwhilelt_b32(0U, 2 * tail_size); svfloat32_t vec_a = svld1_f32(pg, (const float32_t *)a); svfloat32_t vec_b = svld1_f32(pg, (const float32_t *)b); diff --git a/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_f32_2.c b/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_f32_2.c index 5357eb7..0bff5db 100644 --- a/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_f32_2.c +++ b/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_f32_2.c @@ -7,18 +7,18 @@ #include #endif -armral_status armral_cmplx_vecmul_f32_2(int32_t n, +armral_status armral_cmplx_vecmul_f32_2(uint32_t n, const float32_t *restrict a_re, const float32_t *restrict a_im, const float32_t *restrict b_re, const float32_t *restrict b_im, float32_t *c_re, float32_t *c_im) { #ifdef ARMRAL_ARCH_SVE - int32_t num_lanes = svcntw(); + uint32_t num_lanes = svcntw(); svbool_t pg = svptrue_b32(); - int i = 0; - for (; i * num_lanes <= n - 4 * num_lanes; i += 4) { + uint32_t i = 0; + for (; (i + 4) * num_lanes <= n; i += 4) { svfloat32_t vec_a_0_re = svld1_vnum_f32(pg, a_re, i); svfloat32_t vec_a_0_im = svld1_vnum_f32(pg, a_im, i); svfloat32_t vec_b_0_re = svld1_vnum_f32(pg, b_re, i); diff --git a/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_i16.cpp b/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_i16.cpp index 373c25f..3eb94d5 100644 --- a/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_i16.cpp +++ b/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_i16.cpp @@ -10,15 +10,15 @@ #include #endif -armral_status armral_cmplx_vecmul_i16(int32_t n, const armral_cmplx_int16_t *a, +armral_status armral_cmplx_vecmul_i16(uint32_t n, const armral_cmplx_int16_t *a, const armral_cmplx_int16_t *b, armral_cmplx_int16_t *c) { #if ARMRAL_ARCH_SVE >= 2 - int32_t num_lanes = svcnth(); - int32_t full_vectors = n / num_lanes; + uint32_t num_lanes = svcnth(); + uint32_t full_vectors = n / num_lanes; svbool_t pg = svptrue_b16(); - for (int32_t i = 0; i < full_vectors; ++i) { + for (uint32_t i = 0; i < full_vectors; ++i) { svint16x2_t vec_a = svld2_s16(pg, (const int16_t *)a); svint16x2_t vec_b = svld2_s16(pg, (const int16_t *)b); @@ -30,9 +30,9 @@ armral_status armral_cmplx_vecmul_i16(int32_t n, const armral_cmplx_int16_t *a, c += num_lanes; } - int32_t tail_size = n % num_lanes; + uint32_t tail_size = n % num_lanes; if (tail_size != 0) { - pg = svwhilelt_b16(0, tail_size); + pg = svwhilelt_b16(0U, tail_size); svint16x2_t vec_a = svld2_s16(pg, (const int16_t *)a); svint16x2_t vec_b = svld2_s16(pg, (const int16_t *)b); diff --git a/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_i16_2.c b/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_i16_2.c index 5e13b7f..3b3cd28 100644 --- a/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_i16_2.c +++ b/src/BasicMathFun/VectorMult/arm_cmplx_vecmul_i16_2.c @@ -8,17 +8,18 @@ #include #endif -armral_status armral_cmplx_vecmul_i16_2(int32_t n, const int16_t *restrict a_re, +armral_status armral_cmplx_vecmul_i16_2(uint32_t n, + const int16_t *restrict a_re, const int16_t *restrict a_im, const int16_t *restrict b_re, const int16_t *restrict b_im, int16_t *c_re, int16_t *c_im) { #if ARMRAL_ARCH_SVE >= 2 - int32_t num_lanes = svcnth(); - int32_t full_vectors = n / num_lanes; + uint32_t num_lanes = svcnth(); + uint32_t full_vectors = n / num_lanes; svbool_t pg = svptrue_b16(); - for (int32_t i = 0; i < full_vectors; ++i) { + for (uint32_t i = 0; i < full_vectors; ++i) { svint16_t vec_a_real = svld1(pg, (const int16_t *)a_re); svint16_t vec_a_imag = svld1(pg, (const int16_t *)a_im); svint16_t vec_b_real = svld1(pg, (const int16_t *)b_re); @@ -64,9 +65,9 @@ armral_status armral_cmplx_vecmul_i16_2(int32_t n, const int16_t *restrict a_re, c_im += num_lanes; } - int32_t tail_size = n % num_lanes; + uint32_t tail_size = n % num_lanes; if (tail_size) { - pg = svwhilelt_b16(0, tail_size); + pg = svwhilelt_b16(0U, tail_size); svint16_t vec_a_real = svld1(pg, (const int16_t *)a_re); svint16_t vec_a_imag = svld1(pg, (const int16_t *)a_im); diff --git a/src/DuRuInterface/MuLawCompression/arm_mu_law_decompression.cpp b/src/DuRuInterface/MuLawCompression/arm_mu_law_decompression.cpp index 623c6ba..3679a89 100644 --- a/src/DuRuInterface/MuLawCompression/arm_mu_law_decompression.cpp +++ b/src/DuRuInterface/MuLawCompression/arm_mu_law_decompression.cpp @@ -635,16 +635,16 @@ armral_status armral_mu_law_decompr_14bit( // (Aa... and Dd... now contiguous) // note this pattern repeats for bytes 7-13 as well uint8x16_t contig01_idx = - (uint8x16_t){1, 0, 3, 2, 4, 3, 6, 5, 8, 7, 10, 9, 11, 10, 13, 12}; + uint8x16_t{1, 0, 3, 2, 4, 3, 6, 5, 8, 7, 10, 9, 11, 10, 13, 12}; uint8x16_t contig2_idx = - (uint8x16_t){3, 2, 5, 4, 6, 5, 8, 7, 10, 9, 12, 11, 13, 12, 15, 14}; + uint8x16_t{3, 2, 5, 4, 6, 5, 8, 7, 10, 9, 12, 11, 13, 12, 15, 14}; uint16x8_t contig[3]; contig[0] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[0], contig01_idx)); contig[1] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[1], contig01_idx)); contig[2] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[2], contig2_idx)); // shift MSBs to most significant bit positions - int16x8_t contig_shift = (int16x8_t){0, -2, 4, 2, 0, -2, 4, 2}; + int16x8_t contig_shift = int16x8_t{0, -2, 4, 2, 0, -2, 4, 2}; contig[0] = vshlq_u16(contig[0], contig_shift); contig[1] = vshlq_u16(contig[1], contig_shift); contig[2] = vshlq_u16(contig[2], contig_shift); @@ -654,17 +654,17 @@ armral_status armral_mu_law_decompr_14bit( // [dddddddd|ccDddddd|cccccccc|bbbbCccc|bbbbbbbb|aaaaaaBb|Aaaaaaaa] // into // [00000000|00000000|00000000|0000ccDd|Bb000000|00000000|00000000|00000000] - uint8x16_t fill01_idx = (uint8x16_t){255, 255, 255, 1, 5, 255, 255, 255, - 255, 255, 255, 8, 12, 255, 255, 255}; - uint8x16_t fill2_idx = (uint8x16_t){255, 255, 255, 3, 7, 255, 255, 255, - 255, 255, 255, 10, 14, 255, 255, 255}; + uint8x16_t fill01_idx = uint8x16_t{255, 255, 255, 1, 5, 255, 255, 255, + 255, 255, 255, 8, 12, 255, 255, 255}; + uint8x16_t fill2_idx = uint8x16_t{255, 255, 255, 3, 7, 255, 255, 255, + 255, 255, 255, 10, 14, 255, 255, 255}; uint16x8_t fill[3]; fill[0] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[0], fill01_idx)); fill[1] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[1], fill01_idx)); fill[2] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[2], fill2_idx)); - int16x8_t fill_shift = (int16x8_t){0, 6, -4, 0, 0, 6, -4, 0}; + int16x8_t fill_shift = int16x8_t{0, 6, -4, 0, 0, 6, -4, 0}; fill[0] = vshlq_u16(fill[0], fill_shift); fill[1] = vshlq_u16(fill[1], fill_shift); fill[2] = vshlq_u16(fill[2], fill_shift); diff --git a/src/DuRuInterface/ORanBlockFloat/arm_block_float_decompression.cpp b/src/DuRuInterface/ORanBlockFloat/arm_block_float_decompression.cpp index 495ec44..d6fc89c 100644 --- a/src/DuRuInterface/ORanBlockFloat/arm_block_float_decompression.cpp +++ b/src/DuRuInterface/ORanBlockFloat/arm_block_float_decompression.cpp @@ -279,19 +279,19 @@ armral_status armral_block_float_decompr_12bit( // permute bytes [ ... bb|aB|Aa] (capital indicates msb of original data) // into [ ... aB|bb|Aa|aB] (note Bbb and Aaa now contiguous) uint8x16_t idx01 = - (uint8x16_t){1, 0, 2, 1, 4, 3, 5, 4, 7, 6, 8, 7, 10, 9, 11, 10}; + uint8x16_t{1, 0, 2, 1, 4, 3, 5, 4, 7, 6, 8, 7, 10, 9, 11, 10}; uint8x16_t idx2 = - (uint8x16_t){5, 4, 6, 5, 8, 7, 9, 8, 11, 10, 12, 11, 14, 13, 15, 14}; + uint8x16_t{5, 4, 6, 5, 8, 7, 9, 8, 11, 10, 12, 11, 14, 13, 15, 14}; int16x8_t pack0 = vreinterpretq_s16_u8(vqtbl1q_u8(in0, idx01)); int16x8_t pack1 = vreinterpretq_s16_u8(vqtbl1q_u8(in1, idx01)); int16x8_t pack2 = vreinterpretq_s16_u8(vqtbl1q_u8(in2, idx2)); // shift and mask from [ ... aBbb|AaaB] (from above) // into [ ... Bbb0|Aaa0] - int16x8_t mask = (int16x8_t){ + int16x8_t mask = int16x8_t{ (int16_t)0xfff0, (int16_t)0x0fff, (int16_t)0xfff0, (int16_t)0x0fff, (int16_t)0xfff0, (int16_t)0x0fff, (int16_t)0xfff0, (int16_t)0x0fff}; - int16x8_t shift = (int16x8_t){0, 4, 0, 4, 0, 4, 0, 4}; + int16x8_t shift = int16x8_t{0, 4, 0, 4, 0, 4, 0, 4}; pack0 = vshlq_s16(vandq_s16(pack0, mask), shift); pack1 = vshlq_s16(vandq_s16(pack1, mask), shift); pack2 = vshlq_s16(vandq_s16(pack2, mask), shift); @@ -447,9 +447,9 @@ armral_status armral_block_float_decompr_14bit( int16x8_t out = vreinterpretq_s16_u16(in0 | in1 | in2); int16x8_t pack0 = out >> (2 - exp); - idx0 = (uint8x8_t){255, 4, 6, 8, 255, 11, 13, 15}; - idx1 = (uint8x8_t){3, 5, 7, 255, 10, 12, 14, 255}; - idx2 = (uint8x8_t){2, 3, 5, 7, 9, 10, 12, 14}; + idx0 = uint8x8_t{255, 4, 6, 8, 255, 11, 13, 15}; + idx1 = uint8x8_t{3, 5, 7, 255, 10, 12, 14, 255}; + idx2 = uint8x8_t{2, 3, 5, 7, 9, 10, 12, 14}; in0_8b = vqtbl1_u8(in_b, idx0); in1_8b = vqtbl1_u8(in_b, idx1); diff --git a/src/DuRuInterface/ORanBlockScaling/arm_block_scaling_decompression.cpp b/src/DuRuInterface/ORanBlockScaling/arm_block_scaling_decompression.cpp index 3de46c3..01caaf5 100644 --- a/src/DuRuInterface/ORanBlockScaling/arm_block_scaling_decompression.cpp +++ b/src/DuRuInterface/ORanBlockScaling/arm_block_scaling_decompression.cpp @@ -317,16 +317,16 @@ armral_status armral_block_scaling_decompr_14bit( // (Aa... and Dd... now contiguous) // note this pattern repeats for bytes 7-13 as well uint8x16_t contig01_idx = - (uint8x16_t){1, 0, 3, 2, 4, 3, 6, 5, 8, 7, 10, 9, 11, 10, 13, 12}; + uint8x16_t{1, 0, 3, 2, 4, 3, 6, 5, 8, 7, 10, 9, 11, 10, 13, 12}; uint8x16_t contig2_idx = - (uint8x16_t){3, 2, 5, 4, 6, 5, 8, 7, 10, 9, 12, 11, 13, 12, 15, 14}; + uint8x16_t{3, 2, 5, 4, 6, 5, 8, 7, 10, 9, 12, 11, 13, 12, 15, 14}; uint16x8_t contig[3]; contig[0] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[0], contig01_idx)); contig[1] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[1], contig01_idx)); contig[2] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[2], contig2_idx)); // shift MSBs to most significant bit positions - int16x8_t contig_shift = (int16x8_t){0, -2, 4, 2, 0, -2, 4, 2}; + int16x8_t contig_shift = int16x8_t{0, -2, 4, 2, 0, -2, 4, 2}; contig[0] = vshlq_u16(contig[0], contig_shift); contig[1] = vshlq_u16(contig[1], contig_shift); contig[2] = vshlq_u16(contig[2], contig_shift); @@ -336,17 +336,17 @@ armral_status armral_block_scaling_decompr_14bit( // [dddddddd|ccDddddd|cccccccc|bbbbCccc|bbbbbbbb|aaaaaaBb|Aaaaaaaa] // into // [00000000|00000000|00000000|0000ccDd|Bb000000|00000000|00000000|00000000] - uint8x16_t fill01_idx = (uint8x16_t){255, 255, 255, 1, 5, 255, 255, 255, - 255, 255, 255, 8, 12, 255, 255, 255}; - uint8x16_t fill2_idx = (uint8x16_t){255, 255, 255, 3, 7, 255, 255, 255, - 255, 255, 255, 10, 14, 255, 255, 255}; + uint8x16_t fill01_idx = uint8x16_t{255, 255, 255, 1, 5, 255, 255, 255, + 255, 255, 255, 8, 12, 255, 255, 255}; + uint8x16_t fill2_idx = uint8x16_t{255, 255, 255, 3, 7, 255, 255, 255, + 255, 255, 255, 10, 14, 255, 255, 255}; uint16x8_t fill[3]; fill[0] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[0], fill01_idx)); fill[1] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[1], fill01_idx)); fill[2] = vreinterpretq_u16_u8(vqtbl1q_u8(uprb_in[2], fill2_idx)); - int16x8_t fill_shift = (int16x8_t){0, 6, -4, 0, 0, 6, -4, 0}; + int16x8_t fill_shift = int16x8_t{0, 6, -4, 0, 0, 6, -4, 0}; fill[0] = vshlq_u16(fill[0], fill_shift); fill[1] = vshlq_u16(fill[1], fill_shift); fill[2] = vshlq_u16(fill[2], fill_shift); diff --git a/src/LowerPHY/Correlation/arm_correlation.c b/src/LowerPHY/Correlation/arm_correlation.c index 85cca8c..b260b4d 100644 --- a/src/LowerPHY/Correlation/arm_correlation.c +++ b/src/LowerPHY/Correlation/arm_correlation.c @@ -53,9 +53,15 @@ armral_cmplx_vecdot_i16_n8_conj_32bit(int16x8_t vec_a_re, int16x8_t vec_a_im, #endif armral_status -armral_corr_coeff_i16(int32_t n, const armral_cmplx_int16_t *restrict p_src_a, +armral_corr_coeff_i16(uint32_t n, const armral_cmplx_int16_t *restrict p_src_a, const armral_cmplx_int16_t *restrict p_src_b, armral_cmplx_int16_t *c) { + /* We have to cast uint32_t n to an int32_t, but if it is too large to */ + /* be represented by the new datatype we return an argument error. */ + if (n > INT32_MAX) { + return ARMRAL_ARGUMENT_ERROR; + } + const int16_t *p_ini_a = (const int16_t *)p_src_a; const int16_t *p_ini_b = (const int16_t *)p_src_b; @@ -71,7 +77,7 @@ armral_corr_coeff_i16(int32_t n, const armral_cmplx_int16_t *restrict p_src_a, svint64_t xyconj_im_vec = svdup_n_s64(0); /* Compute 8 outputs at a time */ - for (int i = 0; i < n; i += svcnth()) { + for (uint32_t i = 0; i < n; i += svcnth()) { svbool_t pg = svwhilelt_b16(i, n); /* Load samples */ svint16x2_t x = svld2_s16(pg, p_ini_a); @@ -219,14 +225,16 @@ armral_corr_coeff_i16(int32_t n, const armral_cmplx_int16_t *restrict p_src_a, } #endif - xavg_re = xavg_re / n; /*16.15 format*/ - xavg_im = xavg_im / n; - yavg_re = yavg_re / n; - yavg_im = yavg_im / n; + xavg_re = xavg_re / (int32_t)n; /*16.15 format*/ + xavg_im = xavg_im / (int32_t)n; + yavg_re = yavg_re / (int32_t)n; + yavg_im = yavg_im / (int32_t)n; /* (n*xavg*yavg), 33.30 format */ - int64_t temp_re = n * (int64_t)((xavg_re * yavg_re) - (xavg_im * yavg_im)); - int64_t temp_im = n * (int64_t)((xavg_im * yavg_re) + (xavg_re * yavg_im)); + int64_t temp_re = + (int64_t)n * (int64_t)((xavg_re * yavg_re) - (xavg_im * yavg_im)); + int64_t temp_im = + (int64_t)n * (int64_t)((xavg_im * yavg_re) + (xavg_re * yavg_im)); int64_t num_re = xyconj_re - temp_re; /* 33.30 format*/ int64_t num_im = xyconj_im - temp_im; diff --git a/src/LowerPHY/FFT/fft_plan.cpp b/src/LowerPHY/FFT/fft_plan.cpp index 4b10a61..30635e1 100644 --- a/src/LowerPHY/FFT/fft_plan.cpp +++ b/src/LowerPHY/FFT/fft_plan.cpp @@ -49,9 +49,9 @@ Tw *make_twiddles(int n1, int n2, armral_fft_direction_t dir, float input = base_m * i * (j + jj); float a = j + jj < n2 ? cosf(input) : 0; float b = j + jj < n2 ? sinf(input) : 0; - twids[x++] = (Tw){a, b}; + twids[x++] = Tw{a, b}; if (want_conj_twids) { - twids[x++] = (Tw){-b, a}; + twids[x++] = Tw{-b, a}; } } } diff --git a/src/LowerPHY/FFT/rader.cpp b/src/LowerPHY/FFT/rader.cpp index 1678c6a..70dd18f 100644 --- a/src/LowerPHY/FFT/rader.cpp +++ b/src/LowerPHY/FFT/rader.cpp @@ -66,7 +66,7 @@ rader make_rader(int n, armral_fft_direction_t dir) { for (int i = 0; i < n - 1; i++) { double x = ginvmul_fw_perm[i]; double in = ((2. * M_PI * x) / n) * dir_float; - b[i] = (Tw){(real_t)cos(in), (real_t)sin(in)}; + b[i] = Tw{(real_t)cos(in), (real_t)sin(in)}; } armral::fft::execute(pf, b, b, 1, 1, 1); @@ -218,8 +218,8 @@ void execute_rader(const rader &r, const Tx *x, Ty *y, int istride, for (int i = 0; i < howmany; ++i) { y[i * odist] = armral::fft::cast(y0[i]); for (int j = 0; j < nm1; ++j) { - auto yelem = (Tw){work_ptr[i + j * howmany].re + x0[i].re, - work_ptr[i + j * howmany].im + x0[i].im}; + auto yelem = Tw{work_ptr[i + j * howmany].re + x0[i].re, + work_ptr[i + j * howmany].im + x0[i].im}; y[i * odist + r.ginvmul_fw_perm[j] * ostride] = armral::fft::cast(yelem); } diff --git a/src/LowerPHY/SeqGenerator/arm_mat_seq_generator.cpp b/src/LowerPHY/SeqGenerator/arm_mat_seq_generator.cpp index d332880..0ed8fd3 100644 --- a/src/LowerPHY/SeqGenerator/arm_mat_seq_generator.cpp +++ b/src/LowerPHY/SeqGenerator/arm_mat_seq_generator.cpp @@ -27,7 +27,7 @@ static inline void generate_seq_128(uint64_t *x) { template static inline void generate_seq_64(uint64_t *x) { - static_assert((N == 1) | (N == 2)); + static_assert((N == 1) || (N == 2)); poly64_t pmask[3]; if (N == 1) { diff --git a/src/MatrixFactorizations/SVD/arm_svd.cpp b/src/MatrixFactorizations/SVD/arm_svd.cpp index b9e6cb1..e155b69 100644 --- a/src/MatrixFactorizations/SVD/arm_svd.cpp +++ b/src/MatrixFactorizations/SVD/arm_svd.cpp @@ -21,19 +21,18 @@ namespace { // Compute dot product c = a . conj(b) with a, b and c complex vectors // of length n. -inline void cmplx_vecdot_conj_f32(int32_t n, const armral_cmplx_f32_t *p_src_a, +inline void cmplx_vecdot_conj_f32(uint32_t n, const armral_cmplx_f32_t *p_src_a, const armral_cmplx_f32_t *p_src_b, armral_cmplx_f32_t *p_src_c) { #ifdef ARMRAL_ARCH_SVE - - int32_t num_lanes = svcntd(); + uint32_t num_lanes = svcntd(); svbool_t ptrue = svptrue_b32(); svfloat32_t acc0 = svdup_n_f32(0); svfloat32_t acc1 = svdup_n_f32(0); - int32_t i = 0; - for (; i * num_lanes <= n - 2 * num_lanes; i += 2) { + uint32_t i = 0; + for (; (i + 2) * num_lanes <= n; i += 2) { svbool_t pg = svptrue_b32(); svfloat32_t vec_a0 = svld1_vnum_f32(pg, (const float32_t *)p_src_a, i); svfloat32_t vec_b0 = svld1_vnum_f32(pg, (const float32_t *)p_src_b, i); @@ -59,7 +58,6 @@ inline void cmplx_vecdot_conj_f32(int32_t n, const armral_cmplx_f32_t *p_src_a, p_src_c->im = svaddv_f32(ptrue, svtrn2_f32(acc0, acc1)); #else - float32_t real_sum = 0; float32_t imag_sum = 0; @@ -145,7 +143,7 @@ inline void cmplx_vecdot_conj_f32(int32_t n, const armral_cmplx_f32_t *p_src_a, // Compute c -= a * b with a and c complex vectors of length n and b a complex // constant value. -inline void cmplx_axmy_f32(int32_t n, const armral_cmplx_f32_t *p_src_a, +inline void cmplx_axmy_f32(uint32_t n, const armral_cmplx_f32_t *p_src_a, const armral_cmplx_f32_t *p_src_b, armral_cmplx_f32_t *p_src_c) { float32x4_t re_cte = vdupq_n_f32(p_src_b->re); @@ -293,13 +291,13 @@ inline armral_cmplx_f32_t inv_cf32(armral_cmplx_f32_t a) { // a given vector to annihilate all the entries // except the first. The second entry to the // last are overwritten by the reflectors. -inline armral_cmplx_f32_t clarfg(int n, armral_cmplx_f32_t &aii, - armral_cmplx_f32_t *x, int incx) { +inline armral_cmplx_f32_t clarfg(uint32_t n, armral_cmplx_f32_t &aii, + armral_cmplx_f32_t *x, uint32_t incx) { armral_cmplx_f32_t alpha = aii; // Sum of x[i] * conj(x[i]) float32_t sum = 0.0F; - for (int i = 0; i < n * incx; i += incx) { + for (uint32_t i = 0; i < n * incx; i += incx) { sum += square_conj_cf32(x[i]); } @@ -312,8 +310,8 @@ inline armral_cmplx_f32_t clarfg(int n, armral_cmplx_f32_t &aii, sum += square_conj_cf32(alpha); float32_t beta = -copysign(sqrt(sum), alpha.re); float32_t rsafemin = 1.0F / safemin; - int cnt = 0; - int max_attempt = 10; + uint32_t cnt = 0; + uint32_t max_attempt = 10; float32_t scale = 1.0F; // Check if beta is small enough to induce // overflow when taking the inverse, and @@ -328,13 +326,13 @@ inline armral_cmplx_f32_t clarfg(int n, armral_cmplx_f32_t &aii, if (cnt > 0) { alpha.re *= scale; alpha.im *= scale; - for (int i = 0; i < n * incx; i += incx) { + for (uint32_t i = 0; i < n * incx; i += incx) { x[i].re *= scale; x[i].im *= scale; } // The new beta is at most 1, at least safmin, sum = square_conj_cf32(alpha); - for (int i = 0; i < n * incx; i += incx) { + for (uint32_t i = 0; i < n * incx; i += incx) { sum += square_conj_cf32(x[i]); } beta = -copysign(sqrt(sum), alpha.re); @@ -345,7 +343,7 @@ inline armral_cmplx_f32_t clarfg(int n, armral_cmplx_f32_t &aii, tau.im = -alpha.im / beta; armral_cmplx_f32_t normalization_factor = inv_cf32({alpha.re - beta, alpha.im}); - for (int i = 0; i < n * incx; i += incx) { + for (uint32_t i = 0; i < n * incx; i += incx) { x[i] = mult_cf32(normalization_factor, x[i]); } beta /= scale; @@ -380,10 +378,10 @@ inline void rotg(float32_t f, float32_t g, float32_t &cs, float32_t &sn, // This routine updates singular vectors // by applying the Givens rotations // used to update the bidiagonal matrix -inline void update_sigvect(int m, float32_t cs, float32_t sn, +inline void update_sigvect(uint32_t m, float32_t cs, float32_t sn, armral_cmplx_f32_t *v1, armral_cmplx_f32_t *v2, - int incv) { - for (int i = 0; i < m * incv; i += incv) { + uint32_t incv) { + for (uint32_t i = 0; i < m * incv; i += incv) { auto t = v1[i]; v1[i].re = cs * t.re + sn * v2[i].re; v1[i].im = cs * t.im + sn * v2[i].im; @@ -400,20 +398,20 @@ inline void update_sigvect(int m, float32_t cs, float32_t sn, // stored explicitly, but can be formed using // the armral_assemble_q routine, or applied, // using armral_apply_q. -armral_status armral_householder_qr(int m, int n, armral_cmplx_f32_t *a, - armral_cmplx_f32_t *tau) { +armral_status householder_qr(uint32_t m, uint32_t n, armral_cmplx_f32_t *a, + armral_cmplx_f32_t *tau) { if (m < n) { return ARMRAL_ARGUMENT_ERROR; } column_major_matrix_view a_mat{a, m}; - for (int i = 0; i < n; i++) { - int k = std::min(i + 1, m - 1); + for (uint32_t i = 0; i < n; i++) { + uint32_t k = std::min(i + 1, m - 1); tau[i] = clarfg(m - i - 1, a_mat(i, i), &a_mat(k, i), a_mat.row_increment()); auto tmp = a_mat(i, i); a_mat(i, i) = {1.0F, 0.0F}; if (i < n - 1) { - for (int col = i + 1; col < n; col++) { + for (uint32_t col = i + 1; col < n; col++) { // w = A(row, col) * conj(A(row, i)) armral_cmplx_f32_t w = {0.0F, 0.0F}; if (m > i) { @@ -430,9 +428,8 @@ armral_status armral_householder_qr(int m, int n, armral_cmplx_f32_t *a, // Generate explicitly Q from QR factorization or from // the bidiagonalization A = Q * B * P^H -armral_status armral_assemble_q(int m, int n, const armral_cmplx_f32_t *a, - const armral_cmplx_f32_t *tau, - armral_cmplx_f32_t *q) { +armral_status assemble_q(uint32_t m, uint32_t n, const armral_cmplx_f32_t *a, + const armral_cmplx_f32_t *tau, armral_cmplx_f32_t *q) { if (m < n) { return ARMRAL_ARGUMENT_ERROR; } @@ -442,11 +439,16 @@ armral_status armral_assemble_q(int m, int n, const armral_cmplx_f32_t *a, // Accumulate reflectors from right to left // Q = H1 * H2....* Hn. They are applied to identity. column_major_matrix_view q_mat{q, m}; - for (int i = n - 1; i >= 0; i--) { + + // n will always be >=1 because of the fast + // return in the top-level function + uint32_t i = n; + do { + i--; if (i < n - 1) { q_mat(i, i) = {1.0F, 0.0F}; // Apply reflector from the left - for (int col = i + 1; col < n; col++) { + for (uint32_t col = i + 1; col < n; col++) { armral_cmplx_f32_t w = {0.0F, 0.0F}; if (m > i) { cmplx_vecdot_conj_f32(m - i, &q_mat(i, col), &q_mat(i, i), &w); @@ -457,16 +459,16 @@ armral_status armral_assemble_q(int m, int n, const armral_cmplx_f32_t *a, } if (i < m - 1) { // Scale entries i+1 to m-1 of the i-th column - for (int r = i + 1; r < m; r++) { + for (uint32_t r = i + 1; r < m; r++) { q_mat(r, i) = mult_cf32(q_mat(r, i), {-tau[i].re, -tau[i].im}); } } q_mat(i, i) = {1.0F - tau[i].re, -tau[i].im}; // Set the entries 0 to i-1 of the i-th column to zero - for (int r = 0; r <= i - 1; r++) { + for (uint32_t r = 0; r < i; r++) { q_mat(r, i) = {0.0F, 0.0F}; } - } + } while (i != 0); return ARMRAL_SUCCESS; } @@ -474,8 +476,8 @@ armral_status armral_assemble_q(int m, int n, const armral_cmplx_f32_t *a, // the bidiagonalization A = Q * B * P^H, // note that P^H is generated directly // instead of P -void armral_assemble_p(int m, int n, const armral_cmplx_f32_t *a, - const armral_cmplx_f32_t *tau, armral_cmplx_f32_t *p) { +void assemble_p(uint32_t m, uint32_t n, const armral_cmplx_f32_t *a, + const armral_cmplx_f32_t *tau, armral_cmplx_f32_t *p) { // Shifted copy of A to P with first // column and first row set to zero // Set first column to zero @@ -484,46 +486,51 @@ void armral_assemble_p(int m, int n, const armral_cmplx_f32_t *a, memset(&p[1], 0, n * sizeof(armral_cmplx_f32_t)); column_major_matrix_view p_mat{p, n}; column_major_matrix_view a_mat{a, m}; - for (int j = 1; j < n; j++) { + for (uint32_t j = 1; j < n; j++) { // Set for row to zero p_mat(0, j) = {0.0F, 0.0F}; - for (int i = 1; i <= j; i++) { + for (uint32_t i = 1; i <= j; i++) { p_mat(i, j) = a_mat(i - 1, j); } } // Work on shifted matrix with reflector // just above the diagonal to fall back a // case similar to QR - int n1 = n - 1; + + // n will always be >=1 because of the fast + // return in the top-level function + uint32_t n1 = n - 1; auto *p1 = &p[p_mat.stride() + 1]; // Apply householder reflectors from the right column_major_matrix_view p1_mat{p1, n}; - for (int i = n1 - 1; i >= 0; i--) { + uint32_t i = n1; + do { + i--; if (i < n1 - 1) { p1_mat(i, i) = {1.0F, 0.0F}; - for (int row = i + 1; row < n1; row++) { + for (uint32_t row = i + 1; row < n1; row++) { armral_cmplx_f32_t w = {0.0F, 0.0F}; - for (int col = i; col < n1; col++) { + for (uint32_t col = i; col < n1; col++) { w = mult_conj_add_cf32(p1_mat(row, col), p1_mat(i, col), w); } auto tmp = mult_conj_cf32(w, tau[i]); - for (int col = i; col < n1; col++) { + for (uint32_t col = i; col < n1; col++) { p1_mat(row, col) = mult_sub_cf32(p1_mat(i, col), tmp, p1_mat(row, col)); } } // Scale - for (int col = i + 1; col < n1; col++) { + for (uint32_t col = i + 1; col < n1; col++) { p1_mat(i, col) = mult_cf32(p1_mat(i, col), {-tau[i].re, tau[i].im}); } } p1_mat(i, i) = {1.0F - tau[i].re, tau[i].im}; // Set entries 0 to i-1 of the i-th row to zero - for (int col = 0; col < i; col++) { + for (uint32_t col = 0; col < i; col++) { p1_mat(i, col) = {0.0F, 0.0F}; } - } + } while (i != 0); } // This routine reduces a general complex m-by-n matrix A @@ -541,25 +548,25 @@ void armral_assemble_p(int m, int n, const armral_cmplx_f32_t *a, // the bidiagonal matrix B. Note that this routine // returns directly the conjugate transpose of the // left orthogonal matrix. -armral_status armral_bidiagonalization(int m, int n, armral_cmplx_f32_t *a, - float32_t *d, float32_t *e, - armral_cmplx_f32_t *tauq, - armral_cmplx_f32_t *taup) { +armral_status bidiagonalization(uint32_t m, uint32_t n, armral_cmplx_f32_t *a, + float32_t *d, float32_t *e, + armral_cmplx_f32_t *tauq, + armral_cmplx_f32_t *taup) { if (m < n) { return ARMRAL_ARGUMENT_ERROR; } column_major_matrix_view a_mat{a, m}; - for (int i = 0; i < n; i++) { + for (uint32_t i = 0; i < n; i++) { // QR steps, generate elementary reflector H(i) to annihilate // the entries i+1 to the last of the i-th column - int k = std::min(i + 1, m - 1); + uint32_t k = std::min(i + 1, m - 1); tauq[i] = clarfg(m - i - 1, a_mat(i, i), &a_mat(k, i), a_mat.row_increment()); d[i] = a_mat(i, i).re; a_mat(i, i) = {1.0F, 0.0F}; // Apply householder reflectors from left if (i < n - 1) { - for (int col = i + 1; col < n; col++) { + for (uint32_t col = i + 1; col < n; col++) { // w = A(i:m-1, col) * conj(A(i:m-1, i)) armral_cmplx_f32_t w = {0.0F, 0.0F}; if (m > i) { @@ -579,31 +586,31 @@ armral_status armral_bidiagonalization(int m, int n, armral_cmplx_f32_t *a, // Transpose conjugate entries i+1 to the last // of the i-th row - for (int col = i + 1; col < n; col++) { + for (uint32_t col = i + 1; col < n; col++) { a_mat(i, col).im = -a_mat(i, col).im; } // Generate reflectors to annihilate the entries i+2 // to the last of the i-th row. - int j = std::min(i + 2, n - 1); + uint32_t j = std::min(i + 2, n - 1); taup[i] = clarfg(n - i - 2, a_mat(i, i + 1), &a_mat(i, j), a_mat.column_increment()); e[i] = a_mat(i, i + 1).re; a_mat(i, i + 1) = {1.0F, 0.0F}; // Apply the reflectors - for (int row = i + 1; row < m; row++) { + for (uint32_t row = i + 1; row < m; row++) { armral_cmplx_f32_t w = {0.0F, 0.0F}; - for (int col = i + 1; col < n; col++) { + for (uint32_t col = i + 1; col < n; col++) { w = mult_add_cf32(a_mat(row, col), a_mat(i, col), w); } auto tmp = mult_cf32(taup[i], w); - for (int col = i + 1; col < n; col++) { + for (uint32_t col = i + 1; col < n; col++) { a_mat(row, col) = mult_conj_sub_cf32(tmp, a_mat(i, col), a_mat(row, col)); } } // Conjugate transpose the reflectors - for (int col = i + 1; col < n; col++) { + for (uint32_t col = i + 1; col < n; col++) { a_mat(i, col).im = -a_mat(i, col).im; } a_mat(i, i + 1) = {e[i], 0.0F}; @@ -621,12 +628,11 @@ armral_status armral_bidiagonalization(int m, int n, armral_cmplx_f32_t *a, // left and right singular vectors are updated if required. // This algorithm is developed by G. H. Golub and C. Reinsch. // For more detail, the algorithm is well explained in -// "Singular Value Decomposition and Least Squares Solutions" +// "Singular Value Decomposition and Least Squares Solutions" // published in Numer. Math. 14, 403--420 (1970). -armral_status armral_svd_bidiagonal(bool gen_singular_vectors, int m, int n, - float32_t *d, float32_t *e, - armral_cmplx_f32_t *u, - armral_cmplx_f32_t *vt, int u_stride) { +armral_status svd_bidiagonal(bool gen_singular_vectors, uint32_t m, uint32_t n, + float32_t *d, float32_t *e, armral_cmplx_f32_t *u, + armral_cmplx_f32_t *vt, uint32_t u_stride) { if (m < n) { return ARMRAL_ARGUMENT_ERROR; @@ -634,7 +640,7 @@ armral_status armral_svd_bidiagonal(bool gen_singular_vectors, int m, int n, // Shift the off-diagonal elements down by 1 // This helps to have D[i] and E[i] as the i-th // column of the bidiagonal matrix B. - for (int i = n - 1; i > 0; i--) { + for (uint32_t i = n - 1; i > 0; i--) { e[i] = e[i - 1]; } e[0] = 0.0F; @@ -642,7 +648,7 @@ armral_status armral_svd_bidiagonal(bool gen_singular_vectors, int m, int n, // Compute the 1-norm of the bidiagonal matrix // for the computation of the stopping criteria. float32_t anorm = 0; - for (int i = 0; i < n; i++) { + for (uint32_t i = 0; i < n; i++) { float32_t tmp = std::abs(d[i]) + std::abs(e[i]); if (anorm < tmp) { anorm = tmp; @@ -650,19 +656,25 @@ armral_status armral_svd_bidiagonal(bool gen_singular_vectors, int m, int n, } float32_t tol = anorm * eps; - int maxiter = n * n; + uint32_t maxiter = n * n; // Loop over the columns column_major_matrix_view u_mat{u, u_stride}; column_major_matrix_view vt_mat{vt, n}; - for (int curr_col = n - 1; curr_col >= 0; curr_col--) { + + // n will always be >=1 because of the fast + // return in the top-level function + uint32_t curr_col = n; + do { + curr_col--; // iteration to annihilate the off-diagonal E[curr_col]. - int iter = 0; + uint32_t iter = 0; for (; iter <= maxiter; iter++) { bool diag_is_zero = false; - int next_col; + uint32_t next_col = curr_col + 1; // Check if an off-diagonal is zero. - for (next_col = curr_col; next_col >= 0; next_col--) { + do { + next_col--; if (std::abs(e[next_col]) < tol) { break; } @@ -671,8 +683,8 @@ armral_status armral_svd_bidiagonal(bool gen_singular_vectors, int m, int n, diag_is_zero = true; break; } - } - // If the diagonal D[next_col] = 0; then at least one singular + } while (next_col != 0); + // If the diagonal D[next_col] = 0; then at least one singular // value must be equal to zero. In the absence of roundoff error, // the matrix will break if a shift of zero is performed. // In this case, an extra sequence of Givens rotations is @@ -680,7 +692,7 @@ armral_status armral_svd_bidiagonal(bool gen_singular_vectors, int m, int n, if (diag_is_zero) { float32_t cs = 0.0; float32_t sn = 1.0; - for (int i = next_col; i < curr_col; i++) { + for (uint32_t i = next_col; i < curr_col; i++) { float32_t f = sn * e[i]; e[i] *= cs; if (std::abs(f) <= tol) { @@ -706,7 +718,7 @@ armral_status armral_svd_bidiagonal(bool gen_singular_vectors, int m, int n, if (gen_singular_vectors) { // For the sake of performance we copy data that is contiguous in // memory - for (int row = 0; row < m; row++) { + for (uint32_t row = 0; row < m; row++) { u_mat(row, curr_col).re = -u_mat(row, curr_col).re; u_mat(row, curr_col).im = -u_mat(row, curr_col).im; } @@ -739,7 +751,7 @@ armral_status armral_svd_bidiagonal(bool gen_singular_vectors, int m, int n, // successive Givens rotations from right then from left. float32_t c = 1.0F; float32_t s = 1.0F; - for (int i = next_col + 1; i <= curr_col; i++) { + for (uint32_t i = next_col + 1; i <= curr_col; i++) { g = e[i]; y = d[i]; h = s * g; @@ -771,12 +783,12 @@ armral_status armral_svd_bidiagonal(bool gen_singular_vectors, int m, int n, e[curr_col] = f; d[curr_col] = x; } - } + } while (curr_col != 0); // Sort the singular values in decreasing order // and the singular vectors if required. - for (int i = 0; i < n - 1; i++) { - int max_pos = i; - for (int j = i + 1; j < n; j++) { + for (uint32_t i = 0; i < n - 1; i++) { + uint32_t max_pos = i; + for (uint32_t j = i + 1; j < n; j++) { if (d[j] > d[max_pos]) { max_pos = j; } @@ -785,12 +797,12 @@ armral_status armral_svd_bidiagonal(bool gen_singular_vectors, int m, int n, std::swap(d[i], d[max_pos]); if (gen_singular_vectors) { // Swap corresponding columns in left singular vectors. - for (int row = 0; row < m; row++) { + for (uint32_t row = 0; row < m; row++) { std::swap(u_mat(row, i).re, u_mat(row, max_pos).re); std::swap(u_mat(row, i).im, u_mat(row, max_pos).im); } // Swap corresponding rows in right singular vectors. - for (int col = 0; col < n; col++) { + for (uint32_t col = 0; col < n; col++) { std::swap(vt_mat(i, col).re, vt_mat(max_pos, col).re); std::swap(vt_mat(i, col).im, vt_mat(max_pos, col).im); } @@ -806,10 +818,11 @@ struct apply_q_work_buffers { armral_cmplx_f32_t *q; }; -inline armral_status armral_apply_q(int m, int n, const armral_cmplx_f32_t *a, - const armral_cmplx_f32_t *tau, - armral_cmplx_f32_t *c, - apply_q_work_buffers work_buffers) { +inline armral_status apply_q(uint32_t m, uint32_t n, + const armral_cmplx_f32_t *a, + const armral_cmplx_f32_t *tau, + armral_cmplx_f32_t *c, + apply_q_work_buffers work_buffers) { if (m < n) { return ARMRAL_ARGUMENT_ERROR; } @@ -817,11 +830,16 @@ inline armral_status armral_apply_q(int m, int n, const armral_cmplx_f32_t *a, memcpy(work_buffers.q, a, (size_t)m * n * sizeof(armral_cmplx_f32_t)); column_major_matrix_view q_mat{work_buffers.q, m}; column_major_matrix_view c_mat{c, m}; - for (int i = n - 1; i >= 0; i--) { + + // n will always be >=1 because of the fast + // return in the top-level function + uint32_t i = n; + do { + i--; q_mat(i, i) = {1.0F, 0.0F}; // Apply reflector from the left to all columns // of C from row index i, to the end. - for (int col = 0; col < n; col++) { + for (uint32_t col = 0; col < n; col++) { armral_cmplx_f32_t w = {0.0F, 0.0F}; if (m > i) { cmplx_vecdot_conj_f32(m - i, &c_mat(i, col), &q_mat(i, i), &w); @@ -829,7 +847,7 @@ inline armral_status armral_apply_q(int m, int n, const armral_cmplx_f32_t *a, cmplx_axmy_f32(m - i, &q_mat(i, i), &tmp, &c_mat(i, col)); } } - } + } while (i != 0); return ARMRAL_SUCCESS; } @@ -839,7 +857,7 @@ inline armral_status armral_apply_q(int m, int n, const armral_cmplx_f32_t *a, // bidiagonal form, if m / n exceeds this value, // a QR factorization is used first to reduce the // matrix to a triangular form. -inline int threshold_svd_qr(bool vector_needed, int m, int n) { +inline uint32_t threshold_svd_qr(bool vector_needed, uint32_t m, uint32_t n) { float32_t crossover_point; if (vector_needed) { @@ -857,10 +875,10 @@ inline int threshold_svd_qr(bool vector_needed, int m, int n) { crossover_point = n * 1.6; } // return a crossover_point rounded toward zero - return (int)crossover_point; + return (uint32_t)crossover_point; } -// armral_svd computes the SVD decomposition +// qr_svd computes the SVD decomposition // of an m-by-n matrix A in 4 steps. // 1- QR factorization of A. // 2- Bidiagonalization of R. @@ -868,10 +886,9 @@ inline int threshold_svd_qr(bool vector_needed, int m, int n) { // 4- Update of the left singular vectors // with the orthogonal matrix from QR. template -armral_status armral_qr_svd(bool gen_singular_vect, int m, int n, - armral_cmplx_f32_t *a, float32_t *s, - armral_cmplx_f32_t *u, armral_cmplx_f32_t *vt, - Allocator &allocator) { +armral_status qr_svd(bool gen_singular_vect, uint32_t m, uint32_t n, + armral_cmplx_f32_t *a, float32_t *s, armral_cmplx_f32_t *u, + armral_cmplx_f32_t *vt, Allocator &allocator) { assert(m >= n && "Invalid arguments: m < n is not supported"); @@ -897,61 +914,62 @@ armral_status armral_qr_svd(bool gen_singular_vect, int m, int n, } // Perform the QR factorization of A. - armral_householder_qr(m, n, a, tau.get()); + householder_qr(m, n, a, tau.get()); // Extract the R. column_major_matrix_view r_mat{r.get(), n}; column_major_matrix_view a_mat{a, m}; - for (int j = 0; j < n; j++) { - for (int i = 0; i <= j; i++) { + for (uint32_t j = 0; j < n; j++) { + for (uint32_t i = 0; i <= j; i++) { r_mat(i, j) = a_mat(i, j); } } // Bidiagonalization of R. - armral_bidiagonalization(n, n, r.get(), s, e.get(), tauq.get(), taup.get()); + bidiagonalization(n, n, r.get(), s, e.get(), tauq.get(), taup.get()); // Generate left and right orthogonal vectors. if (maybe_u1.has_value()) { auto *u1 = maybe_u1.value().get(); // Generate Q, and store it in u1. - armral_assemble_q(n, n, r.get(), tauq.get(), u1); + assemble_q(n, n, r.get(), tauq.get(), u1); // Copy u1 in u column_major_matrix_view u_mat{u, m}; column_major_matrix_view u1_mat{u1, n}; - for (int j = 0; j < n; j++) { - for (int i = 0; i < n; i++) { + for (uint32_t j = 0; j < n; j++) { + for (uint32_t i = 0; i < n; i++) { u_mat(i, j) = u1_mat(i, j); } } // Initialize last n*(m-n) elements of u // to zero in case it is not. - int remainder = m - n; - for (int j = 0; j < n; j++) { + // m >=n is a prerequisite of this function so this + // will never be negative + uint32_t remainder = m - n; + for (uint32_t j = 0; j < n; j++) { memset(&u[n + j * u_mat.stride()], 0, remainder * sizeof(armral_cmplx_f32_t)); } // Generate P and store it in vt. - armral_assemble_p(n, n, r.get(), taup.get(), vt); + assemble_p(n, n, r.get(), taup.get(), vt); } // Compute the singular values // and singular vectors if required. // Note: U is treated as N-by-N, but still stored in an M-by-N matrix. - armral_svd_bidiagonal(gen_singular_vect, n, n, s, e.get(), u, vt, m); + svd_bidiagonal(gen_singular_vect, n, n, s, e.get(), u, vt, m); // Apply Q to U if (maybe_q.has_value()) { - armral_apply_q(m, n, a, tau.get(), u, {maybe_q.value().get()}); + apply_q(m, n, a, tau.get(), u, {maybe_q.value().get()}); } return ARMRAL_SUCCESS; } template -armral_status armral_svd(bool gen_singular_vect, int m, int n, - armral_cmplx_f32_t *a, float32_t *s, - armral_cmplx_f32_t *u, armral_cmplx_f32_t *vt, - Allocator &allocator) { +armral_status svd(bool gen_singular_vect, uint32_t m, uint32_t n, + armral_cmplx_f32_t *a, float32_t *s, armral_cmplx_f32_t *u, + armral_cmplx_f32_t *vt, Allocator &allocator) { // Bidiagonalization: A = Q * B * P^H. auto tauq = allocate_uninitialized(allocator, n); auto taup = allocate_uninitialized(allocator, n); @@ -961,62 +979,77 @@ armral_status armral_svd(bool gen_singular_vect, int m, int n, return ARMRAL_SUCCESS; } - armral_bidiagonalization(m, n, a, s, e.get(), tauq.get(), taup.get()); + bidiagonalization(m, n, a, s, e.get(), tauq.get(), taup.get()); // Generate left and right orthogonal vectors if required. if (gen_singular_vect) { // Generate Q and store it in u. - armral_assemble_q(m, n, a, tauq.get(), u); + assemble_q(m, n, a, tauq.get(), u); // Generate P and store it in vt. - armral_assemble_p(m, n, a, taup.get(), vt); + assemble_p(m, n, a, taup.get(), vt); } // Compute the singular values and singular vectors // if required. - armral_svd_bidiagonal(gen_singular_vect, m, n, s, e.get(), u, vt, m); + svd_bidiagonal(gen_singular_vect, m, n, s, e.get(), u, vt, m); return ARMRAL_SUCCESS; } -// armral_svd computes the SVD decomposition +// svd_cf32 computes the SVD decomposition // of an m-by-n matrix. It either performs // a direct SVD decomposition of the input matrix, // or performs QR factorization first followed // by the SVD of R depending on the ratio m/n. template -armral_status armral_svd_cf32(bool gen_singular_vect, int m, int n, - armral_cmplx_f32_t *a, float32_t *s, - armral_cmplx_f32_t *u, armral_cmplx_f32_t *vt, - Allocator &allocator) { - +armral_status svd_cf32(bool gen_singular_vect, uint32_t m, uint32_t n, + armral_cmplx_f32_t *a, float32_t *s, + armral_cmplx_f32_t *u, armral_cmplx_f32_t *vt, + Allocator &allocator) { // Call arm_qr_svd if m is much larger than n - int crossover_point = threshold_svd_qr(gen_singular_vect, m, n); + uint32_t crossover_point = threshold_svd_qr(gen_singular_vect, m, n); if (m > crossover_point) { - return armral_qr_svd(gen_singular_vect, m, n, a, s, u, vt, allocator); + return qr_svd(gen_singular_vect, m, n, a, s, u, vt, allocator); } - return armral_svd(gen_singular_vect, m, n, a, s, u, vt, allocator); + return svd(gen_singular_vect, m, n, a, s, u, vt, allocator); } } // anonymous namespace -armral_status armral_svd_cf32(bool gen_singular_vect, int m, int n, +armral_status armral_svd_cf32(bool gen_singular_vect, uint32_t m, uint32_t n, armral_cmplx_f32_t *a, float32_t *s, armral_cmplx_f32_t *u, armral_cmplx_f32_t *vt) { + // This function is only implemented for m >= n + if (m < n) { + return ARMRAL_ARGUMENT_ERROR; + } + // Trivial case: no work to do + if (m == 0 || n == 0) { + return ARMRAL_SUCCESS; + } heap_allocator allocator{}; - return armral_svd_cf32(gen_singular_vect, m, n, a, s, u, vt, allocator); + return svd_cf32(gen_singular_vect, m, n, a, s, u, vt, allocator); } -armral_status armral_svd_cf32_noalloc(bool gen_singular_vect, int m, int n, - armral_cmplx_f32_t *a, float32_t *s, - armral_cmplx_f32_t *u, +armral_status armral_svd_cf32_noalloc(bool gen_singular_vect, uint32_t m, + uint32_t n, armral_cmplx_f32_t *a, + float32_t *s, armral_cmplx_f32_t *u, armral_cmplx_f32_t *vt, void *buffer) { + // This function is only implemented for m >= n + if (m < n) { + return ARMRAL_ARGUMENT_ERROR; + } + // Trivial case: no work to do + if (m == 0 || n == 0) { + return ARMRAL_SUCCESS; + } buffer_bump_allocator allocator{buffer}; - return armral_svd_cf32(gen_singular_vect, m, n, a, s, u, vt, allocator); + return svd_cf32(gen_singular_vect, m, n, a, s, u, vt, allocator); } -uint32_t armral_svd_cf32_noalloc_buffer_size(bool gen_singular_vect, int m, - int n) { +uint32_t armral_svd_cf32_noalloc_buffer_size(bool gen_singular_vect, uint32_t m, + uint32_t n) { counting_allocator allocator{}; - (void)armral_svd_cf32(gen_singular_vect, m, n, nullptr, nullptr, nullptr, - nullptr, allocator); + (void)svd_cf32(gen_singular_vect, m, n, nullptr, nullptr, nullptr, nullptr, + allocator); return allocator.required_bytes(); } diff --git a/src/MatrixFactorizations/SVD/matrix_view.hpp b/src/MatrixFactorizations/SVD/matrix_view.hpp index 6747418..49eaf7d 100644 --- a/src/MatrixFactorizations/SVD/matrix_view.hpp +++ b/src/MatrixFactorizations/SVD/matrix_view.hpp @@ -5,31 +5,33 @@ #pragma once +#include + /* A non-owning column major view of a matrix to provide more convenient indexing. */ template struct column_major_matrix_view { - column_major_matrix_view(T *data, int stride) + column_major_matrix_view(T *data, uint32_t stride) : m_data(data), m_stride(stride) {} - T &operator()(int i, int j) { + T &operator()(uint32_t i, uint32_t j) { return m_data[i + stride() * j]; } - int stride() const { + uint32_t stride() const { return m_stride; } - int row_increment() const { + uint32_t row_increment() const { return 1; } - int column_increment() const { + uint32_t column_increment() const { return stride(); } private: T *const m_data; - const int m_stride; + const uint32_t m_stride; }; diff --git a/src/UpperPHY/ConvolutionalEncoder/arm_convolutional_decoder.cpp b/src/UpperPHY/ConvolutionalEncoder/arm_convolutional_decoder.cpp index c2c6210..bc5d238 100644 --- a/src/UpperPHY/ConvolutionalEncoder/arm_convolutional_decoder.cpp +++ b/src/UpperPHY/ConvolutionalEncoder/arm_convolutional_decoder.cpp @@ -3,8 +3,8 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "bit_utils.hpp" #include "utils/allocators.hpp" +#include "utils/bits_to_bytes.hpp" #include "convolutional_code_table.hpp" @@ -307,9 +307,9 @@ armral_status tail_biting_convolutional_decode_block( // == Output decoded stream == // Convert the bytes back to bits if (ro_tb_best_i != states) { // if TB path found - bytes_to_bits(k, &bytes_dst[ro_tb_best_i * k], dst); + armral::bytes_to_bits(k, &bytes_dst[ro_tb_best_i * k], dst); } else { - bytes_to_bits(k, &bytes_dst[ro_best_i * k], dst); + armral::bytes_to_bits(k, &bytes_dst[ro_best_i * k], dst); } return ARMRAL_SUCCESS; diff --git a/src/UpperPHY/LDPC/ldpc_coding.hpp b/src/UpperPHY/LDPC/ldpc_coding.hpp index 33c4576..c69f753 100644 --- a/src/UpperPHY/LDPC/ldpc_coding.hpp +++ b/src/UpperPHY/LDPC/ldpc_coding.hpp @@ -10,7 +10,7 @@ namespace armral::ldpc { constexpr uint32_t num_lifting_sets = 8; -uint32_t get_ldpc_lifting_index(uint32_t lifting_size); +uint32_t get_lifting_index(uint32_t lifting_size); template void decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, uint32_t z, diff --git a/src/UpperPHY/LDPC/ldpc_decoder.cpp b/src/UpperPHY/LDPC/ldpc_decoder.cpp index ba5297f..7cb98bd 100644 --- a/src/UpperPHY/LDPC/ldpc_decoder.cpp +++ b/src/UpperPHY/LDPC/ldpc_decoder.cpp @@ -3,9 +3,9 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ -#include "bit_utils.hpp" #include "ldpc_coding.hpp" #include "utils/allocators.hpp" +#include "utils/bits_to_bytes.hpp" #ifdef ARMRAL_ARCH_SVE #include @@ -89,7 +89,7 @@ public: } // Hard decode - llrs_to_bits(m_total_bits, m_llrs.get(), m_buffer.get()); + armral::llrs_to_bits(m_total_bits, m_llrs.get(), m_buffer.get()); // Generate the CRC parity bits uint64_t crc; @@ -119,7 +119,7 @@ template<> bool parity_check(const int8_t *llrs, uint32_t z, uint32_t lsi, const armral_ldpc_base_graph_t *graph, int32_t num_lanes, int32_t full_vec, - int32_t tail_size, int8_t *check_array) { + int32_t tail_size, int8_t *check) { // Loop through the rows in the base graph bool passed = true; for (uint32_t row = 0; row < graph->nrows && passed; ++row) { @@ -1342,7 +1342,7 @@ void armral::ldpc::decode_block(const int8_t *llrs, armral_ldpc_graph_t bg, uint8_t *data_out, Allocator &allocator) { // Get the base graph and the lifting size const auto *graph = armral_ldpc_get_base_graph(bg); - uint32_t lsi = armral::ldpc::get_ldpc_lifting_index(z); + uint32_t lsi = get_lifting_index(z); // Only allocate the CRC checker if necessary. std::optional> maybe_crc_checker; diff --git a/src/UpperPHY/LDPC/ldpc_encoder.cpp b/src/UpperPHY/LDPC/ldpc_encoder.cpp index 655d7cb..951a167 100644 --- a/src/UpperPHY/LDPC/ldpc_encoder.cpp +++ b/src/UpperPHY/LDPC/ldpc_encoder.cpp @@ -3,9 +3,9 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "bit_utils.hpp" #include "ldpc_coding.hpp" #include "utils/allocators.hpp" +#include "utils/bits_to_bytes.hpp" #ifdef ARMRAL_ARCH_SVE #include @@ -2356,10 +2356,37 @@ inline void calc_hdsm_rhs(uint32_t z, const uint8_t *parity_hdsm, #endif } +} // anonymous namespace + +namespace armral::ldpc { + +uint32_t get_lifting_index(uint32_t lifting_size) { + // Each lifting size is either a power of two, + // or an odd multiple (up to 15) of a power of two. Find the first odd + // number when shifting right, + // e.g. (112 -> 56 -> 28 -> 14 -> 7) + // then divide that by two to get the index from + // the mapping: + // 2 -> 0 + // 3 -> 1 + // 5 -> 2 + // 7 -> 3 + // 9 -> 4 + // 11 -> 5 + // 13 -> 6 + // 15 -> 7 + // Using the example above, 112 would then be mapped onto index set 3 + assert(lifting_size > 0); + auto lifting_set_index = lifting_size >> __builtin_ctz(lifting_size); + assert(lifting_set_index <= 15); + lifting_set_index >>= 1; + return lifting_set_index; +} + template -armral_status ldpc_encode_block(const uint8_t *data_in, armral_ldpc_graph_t bg, - uint32_t z, uint32_t len_filler_bits, - uint8_t *data_out, Allocator &allocator) { +armral_status encode_block(const uint8_t *data_in, armral_ldpc_graph_t bg, + uint32_t z, uint32_t len_filler_bits, + uint8_t *data_out, Allocator &allocator) { // Get a pointer to the graph to be working with const auto *graph = armral_ldpc_get_base_graph(bg); @@ -2380,7 +2407,7 @@ armral_status ldpc_encode_block(const uint8_t *data_in, armral_ldpc_graph_t bg, bytes_in.get()); // Get the lifting set index - auto lsi = armral::ldpc::get_ldpc_lifting_index(z); + auto lsi = get_lifting_index(z); // The encoding is done by computing: // 1- Parity bits for the high-density submatrix (hdsm) @@ -2417,15 +2444,15 @@ armral_status ldpc_encode_block(const uint8_t *data_in, armral_ldpc_graph_t bg, return ARMRAL_SUCCESS; } -} // anonymous namespace +} // namespace armral::ldpc armral_status armral_ldpc_encode_block(const uint8_t *data_in, armral_ldpc_graph_t bg, uint32_t z, uint32_t len_filler_bits, uint8_t *data_out) { heap_allocator allocator{}; - return ldpc_encode_block(data_in, bg, z, len_filler_bits, data_out, - allocator); + return armral::ldpc::encode_block(data_in, bg, z, len_filler_bits, data_out, + allocator); } armral_status @@ -2433,15 +2460,16 @@ armral_ldpc_encode_block_noalloc(const uint8_t *data_in, armral_ldpc_graph_t bg, uint32_t z, uint32_t len_filler_bits, uint8_t *data_out, void *buffer) { buffer_bump_allocator allocator{buffer}; - return ldpc_encode_block(data_in, bg, z, len_filler_bits, data_out, - allocator); + return armral::ldpc::encode_block(data_in, bg, z, len_filler_bits, data_out, + allocator); } uint32_t armral_ldpc_encode_block_noalloc_buffer_size(armral_ldpc_graph_t bg, uint32_t z, uint32_t len_filler_bits) { counting_allocator allocator{}; - (void)ldpc_encode_block(nullptr, bg, z, len_filler_bits, nullptr, allocator); + (void)armral::ldpc::encode_block(nullptr, bg, z, len_filler_bits, nullptr, + allocator); return allocator.required_bytes(); } @@ -2449,26 +2477,3 @@ const armral_ldpc_base_graph_t * armral_ldpc_get_base_graph(armral_ldpc_graph_t bg) { return bg == LDPC_BASE_GRAPH_1 ? &base_graph_1 : &base_graph_2; } - -uint32_t armral::ldpc::get_ldpc_lifting_index(uint32_t lifting_size) { - // Each lifting size is either a power of two, - // or an odd multiple (up to 15) of a power of two. Find the first odd - // number when shifting right, - // e.g. (112 -> 56 -> 28 -> 14 -> 7) - // then divide that by two to get the index from - // the mapping: - // 2 -> 0 - // 3 -> 1 - // 5 -> 2 - // 7 -> 3 - // 9 -> 4 - // 11 -> 5 - // 13 -> 6 - // 15 -> 7 - // Using the example above, 112 would then be mapped onto index set 3 - assert(lifting_size > 0); - auto lifting_set_index = lifting_size >> __builtin_ctz(lifting_size); - assert(lifting_set_index <= 15); - lifting_set_index >>= 1; - return lifting_set_index; -} diff --git a/src/UpperPHY/LDPC/ldpc_rate_matching.cpp b/src/UpperPHY/LDPC/ldpc_rate_matching.cpp index 9ab6760..56171f4 100644 --- a/src/UpperPHY/LDPC/ldpc_rate_matching.cpp +++ b/src/UpperPHY/LDPC/ldpc_rate_matching.cpp @@ -3,9 +3,9 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "bit_utils.hpp" #include "ldpc_rate_common.hpp" #include "utils/allocators.hpp" +#include "utils/bits_to_bytes.hpp" #include #include #include diff --git a/src/UpperPHY/Polar/arm_polar_decoder.cpp b/src/UpperPHY/Polar/arm_polar_decoder.cpp index b1db620..83eb412 100644 --- a/src/UpperPHY/Polar/arm_polar_decoder.cpp +++ b/src/UpperPHY/Polar/arm_polar_decoder.cpp @@ -30,8 +30,8 @@ inline void f_l(const int8_t *in, int8_t *out) { } template -inline void g_l(const int8_t *in, const uint8_t *dec, const uint8_t *hist, - int8_t *out) { +inline void g_l(const int8_t *in, const uint8_t *dec, + [[maybe_unused]] const uint8_t *hist, int8_t *out) { // Calculate beliefs for right children in the successive cancellation list // (SCL) algorithm: // g(a_h, b_h, c_i=0) = a_h + b_h diff --git a/src/UpperPHY/Polar/arm_polar_encoder.c b/src/UpperPHY/Polar/arm_polar_encoder.c index 5936f57..71ad88e 100644 --- a/src/UpperPHY/Polar/arm_polar_encoder.c +++ b/src/UpperPHY/Polar/arm_polar_encoder.c @@ -53,8 +53,8 @@ polar_encoding_algo_n128(const uint32_t *u, uint32_t *d128) { static inline void __attribute__((always_inline)) polar_encoding_algo_n256(const uint32_t *u, uint32_t *d256) { - uint32x4_t tmp_low = {}; - uint32x4_t tmp_hi = {}; + uint32x4_t tmp_low = {0}; + uint32x4_t tmp_hi = {0}; // Computing [dLow] = [uLow]*[G_256Low] polar_encoding_algo_n128(u, (uint32_t *)&tmp_low); @@ -69,8 +69,8 @@ polar_encoding_algo_n256(const uint32_t *u, uint32_t *d256) { static inline void __attribute__((always_inline)) polar_encoding_algo_n512(const uint32_t *u, uint32_t *d512) { - uint32x4_t tmp_low[2] = {}; - uint32x4_t tmp_hi[2] = {}; + uint32x4_t tmp_low[2] = {0}; + uint32x4_t tmp_hi[2] = {0}; // Computing [dLow] = [uLow]*[G_512Low] polar_encoding_algo_n256(u, (uint32_t *)&tmp_low); @@ -89,8 +89,8 @@ polar_encoding_algo_n512(const uint32_t *u, uint32_t *d512) { static inline void __attribute__((always_inline)) polar_encoding_algo_n1024(const uint32_t *u, uint32_t *d1024) { - uint32x4_t tmp_low[4] = {}; - uint32x4_t tmp_hi[4] = {}; + uint32x4_t tmp_low[4] = {0}; + uint32x4_t tmp_hi[4] = {0}; // Computing [dLow] = [uLow]*[G_1024Low] polar_encoding_algo_n512(u, (uint32_t *)tmp_low); diff --git a/src/UpperPHY/Polar/arm_polar_frozen_bits.cpp b/src/UpperPHY/Polar/arm_polar_frozen_bits.cpp index 6c5fe66..f670c9a 100644 --- a/src/UpperPHY/Polar/arm_polar_frozen_bits.cpp +++ b/src/UpperPHY/Polar/arm_polar_frozen_bits.cpp @@ -878,8 +878,10 @@ polar_frozen_mask_impl_repetition(uint32_t e, uint32_t k, uint8_t *frozen) { } // finally, set the indices for any remaining parity bits. - for (uint32_t i = n_pc_wm; i < n_pc; ++i) { - frozen[arrs->q[k + i]] = ARMRAL_POLAR_PARITY_BIT; + if constexpr (n_pc > 0) { + for (uint32_t i = n_pc_wm; i < n_pc; ++i) { + frozen[arrs->q[k + i]] = ARMRAL_POLAR_PARITY_BIT; + } } } diff --git a/src/UpperPHY/Turbo/arm_turbo_decoder.cpp b/src/UpperPHY/Turbo/arm_turbo_decoder.cpp index f2415fd..e9b89ec 100644 --- a/src/UpperPHY/Turbo/arm_turbo_decoder.cpp +++ b/src/UpperPHY/Turbo/arm_turbo_decoder.cpp @@ -8,11 +8,7 @@ #include "turbo_tables.hpp" #include "utils/allocators.hpp" -#ifdef ARMRAL_ARCH_SVE -#include "turbo_decoder_fp16.hpp" -#else -#include "turbo_decoder_fp32.hpp" -#endif +#include "arm_turbo_decoder.hpp" template void armral::turbo::decode_block( const int8_t *sys, const int8_t *par, const int8_t *itl, uint32_t k, diff --git a/src/UpperPHY/Turbo/turbo_decoder_fp16.hpp b/src/UpperPHY/Turbo/arm_turbo_decoder.hpp similarity index 56% rename from src/UpperPHY/Turbo/turbo_decoder_fp16.hpp rename to src/UpperPHY/Turbo/arm_turbo_decoder.hpp index b924436..dcd60a8 100644 --- a/src/UpperPHY/Turbo/turbo_decoder_fp16.hpp +++ b/src/UpperPHY/Turbo/arm_turbo_decoder.hpp @@ -1,6 +1,6 @@ /* Arm RAN Acceleration Library - SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates + SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #pragma once @@ -8,117 +8,83 @@ namespace { -struct float16x4x8_t { - float16x4_t val[8]; +struct int16x4x8_t { + int16x4_t val[8]; }; // With Turbo codes n (=k) is always divisible by 8 so we // do not have to worry about tail bits -inline void turbo_llrs_to_bits(uint32_t n, const float16x8_t *llr, +inline void turbo_llrs_to_bits(uint32_t n, const int16x8_t *llr, uint8_t *data_out) { uint32_t full_bytes = n >> 3; constexpr uint16x8_t ones = {128, 64, 32, 16, 8, 4, 2, 1}; for (uint32_t i = 0; i < full_bytes; ++i) { // The first bit to write in the byte is the most significant - uint16x8_t pred = vcltzq_f16(llr[i]); + uint16x8_t pred = vcltzq_s16(llr[i]); uint16x8_t mask = vandq_u16(pred, ones); data_out[i] = (uint8_t)vaddvq_u16(mask); } } -// Take the input int8_t LLRs and convert them to float16x8_ts -inline void convert_llrs(uint32_t k, const int8_t *llrs, - float16x8_t *llrs_f16) { - constexpr int8x16_t idx = {127, 0, 127, 1, 127, 2, 127, 3, - 127, 4, 127, 5, 127, 6, 127, 7}; +// Take the input int8_t LLRs and convert them to int16x8_ts +inline void convert_llrs(uint32_t k, const int8_t *llrs, int16x8_t *llrs_i16) { // With turbo codes k is always a multiple of 8 so we do 8 LLRs at a time for (uint32_t i = 0, j = 0; i < k; i += 8, j++) { - int8x8_t data = vld1_s8(&llrs[i]); - int16x8_t data_i16 = vreinterpretq_s16_s8(vtbl1q_s8(data, idx)); - llrs_f16[j] = vcvtq_n_f16_s16(data_i16, 8); - } -} - -// Calculate the PDF of the state transition probability on the assumption that -// we are operating on an AWGN channel: -// PDF = (x1/2 (l_uk + l_c*y1)) + (l_c/2 x2 y2) -// In our implementation we assume the channel reliability, l_c, -// has been prescaled by 1/2 to avoid doing so repeatedly here. -template -inline float16x4_t transition_pdf(float16x8_t l_c, float16x8_t y1, - float16x8_t y2) { - return vget_low_f16( - vmulq_f16(l_c, vaddq_f16(vmulq_n_f16(y1, (float16_t)x1), - vmulq_n_f16(y2, (float16_t)x2)))); -} - -template -inline float16x8_t transition_pdf(float16x8_t l_uk, float16x8_t l_c, - float16x8_t y1, float16x8_t y2) { - if constexpr (use_extrinsic) { - float16x8_t term1 = - vmulq_n_f16(vfmaq_f16(vmulq_n_f16(l_uk, 0.5F), l_c, y1), x1); - float16x8_t term2 = vmulq_f16(vmulq_n_f16(l_c, (float16_t)x2), y2); - return vaddq_f16(term1, term2); - } else { - return vmulq_f16(l_c, vaddq_f16(vmulq_n_f16(y1, (float16_t)x1), - vmulq_n_f16(y2, (float16_t)x2))); + llrs_i16[j] = vshll_n_s8(vld1_s8(&llrs[i]), 0); } } // Update the extrinsic information output from the decoding stage -// based on the computed LLRs, the old extrinsic information and the input. -inline void update_extrinsic(uint32_t len, const float16x8_t *llr, - float16x8_t *extrinsic, const float16x8_t *input) { +// based on the computed LLRs, the old extrinsic information and the input +inline void update_extrinsic(uint32_t len, const int16x8_t *llr, + int16x8_t *extrinsic, const int16x8_t *input) { for (uint32_t i = 0; i < len; i++) { - extrinsic[i] = vsubq_f16(vsubq_f16(llr[i], extrinsic[i]), input[i]); + extrinsic[i] = vqsubq_s16(vqsubq_s16(llr[i], extrinsic[i]), input[i]); } } // Calculate the trellis termination values. These are independent of the // extrinsic information and so can be done once without needing to be updated // on every iteration. -float16x8_t trellis_termination(const float16x8_t *sys, const float16x8_t *par, - uint32_t k8, float16x8_t l_c) { +int16x8_t trellis_termination(const int16x8_t *sys, const int16x8_t *par, + uint32_t k8, int16x8_t l_c) { // We handle the gammas for the trellis termination bits separately // as the state transitions are different. The x_{kl} are never 1 // here, because we always use inputs of 0 to drive the trellis back // to state 0 in the encoder, so we only need to consider a smaller // number of state transitions. We also do not have any extrinsic - // information. Because some of the gamma terms will - // always be -INFINITY (specifically indices [1] and [3]) we can forgo - // adding to them to beta or taking the max with them, compared with - // when we calculate beta in the main calculations. - float16x4_t pdf_00 = transition_pdf<1, 1>(l_c, sys[k8], par[k8]); - float16x4_t pdf_01 = transition_pdf<1, -1>(l_c, sys[k8], par[k8]); + // information. Because some of the gamma terms will always be + // -INFINITY (specifically indices [1] and [3]) we can forgo adding + // to them to beta or taking the max with them, compared with when + // we calculate beta in the main calculations. As above, we assume + // that the channel reliability parameter l_c/2 = 1. + int16x4_t pdf_00 = vget_low_s16(vqaddq_s16(sys[k8], par[k8])); + int16x4_t pdf_01 = vget_low_s16(vqsubq_s16(sys[k8], par[k8])); - float16x8_t g0102 = {pdf_00[1], pdf_01[1], pdf_00[1], pdf_01[1], - pdf_00[1], pdf_01[1], pdf_00[1], pdf_01[1]}; + int16x8_t g0102 = {pdf_00[1], pdf_01[1], pdf_00[1], pdf_01[1], + pdf_00[1], pdf_01[1], pdf_00[1], pdf_01[1]}; - float16x8_t b01 = {pdf_00[2], pdf_00[2], pdf_01[2], pdf_01[2], - pdf_00[2], pdf_00[2], pdf_01[2], pdf_01[2]}; + int16x8_t b01 = {pdf_00[2], pdf_00[2], pdf_01[2], pdf_01[2], + pdf_00[2], pdf_00[2], pdf_01[2], pdf_01[2]}; - float16x8_t beta_term = vaddq_f16(g0102, b01); + int16x8_t beta_term = vqaddq_s16(g0102, b01); - float16x8_t g = {pdf_00[0], pdf_01[0], pdf_00[0], pdf_01[0], - pdf_01[0], pdf_00[0], pdf_01[0], pdf_00[0]}; + int16x8_t g = {pdf_00[0], pdf_01[0], pdf_00[0], pdf_01[0], + pdf_01[0], pdf_00[0], pdf_01[0], pdf_00[0]}; - float64x2_t beta_term_f64 = vreinterpretq_f64_f16(beta_term); - beta_term_f64 = vsetq_lane_f64(beta_term_f64[0], beta_term_f64, 1); - float16x8_t b0123 = vzip1q_f16(vreinterpretq_f16_f64(beta_term_f64), - vreinterpretq_f16_f64(beta_term_f64)); + int16x8_t b0123 = vzip1q_s16(beta_term, beta_term); - return vaddq_f16(g, b0123); + return vqaddq_s16(g, b0123); } // A single max-log-MAP decoder that works on an array of systematic bits (sys), // an array of parity bits (par), and an array of extrinsic values from a // previous decoding stage (extrinsic) -void decode_step(const float16x8_t *sys, const float16x8_t *par, - const float16x8_t *extrinsic, uint32_t k8, float16x8_t *llr, - float16x8_t *alpha, float16x8_t beta_tail, float16x4x8_t *pdf4, - float16x8_t l_c) { +void decode_step(const int16x8_t *sys, const int16x8_t *par, + const int16x8_t *extrinsic, uint32_t k8, int16x8_t *llr, + int16x8_t *alpha, int16x8_t beta_tail, int16x4x8_t *pdf4, + int16x8_t l_c) { uint32_t k_idx; uint32_t kp1_idx; @@ -135,24 +101,31 @@ void decode_step(const float16x8_t *sys, const float16x8_t *par, // They can only ever be either 0 or 1 so we precompute the four possible // values in the exponential for x = (0,0), (0,1), (1,0) and (1,1). Note // that these 0s and 1s have to be converted to 1s and -1s to match the - // values in y + // values in y. // // The y_{kl} values are the observed systematic and parity inputs. - // These have potentially been perturbed by noise on the channel + // These have potentially been perturbed by noise on the channel. // // Although each of the 8 states of the encoder has in theory 8 // predecessor states, the encoder's structure means that not all state // transitions are possible. Each state actually only has 2 predecessor // states so we only have to compute 16 non-zero values for each input // LLR. - float16x8_t pdf_00 = - transition_pdf<1, 1, true>(extrinsic[i], l_c, sys[i], par[i]); - float16x8_t pdf_10 = - transition_pdf<-1, 1, true>(extrinsic[i], l_c, sys[i], par[i]); - float16x8_t pdf_01 = - transition_pdf<1, -1, true>(extrinsic[i], l_c, sys[i], par[i]); - float16x8_t pdf_11 = - transition_pdf<-1, -1, true>(extrinsic[i], l_c, sys[i], par[i]); + // + // We calculate the PDF of the state transition probability on the + // assumption that we are operating on an AWGN channel: + // PDF = (x1/2 (l_uk + l_c*y1)) + (l_c/2 x2 y2) + // where l_uk is the extrinsic information, y1 is the systematic + // input, and y2 is the parity input. We assume the channel + // reliability, l_c, is set such that l_c/2 = 1 and therefore omit + // it from the calculation. See arm_turbo_decoder.cpp for + // justification. + + int16x8_t term = vqaddq_s16(extrinsic[i] >> 1, sys[i]); + int16x8_t pdf_00 = vqaddq_s16(term, par[i]); + int16x8_t pdf_10 = vqsubq_s16(par[i], term); + int16x8_t pdf_01 = vqsubq_s16(term, par[i]); + int16x8_t pdf_11 = vqsubq_s16(vqnegq_s16(term), par[i]); // There is considerable duplication in the values we could store. For // example, for a single state the 16 gamma values are: @@ -162,49 +135,57 @@ void decode_step(const float16x8_t *sys, const float16x8_t *par, // gamma[g_k_idx+2] = {pdf_01[j], pdf_10[j], pdf_10[j], pdf_01[j]}; // gamma[g_k_idx+3] = {pdf_11[j], pdf_00[j], pdf_00[j], pdf_11[j]}; // - // We therefore choose to store the 4 unique pdf values (using st4) - // as this allows us to access the pdf values contiguously in the - // calculations needed for the alpha and beta values - vst4q_f16((float16_t *)&pdf4[i], - float16x8x4_t({pdf_00, pdf_10, pdf_01, pdf_11})); - - // Accumulate the state transition probabilities forwards through the - // state transition trellis starting from the known encoder start state 0 + // We therefore choose to store the 4 unique pdf values (using + // st4) as this allows us to access the pdf values contiguously in + // the calculations needed for the alpha and beta values. + vst4q_s16((int16_t *)&pdf4[i], + int16x8x4_t({pdf_00, pdf_10, pdf_01, pdf_11})); + + // Accumulate the state transition probabilities forwards through + // the state transition trellis starting from the known encoder + // start state 0. + + constexpr int8x16_t idx_0123321 = {0, 1, 2, 3, 4, 5, 6, 7, + 6, 7, 4, 5, 2, 3, 0, 1}; + + constexpr int8x16_t idx_32100123 = {6, 7, 4, 5, 2, 3, 0, 1, + 0, 1, 2, 3, 4, 5, 6, 7}; + for (uint32_t j = 0; j < 8; j++) { k_idx = 8 * i + j; kp1_idx = k_idx + 1; - float16x4_t fdp = vrev64_f16(pdf4[i].val[j]); - // We need g02 = {gamma[g_k_idx][0], gamma[g_k_idx + 1][0], // gamma[g_k_idx + 2][0], gamma[g_k_idx + 3][0], // gamma[g_k_idx][2], gamma[g_k_idx + 1][2], // gamma[g_k_idx + 2][2], gamma[g_k_idx + 3][2]}; - float16x8_t g02 = vcombine_f16(pdf4[i].val[j], fdp); + int16x8_t g02 = vreinterpretq_s16_s8( + vtbl1q_s8(vreinterpret_s8_s16(pdf4[i].val[j]), idx_0123321)); // We need a02 = {alpha[k_idx][0], alpha[k_idx][2], // alpha[k_idx + 1][0], alpha[k_idx + 1][2], // alpha[k_idx][0], alpha[k_idx][2], // alpha[k_idx + 1][0], alpha[k_idx + 1][2]}; - float16x8_t a02 = vuzp1q_f16(alpha[k_idx], alpha[k_idx]); - float16x8_t left = vaddq_f16(g02, a02); + int16x8_t a02 = vuzp1q_s16(alpha[k_idx], alpha[k_idx]); + int16x8_t left = vqaddq_s16(g02, a02); // This is g02 with the 64-bit elements swapped - float16x8_t g20 = vcombine_f16(fdp, pdf4[i].val[j]); + int16x8_t g20 = vreinterpretq_s16_s8( + vtbl1q_s8(vreinterpret_s8_s16(pdf4[i].val[j]), idx_32100123)); // We need a13 = {alpha[k_idx][1], alpha[k_idx][3], // alpha[k_idx + 1][1], alpha[k_idx + 1][3], // alpha[k_idx][1], alpha[k_idx][3], // alpha[k_idx + 1][1], alpha[k_idx + 1][3]}; - float16x8_t a13 = vuzp2q_f16(alpha[k_idx], alpha[k_idx]); - float16x8_t right = vaddq_f16(g20, a13); + int16x8_t a13 = vuzp2q_s16(alpha[k_idx], alpha[k_idx]); + int16x8_t right = vqaddq_s16(g20, a13); - alpha[kp1_idx] = vmaxq_f16(left, right); + alpha[kp1_idx] = vmaxq_s16(left, right); // Normalize alpha if (j % 4 == 0) { - float16x8_t alpha0 = vdupq_n_f16(alpha[kp1_idx][0]); - alpha[kp1_idx] = vsubq_f16(alpha[kp1_idx], alpha0); + int16x8_t alpha0 = vdupq_n_s16(alpha[kp1_idx][0]); + alpha[kp1_idx] = vqsubq_s16(alpha[kp1_idx], alpha0); } } } @@ -212,7 +193,7 @@ void decode_step(const float16x8_t *sys, const float16x8_t *par, // Accumulate the state transition probabilities backwards through the state // transition trellis starting from the beginning of the precomputed tail // and calculate the conditional probabilities of each bit being either 0 - // or 1 + // or 1. constexpr uint8x16_t idx_even_odd = {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15}; @@ -232,73 +213,66 @@ void decode_step(const float16x8_t *sys, const float16x8_t *par, constexpr uint8x16_t idx_1302 = {6, 7, 0, 1, 4, 5, 2, 3, 2, 3, 4, 5, 0, 1, 6, 7}; - float16x8_t beta_kp1 = beta_tail; + int16x8_t beta_kp1 = beta_tail; for (int32_t i = k8 - 1; i >= 0; i--) { - float16x8_t prob_0; - float16x8_t prob_1; + int16x8_t prob_0; + int16x8_t prob_1; for (int32_t j = 7; j >= 0; j--) { k_idx = 8 * i + j; // Normalize beta if (j % 4 == 0) { - float16x8_t beta0 = vdupq_n_f16(beta_kp1[0]); - beta_kp1 = vsubq_f16(beta_kp1, beta0); + int16x8_t beta0 = vdupq_n_s16(beta_kp1[0]); + beta_kp1 = vqsubq_s16(beta_kp1, beta0); } uint8x16_t pdf8_u8 = - vreinterpretq_u8_f16(vcombine_f16(pdf4[i].val[j], pdf4[i].val[j])); + vreinterpretq_u8_s16(vcombine_s16(pdf4[i].val[j], pdf4[i].val[j])); // g0213 = {pdf[0], pdf[3], pdf[1], pdf[2], // pdf[2], pdf[1], pdf[3], pdf[0]}; - float16x8_t g0213 = vreinterpretq_f16_u8(vqtbl1q_u8(pdf8_u8, idx_0213)); + int16x8_t g0213 = vreinterpretq_s16_u8(vqtbl1q_u8(pdf8_u8, idx_0213)); // Reverse 32-bit elements in g0213 // g1302 = {pdf[3], pdf[0], pdf[2], pdf[1], // pdf[1], pdf[2], pdf[0], pdf[3]}; - float16x8_t g1302 = vreinterpretq_f16_u8(vqtbl1q_u8(pdf8_u8, idx_1302)); + int16x8_t g1302 = vreinterpretq_s16_u8(vqtbl1q_u8(pdf8_u8, idx_1302)); // b0123 = {beta_kp1[0], beta_kp1[0], beta_kp1[1], beta_kp1[1], // beta_kp1[2], beta_kp1[2], beta_kp1[3], beta_kp1[3]}; // b4567 = {beta_kp1[4], beta_kp1[4], beta_kp1[5], beta_kp1[5], // beta_kp1[6], beta_kp1[6], beta_kp1[7], beta_kp1[7]}; - float16x8_t b0123 = vzip1q_f16(beta_kp1, beta_kp1); - float16x8_t b4567 = vzip2q_f16(beta_kp1, beta_kp1); + int16x8_t b0123 = vzip1q_s16(beta_kp1, beta_kp1); + int16x8_t b4567 = vzip2q_s16(beta_kp1, beta_kp1); - float16x8_t left = vaddq_f16(g0213, b0123); - float16x8_t right = vaddq_f16(g1302, b4567); + int16x8_t left = vqaddq_s16(g0213, b0123); + int16x8_t right = vqaddq_s16(g1302, b4567); - float16x8_t beta_k = vmaxq_f16(left, right); + int16x8_t beta_k = vmaxq_s16(left, right); // a0213 = {alpha[k_idx][0], alpha[k_idx][2], alpha[k_idx][4], alpha[k_idx][6], // alpha[k_idx][1], alpha[k_idx][3], alpha[k_idx][5], alpha[k_idx][7]}; - float16x8_t a0213 = vreinterpretq_f16_u8( - vqtbl1q_u8(vreinterpretq_u8_f16(alpha[k_idx]), idx_even_odd)); + int16x8_t a0213 = vreinterpretq_s16_u8( + vqtbl1q_u8(vreinterpretq_u8_s16(alpha[k_idx]), idx_even_odd)); // b0213_1302 = {beta_kp1[0], beta_kp1[5], beta_kp1[2], beta_kp1[7], // beta_kp1[4], beta_kp1[1], beta_kp1[6], beta_kp1[3]}; - float16x8_t b0213_1302 = vreinterpretq_f16_u8( - vqtbl1q_u8(vreinterpretq_u8_f16(beta_kp1), idx_05274163)); - float16x8_t b1302_0213 = vextq_f16(b0213_1302, b0213_1302, 4); + int16x8_t b0213_1302 = vreinterpretq_s16_u8( + vqtbl1q_u8(vreinterpretq_u8_s16(beta_kp1), idx_05274163)); + int16x8_t b1302_0213 = vextq_s16(b0213_1302, b0213_1302, 4); // g0101 = {pdf[0], pdf[2], pdf[2], pdf[0]}; - float16x8_t g0101 = vreinterpretq_f16_u8(vqtbl1q_u8(pdf8_u8, idx_0220)); - - float16x8_t left_right_0 = vaddq_f16(vaddq_f16(a0213, b0213_1302), g0101); - float16x4_t left_0 = vget_low_f16(left_right_0); - float16x4_t right_0 = vget_high_f16(left_right_0); + int16x8_t g0101 = vreinterpretq_s16_u8(vqtbl1q_u8(pdf8_u8, idx_0220)); + int16x8_t left_right_0 = vqaddq_s16(vqaddq_s16(a0213, b0213_1302), g0101); // g1010 = {pdf[3], pdf[1], pdf[1], pdf[3]}; - float16x8_t g1010 = vreinterpretq_f16_u8(vqtbl1q_u8(pdf8_u8, idx_3113)); + int16x8_t g1010 = vreinterpretq_s16_u8(vqtbl1q_u8(pdf8_u8, idx_3113)); + int16x8_t left_right_1 = vqaddq_s16(vqaddq_s16(a0213, b1302_0213), g1010); - float16x8_t left_right_1 = vaddq_f16(vaddq_f16(a0213, b1302_0213), g1010); - - float16x4_t left_1 = vget_low_f16(left_right_1); - float16x4_t right_1 = vget_high_f16(left_right_1); - - prob_0[j] = vmaxv_f16(vmax_f16(left_0, right_0)); - prob_1[j] = vmaxv_f16(vmax_f16(left_1, right_1)); + prob_0[j] = vmaxvq_s16(left_right_0); + prob_1[j] = vmaxvq_s16(left_right_1); // Store the current value of beta to use in the next // round of calculations @@ -306,7 +280,7 @@ void decode_step(const float16x8_t *sys, const float16x8_t *par, } // Calculate the LLRs - llr[i] = vsubq_f16(prob_0, prob_1); + llr[i] = vqsubq_s16(prob_0, prob_1); } } @@ -321,20 +295,20 @@ void armral::turbo::decode_block(const int8_t *sys, const int8_t *par, float32_t l_c, uint32_t max_iter, Allocator &allocator) { // This implements multiple steps of the max-log-MAP algorithm, - // which is an approximation to the MAP (BCJR) algorithm. - // It returns a hard decision rather than raw LLRs + // which is an approximation to the MAP (BCJR) algorithm. It returns + // a hard decision rather than raw LLRs. - // We will be working with float16x8_t, so work out how - // many of these will be needed to store k float16_ts. - // k is always a multiple of 8, so no need to worry about remainders. + // We will be working with int16x8_t, so work out how many of these + // will be needed to store k int16_ts. k is always a multiple of 8, + // so no need to worry about remainders. uint32_t k8 = k >> 3; - auto sys_f16 = allocate_uninitialized(allocator, k8 + 1); - auto par_f16 = allocate_uninitialized(allocator, k8 + 1); - auto itl_f16 = allocate_uninitialized(allocator, k8 + 1); + auto sys_s16 = allocate_uninitialized(allocator, k8 + 1); + auto par_s16 = allocate_uninitialized(allocator, k8 + 1); + auto itl_s16 = allocate_uninitialized(allocator, k8 + 1); auto perm_idx = allocate_uninitialized(allocator, k); - auto perm_sys = allocate_uninitialized(allocator, k8 + 1); + auto perm_sys = allocate_uninitialized(allocator, k8 + 1); struct perm_pair { uint16_t first; @@ -345,33 +319,33 @@ void armral::turbo::decode_block(const int8_t *sys, const int8_t *par, // Allocate space to hold the extrinsic and permuted extrinsic information // to be passed between the two decoders. Extrinsic is initially set to 0. - auto extrinsic = allocate_zeroed(allocator, k8); - auto perm_extrinsic = allocate_zeroed(allocator, k8); + auto extrinsic = allocate_zeroed(allocator, k8); + auto perm_extrinsic = allocate_zeroed(allocator, k8); // Allocate space for log likelihood ratios from both stages of decoding - auto l1_uky = allocate_uninitialized(allocator, k8); - auto l2_uky = allocate_uninitialized(allocator, k8); - auto prev_l2_uky = allocate_zeroed(allocator, k8); + auto l1_uky = allocate_uninitialized(allocator, k8); + auto l2_uky = allocate_uninitialized(allocator, k8); + auto prev_l2_uky = allocate_zeroed(allocator, k8); // Allocate space to hold alpha and gamma // alpha stores the forward-accumulated state probabilities for each decoded // bit, where the LTE encoder has 8 states and there are k+3 bits to decode // plus the starting condition - auto alpha = allocate_uninitialized(allocator, 8 * k8 + 1); + auto alpha = allocate_uninitialized(allocator, 8 * k8 + 1); // gamma stores the conditional state transition probabilities for each of the // k+3 bits to decode - auto gamma = allocate_uninitialized(allocator, k8); + auto gamma = allocate_uninitialized(allocator, k8); // NOTE: All allocations done. if constexpr (Allocator::is_counting) { return; } - // Convert our LLRs from int8_ts into float16_ts - convert_llrs(k, sys, sys_f16.get()); - convert_llrs(k, par, par_f16.get()); - convert_llrs(k, itl, itl_f16.get()); + // Convert our LLRs from int8_ts into int16_ts + convert_llrs(k, sys, sys_s16.get()); + convert_llrs(k, par, par_s16.get()); + convert_llrs(k, itl, itl_s16.get()); // Unperturb the trellis termination bits. They are transmitted as: // X0 Z1 X'0 Z'1 Z0 X2 Z'0 X'2 X1 Z2 X'1 @@ -382,20 +356,20 @@ void armral::turbo::decode_block(const int8_t *sys, const int8_t *par, // We append to the systematic (X), the parity (Z) and the interleaved parity // (Z') values here, and to the interleaved systematic values (X') further // down. - sys_f16[k8][0] = (float16_t)sys[k]; - sys_f16[k8][1] = (float16_t)itl[k]; - sys_f16[k8][2] = (float16_t)par[k + 1]; + sys_s16[k8][0] = (int16_t)sys[k]; + sys_s16[k8][1] = (int16_t)itl[k]; + sys_s16[k8][2] = (int16_t)par[k + 1]; - par_f16[k8][0] = (float16_t)par[k]; - par_f16[k8][1] = (float16_t)sys[k + 1]; - par_f16[k8][2] = (float16_t)itl[k + 1]; + par_s16[k8][0] = (int16_t)par[k]; + par_s16[k8][1] = (int16_t)sys[k + 1]; + par_s16[k8][2] = (int16_t)itl[k + 1]; - itl_f16[k8][0] = (float16_t)par[k + 2]; - itl_f16[k8][1] = (float16_t)sys[k + 3]; - itl_f16[k8][2] = (float16_t)itl[k + 3]; + itl_s16[k8][0] = (int16_t)par[k + 2]; + itl_s16[k8][1] = (int16_t)sys[k + 3]; + itl_s16[k8][2] = (int16_t)itl[k + 3]; - // Prescale l_c to avoid doing it repeatedly in the PDF calculations later. - const float16x8_t channel_reliability = vdupq_n_f16((float16_t)l_c / 2); + // Prescale l_c to avoid doing it repeatedly in the PDF calculations later + const int16x8_t channel_reliability = vdupq_n_s16((int16_t)l_c / 2); // Generate the permutation vector for the input value of k // Find the index into the array of parameter arrays corresponding @@ -413,16 +387,16 @@ void armral::turbo::decode_block(const int8_t *sys, const int8_t *par, // with the second decoder for (uint32_t i = 0; i < k8; i++) { for (uint32_t j = 0; j < 8; j++) { - perm_sys[i][j] = (float16_t)sys[perm_idx[(i * 8) + j]]; + perm_sys[i][j] = (int16_t)sys[perm_idx[(i * 8) + j]]; } } - perm_sys[k8][0] = (float16_t)sys[k + 2]; - perm_sys[k8][1] = (float16_t)itl[k + 2]; - perm_sys[k8][2] = (float16_t)par[k + 3]; + perm_sys[k8][0] = (int16_t)sys[k + 2]; + perm_sys[k8][1] = (int16_t)itl[k + 2]; + perm_sys[k8][2] = (int16_t)par[k + 3]; // Create a look-up of the permutation vector that maps [0,...k-1] indices // to vector element/vector lane pairs. This avoids having to a modulo - // operator every time we want to apply the permutation to vector elements + // operator every time we want to apply the permutation to vector elements. for (uint32_t i = 0; i < k; i++) { uint16_t vec_idx = perm_idx[i] / 8; uint16_t vec_lane = perm_idx[i] % 8; @@ -430,26 +404,26 @@ void armral::turbo::decode_block(const int8_t *sys, const int8_t *par, } // Initialize alpha - alpha[0] = vdupq_n_f16(-INFINITY); + alpha[0] = vdupq_n_s16(std::numeric_limits::min()); alpha[0][0] = 0; // Calculate the trellis termination state transition probabilities, which // do not require any extrinsic information - float16x8_t beta_tail = trellis_termination(sys_f16.get(), par_f16.get(), k8, - channel_reliability); - float16x8_t perm_beta_tail = trellis_termination( - perm_sys.get(), itl_f16.get(), k8, channel_reliability); + int16x8_t beta_tail = trellis_termination(sys_s16.get(), par_s16.get(), k8, + channel_reliability); + int16x8_t perm_beta_tail = trellis_termination(perm_sys.get(), itl_s16.get(), + k8, channel_reliability); // Initialize the number of iterations uint32_t num_iter = 0; while (num_iter < max_iter) { // Run the first decoder step - decode_step(sys_f16.get(), par_f16.get(), extrinsic.get(), k8, l1_uky.get(), + decode_step(sys_s16.get(), par_s16.get(), extrinsic.get(), k8, l1_uky.get(), alpha.get(), beta_tail, gamma.get(), channel_reliability); // Compute the new extrinsic information to pass into the second decoder - update_extrinsic(k8, l1_uky.get(), extrinsic.get(), sys_f16.get()); + update_extrinsic(k8, l1_uky.get(), extrinsic.get(), sys_s16.get()); // Need to unpermute extrinsic to match input to second decoder for (uint32_t i = 0; i < k8; i++) { @@ -460,7 +434,7 @@ void armral::turbo::decode_block(const int8_t *sys, const int8_t *par, } // Run the second decoder step - decode_step(perm_sys.get(), itl_f16.get(), perm_extrinsic.get(), k8, + decode_step(perm_sys.get(), itl_s16.get(), perm_extrinsic.get(), k8, l2_uky.get(), alpha.get(), perm_beta_tail, gamma.get(), channel_reliability); @@ -476,23 +450,18 @@ void armral::turbo::decode_block(const int8_t *sys, const int8_t *par, } // Compare this iteration's results with those from the previous iteration - float16_t max_abs_diff = 0.0; - float16_t max_abs_val = 0.0; + int16_t max_abs_diff = 0; for (uint32_t i = 0; i < k8; i++) { - float16_t abs_diff = vmaxvq_f16(vabdq_f16(l2_uky[i], prev_l2_uky[i])); - float16_t abs_val = vmaxvq_f16(vabsq_f16(l2_uky[i])); + int16_t abs_diff = + vmaxvq_s16(vqabsq_s16(vqsubq_s16(l2_uky[i], prev_l2_uky[i]))); if (abs_diff > max_abs_diff) { max_abs_diff = abs_diff; } - if (abs_val > max_abs_val) { - max_abs_val = abs_val; - } } // If we've converged, finish decoding if constexpr (check_convergence) { - if (max_abs_diff / max_abs_val < - std::numeric_limits::epsilon()) { + if (max_abs_diff == 0) { break; } } diff --git a/src/UpperPHY/Turbo/arm_turbo_rate_matching.cpp b/src/UpperPHY/Turbo/arm_turbo_rate_matching.cpp index cd41869..25b3ede 100644 --- a/src/UpperPHY/Turbo/arm_turbo_rate_matching.cpp +++ b/src/UpperPHY/Turbo/arm_turbo_rate_matching.cpp @@ -207,7 +207,7 @@ armral_status rate_matching(uint32_t d, uint32_t e, uint32_t rv, Allocator &allocator) { assert(d > 0); assert(e > 0); - assert(rv >= 0 && rv <= 3); + assert(rv <= 3); // The minimum number of rows which gives rtc * ctc >= d. const uint32_t rtc = (d + armral::turbo::ctc - 1) / armral::turbo::ctc; diff --git a/src/UpperPHY/Turbo/arm_turbo_rate_recovery.cpp b/src/UpperPHY/Turbo/arm_turbo_rate_recovery.cpp index 6f00ac7..02dfa28 100644 --- a/src/UpperPHY/Turbo/arm_turbo_rate_recovery.cpp +++ b/src/UpperPHY/Turbo/arm_turbo_rate_recovery.cpp @@ -136,7 +136,7 @@ armral_status rate_recovery(uint32_t d, uint32_t e, uint32_t rv, int8_t *dst2, Allocator &allocator) { assert(d > 0); assert(e > 0); - assert(rv >= 0 && rv <= 3); + assert(rv <= 3); // The minimum number of rows which gives rtc * ctc >= d. const uint32_t rtc = (d + armral::turbo::ctc - 1) / armral::turbo::ctc; diff --git a/src/UpperPHY/Turbo/turbo_decoder_fp32.hpp b/src/UpperPHY/Turbo/turbo_decoder_fp32.hpp deleted file mode 100644 index 1fdb679..0000000 --- a/src/UpperPHY/Turbo/turbo_decoder_fp32.hpp +++ /dev/null @@ -1,533 +0,0 @@ -/* - Arm RAN Acceleration Library - SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -*/ -#pragma once - -#include - -namespace { - -// With Turbo codes n (=k) is always divisible by 8 so we -// do not have to worry about tail bits -inline void turbo_llrs_to_bits(uint32_t n, const float32x4_t *llr, - uint8_t *data_out) { - uint32_t full_bytes = n >> 3; - constexpr uint32x4_t ones_0 = {128, 64, 32, 16}; - constexpr uint32x4_t ones_1 = {8, 4, 2, 1}; - - for (uint32_t i = 0; i < full_bytes; ++i) { - // The first bit to write in the byte is the most significant - uint32x4_t pred_0 = vcltzq_f32(llr[i * 2]); - uint32x4_t pred_1 = vcltzq_f32(llr[i * 2 + 1]); - uint32x4_t mask_0 = vandq_u32(pred_0, ones_0); - uint32x4_t mask_1 = vandq_u32(pred_1, ones_1); - uint32x4_t mask_2 = vorrq_u32(mask_0, mask_1); - data_out[i] = (uint8_t)vaddvq_u32(mask_2); - } -} - -// Take the input int8_t LLRs and convert them to float32x4_ts -inline void convert_llrs(uint32_t k, const int8_t *llrs, - float32x4_t *llrs_f32) { - constexpr int8x16_t idx_0 = {127, 127, 127, 0, 127, 127, 127, 1, - 127, 127, 127, 2, 127, 127, 127, 3}; - constexpr int8x16_t idx_1 = {127, 127, 127, 4, 127, 127, 127, 5, - 127, 127, 127, 6, 127, 127, 127, 7}; - // With turbo codes k is always a multiple of 8 so we do 8 LLRs at a time - for (uint32_t i = 0, j = 0; i < k; i += 8, j += 2) { - int8x8_t data = vld1_s8(&llrs[i]); - int32x4_t ldata = vreinterpretq_s32_s8(vtbl1q_s8(data, idx_0)); - int32x4_t hdata = vreinterpretq_s32_s8(vtbl1q_s8(data, idx_1)); - llrs_f32[j] = vcvtq_n_f32_s32(ldata, 24); - llrs_f32[j + 1] = vcvtq_n_f32_s32(hdata, 24); - } -} - -// Calculate the PDF of the state transition probability on the assumption that -// we are operating on an AWGN channel: -// PDF = (x1/2 (l_uk + l_c*y1)) + (l_c/2 x2 y2) -// In our implementation we assume the channel reliability, l_c, -// has been prescaled by 1/2 to avoid doing so repeatedly here. -template -inline float32x4_t transition_pdf(float32x4_t l_uk, float32x4_t l_c, - float32x4_t y1, float32x4_t y2) { - if constexpr (use_extrinsic) { - float32x4_t term1 = - vmulq_n_f32(vfmaq_f32(vmulq_n_f32(l_uk, 0.5F), l_c, y1), x1); - float32x4_t term2 = vmulq_f32(vmulq_n_f32(l_c, (float32_t)x2), y2); - return vaddq_f32(term1, term2); - } else { - return vmulq_f32(l_c, vaddq_f32(vmulq_n_f32(y1, (float32_t)x1), - vmulq_n_f32(y2, (float32_t)x2))); - } -} - -// Update the extrinsic information output from the decoding stage -// based on the computed LLRs, the old extrinsic information and the input. -inline void update_extrinsic(uint32_t len, const float32x4_t *llr, - float32x4_t *extrinsic, const float32x4_t *input) { - for (uint32_t i = 0; i < len; i++) { - extrinsic[i] = vsubq_f32(vsubq_f32(llr[i], extrinsic[i]), input[i]); - } -} - -// Calculate the trellis termination values. These are independent of the -// extrinsic information and so can be done once without needing to be updated -// on every iteration. -void trellis_termination(const float32x4_t *sys, const float32x4_t *par, - uint32_t k4, float32x4_t l_c, float32x4_t *beta_out) { - // We handle the gammas for the trellis termination bits separately - // as the state transitions are different. The x_{kl} are never 1 - // here, because we always use inputs of 0 to drive the trellis back - // to state 0 in the encoder, so we only need to consider a smaller - // number of state transitions. We also do not have any extrinsic - // information. Because some of the gamma terms will - // always be -INFINITY (specifically indices [1] and [3]) we can forgo - // adding to them to beta or taking the max with them, compared with - // when we calculate beta in the main calculations. - float32x4_t unused_extrinsic = {0}; - float32x4_t pdf_00 = - transition_pdf<1, 1, false>(unused_extrinsic, l_c, sys[k4], par[k4]); - float32x4_t pdf_01 = - transition_pdf<1, -1, false>(unused_extrinsic, l_c, sys[k4], par[k4]); - - // We need b01 = {pdf_00[2], pdf_00[2], pdf_01[2], pdf_01[2]} - float32x4_t pdf_uzp1 = vuzp1q_f32(pdf_00, pdf_01); - float32x4_t b01 = vtrn2q_f32(pdf_uzp1, pdf_uzp1); - - // We need g01_02 = {pdf_00[1], pdf_01[1], pdf_00[1], pdf_01[1]}; - float32x4_t pdf_uzp2 = vuzp2q_f32(pdf_00, pdf_01); - float32x4_t g01_02 = vuzp1q_f32(pdf_uzp2, pdf_uzp2); - - float32x4_t beta_term = vaddq_f32(g01_02, b01); - - // We need g01_02_1 = {pdf_00[0], pdf_01[0], pdf_00[0], pdf_01[0]}; - float32x4_t g01_02_1 = vuzp1q_f32(pdf_uzp1, pdf_uzp1); - - // We need b01_1 = {beta_term[0], beta_term[0], beta_term[1], beta_term[1]}; - float32x4_t b01_1 = vzip1q_f32(beta_term, beta_term); - beta_out[0] = vaddq_f32(g01_02_1, b01_1); - - // We need g23_02_1 = {pdf_01[0], pdf_00[0], pdf_01[0], pdf_00[0]}; - float32x4_t g23_02_1 = vrev64q_f32(g01_02_1); - - // We need b23_1 = {beta_term[2], beta_term[2], beta_term[3], beta_term[3]}; - float32x4_t b23_1 = vzip2q_f32(beta_term, beta_term); - beta_out[1] = vaddq_f32(g23_02_1, b23_1); -} - -// A single max-log-MAP decoder that works on an array of systematic bits (sys), -// an array of parity bits (par), and an array of extrinsic values from a -// previous decoding stage (extrinsic) -void decode_step(const float32x4_t *sys, const float32x4_t *par, - const float32x4_t *extrinsic, uint32_t k4, float32x4_t *llr, - float32x4_t *alpha, const float32x4_t *beta_tail, - float32x4x4_t *pdf4, float32x4_t l_c) { - uint32_t k_idx; - uint32_t kp1_idx; - - constexpr uint8x16_t rev_idx = {12, 13, 14, 15, 8, 9, 10, 11, - 4, 5, 6, 7, 0, 1, 2, 3}; - - // Start by computing the non-zero conditional state transition probabilities - // from state s' to state s for every k, denoted gamma_k(s',s). In general for - // an AWGN channel (ignoring extrinsic information in l_uk): - // gamma_k(s',s) = exp(L_c / 2 \sum_{l=1}^{n} x_{kl} y_{kl}) - // Here there are only 2 possible state transitions into each state - // (corresponding to encoding a 0 bit or a 1 bit) so the summation only has 2 - // terms. - for (uint32_t i = 0; i < k4; i++) { - // The x_{kl} values are the actual systematic and parity values that - // would result from the encoder having transited from state s' to s. - // They can only ever be either 0 or 1 so we precompute the four possible - // values in the exponential for x = (0,0), (0,1), (1,0) and (1,1). Note - // that these 0s and 1s have to be converted to 1s and -1s to match the - // values in y - // - // The y_{kl} values are the observed systematic and parity inputs. - // These have potentially been perturbed by noise on the channel - // - // Although each of the 8 states of the encoder has in theory 8 - // predecessor states, the encoder's structure means that not all state - // transitions are possible. Each state actually only has 2 predecessor - // states so we only have to compute 16 non-zero values for each input - // LLR. - float32x4_t pdf_00 = - transition_pdf<1, 1, true>(extrinsic[i], l_c, sys[i], par[i]); - float32x4_t pdf_10 = - transition_pdf<-1, 1, true>(extrinsic[i], l_c, sys[i], par[i]); - float32x4_t pdf_01 = - transition_pdf<1, -1, true>(extrinsic[i], l_c, sys[i], par[i]); - float32x4_t pdf_11 = - transition_pdf<-1, -1, true>(extrinsic[i], l_c, sys[i], par[i]); - - // There is considerable duplication in the values we could store. For - // example, for a single state the 16 gamma values are: - // - // gamma[g_k_idx] = {pdf_00[j], pdf_11[j], pdf_11[j], pdf_00[j]}; - // gamma[g_k_idx+1] = {pdf_10[j], pdf_01[j], pdf_01[j], pdf_10[j]}; - // gamma[g_k_idx+2] = {pdf_01[j], pdf_10[j], pdf_10[j], pdf_01[j]}; - // gamma[g_k_idx+3] = {pdf_11[j], pdf_00[j], pdf_00[j], pdf_11[j]}; - // - // We therefore choose to store the 4 unique pdf values (using st4) - // as this allows us to access the pdf values contiguously in the - // calculations needed for the alpha and beta values - vst4q_f32((float32_t *)&pdf4[i], - float32x4x4_t({pdf_00, pdf_10, pdf_01, pdf_11})); - - // Accumulate the state transition probabilities forwards through the - // state transition trellis starting from the known encoder start state 0 - for (uint32_t j = 0; j < 4; j++) { - k_idx = 8 * i + j * 2; - kp1_idx = k_idx + 2; - - // We need g0 = {gamma[g_k_idx][0], gamma[g_k_idx + 1][0], - // gamma[g_k_idx + 2][0], gamma[g_k_idx + 3][0]}; - // a02 = {alpha[k_idx][0], alpha[k_idx][2], - // alpha[k_idx + 1][0], alpha[k_idx + 1][2]}; - float32x4_t g0 = pdf4[i].val[j]; - float32x4_t a02 = vuzp1q_f32(alpha[k_idx], alpha[k_idx + 1]); - float32x4_t left_1 = vaddq_f32(g0, a02); - // We need g2 = {gamma[g_k_idx][2], gamma[g_k_idx + 1][2], - // gamma[g_k_idx + 2][2], gamma[g_k_idx + 3][2]}; - // a13 = {alpha[k_idx][1], alpha[k_idx][3], - // alpha[k_idx + 1][1], alpha[k_idx + 1][3]}; - float32x4_t g2 = vreinterpretq_f32_u8( - vqtbl1q_u8(vreinterpretq_u8_f32(pdf4[i].val[j]), rev_idx)); - float32x4_t a13 = vuzp2q_f32(alpha[k_idx], alpha[k_idx + 1]); - float32x4_t right_1 = vaddq_f32(g2, a13); - alpha[kp1_idx] = vmaxq_f32(left_1, right_1); - - // We need g1 = {gamma[g_k_idx][1], gamma[g_k_idx + 1][1], - // gamma[g_k_idx + 2][1], gamma[g_k_idx + 3][1]}; - // which is g2 above - float32x4_t left_2 = vaddq_f32(g2, a02); - // We need g3 = {gamma[g_k_idx][3], gamma[g_k_idx + 1][3], - // gamma[g_k_idx + 2][3], gamma[g_k_idx + 3][3]}; - // which is g0 above - float32x4_t right_2 = vaddq_f32(g0, a13); - alpha[kp1_idx + 1] = vmaxq_f32(left_2, right_2); - } - } - - // Accumulate the state transition probabilities backwards through the state - // transition trellis starting from the beginning of the precomputed tail - // and calculate the conditional probabilities of each bit being either 0 - // or 1 - constexpr uint8x16_t idx_0312 = {0, 1, 2, 3, 12, 13, 14, 15, - 4, 5, 6, 7, 8, 9, 10, 11}; - constexpr uint8x16_t idx_3021 = {12, 13, 14, 15, 0, 1, 2, 3, - 8, 9, 10, 11, 4, 5, 6, 7}; - constexpr uint8x16_t idx_2130 = {8, 9, 10, 11, 4, 5, 6, 7, - 12, 13, 14, 15, 0, 1, 2, 3}; - constexpr uint8x16_t idx_1203 = {4, 5, 6, 7, 8, 9, 10, 11, - 0, 1, 2, 3, 12, 13, 14, 15}; - constexpr uint8x16_t idx_0220 = {0, 1, 2, 3, 8, 9, 10, 11, - 8, 9, 10, 11, 0, 1, 2, 3}; - constexpr uint8x16_t idx_3113 = {12, 13, 14, 15, 4, 5, 6, 7, - 4, 5, 6, 7, 12, 13, 14, 15}; - - float32x4x2_t beta_k; - float32x4x2_t beta_kp1 = {beta_tail[0], beta_tail[1]}; - - for (int32_t i = k4 - 1; i >= 0; i--) { - float32x4_t prob_0; - float32x4_t prob_1; - for (int32_t j = 3; j >= 0; j--) { - k_idx = 8 * i + j * 2; - - // We need g01_02 = {gamma[g_k_idx][0], gamma[g_k_idx][2], - // gamma[g_k_idx + 1][0], gamma[g_k_idx + 1][2]}; - // b01 = {beta[b_kp1_idx][0], beta[b_kp1_idx][0], - // beta[b_kp1_idx][1], beta[b_kp1_idx][1]}; - float32x4_t g01_02 = vreinterpretq_f32_u8( - vqtbl1q_u8(vreinterpretq_u8_f32(pdf4[i].val[j]), idx_0312)); - float32x4_t b01 = vzip1q_f32(beta_kp1.val[0], beta_kp1.val[0]); - float32x4_t left_1 = vaddq_f32(g01_02, b01); - - // We need g13 = {gamma[g_k_idx][1], gamma[g_k_idx][3], - // gamma[g_k_idx + 1][1], gamma[g_k_idx + 1][3]}; - // bp1_01 = {beta[b_kp1_idx + 1][0], beta[b_kp1_idx + 1][0], - // beta[b_kp1_idx + 1][1], beta[b_kp1_idx + 1][1]}; - float32x4_t g13 = vreinterpretq_f32_u8( - vqtbl1q_u8(vreinterpretq_u8_f32(pdf4[i].val[j]), idx_3021)); - float32x4_t bp1_01 = vzip1q_f32(beta_kp1.val[1], beta_kp1.val[1]); - float32x4_t right_1 = vaddq_f32(g13, bp1_01); - beta_k.val[0] = vmaxq_f32(left_1, right_1); - - // We need g23_02 = {gamma[g_k_idx + 2][0], gamma[g_k_idx + 2][2], - // gamma[g_k_idx + 3][0], gamma[g_k_idx + 3][2]}; - // We need b23 = {beta[b_kp1_idx][2], beta[b_kp1_idx][2], - // beta[b_kp1_idx][3], beta[b_kp1_idx][3]}; - float32x4_t g23_02 = vreinterpretq_f32_u8( - vqtbl1q_u8(vreinterpretq_u8_f32(pdf4[i].val[j]), idx_2130)); - float32x4_t b23 = vzip2q_f32(beta_kp1.val[0], beta_kp1.val[0]); - float32x4_t left_2 = vaddq_f32(g23_02, b23); - - // We need g23_13 = {gamma[g_k_idx + 2][1], gamma[g_k_idx + 2][3], - // gamma[g_k_idx + 3][1], gamma[g_k_idx + 3][3]}; - // bp1_23 = {beta[b_kp1_idx + 1][2], beta[b_kp1_idx + 1][2], - // beta[b_kp1_idx + 1][3], beta[b_kp1_idx + 1][3]}; - float32x4_t g23_13 = vreinterpretq_f32_u8( - vqtbl1q_u8(vreinterpretq_u8_f32(pdf4[i].val[j]), idx_1203)); - float32x4_t bp1_23 = vzip2q_f32(beta_kp1.val[1], beta_kp1.val[1]); - float32x4_t right_2 = vaddq_f32(g23_13, bp1_23); - beta_k.val[1] = vmaxq_f32(left_2, right_2); - - // We need a02 = {alpha[k_idx][0], alpha[k_idx][2], - // alpha[k_idx + 1][0], alpha[k_idx + 1][2]}; - // a13 = {alpha[k_idx][1], alpha[k_idx][3], - // alpha[k_idx + 1][1], alpha[k_idx + 1][3]}; - // b02_13 = {beta[b_kp1_idx][0], beta[b_kp1_idx + 1][1], - // beta[b_kp1_idx][2], beta[b_kp1_idx + 1][3]}; - // b13_02 = {beta[b_kp1_idx + 1][0], beta[b_kp1_idx][1], - // beta[b_kp1_idx + 1][2], beta[b_kp1_idx][3]}; - float32x4_t a02 = vuzp1q_f32(alpha[k_idx], alpha[k_idx + 1]); - float32x4_t a13 = vuzp2q_f32(alpha[k_idx], alpha[k_idx + 1]); - float32x4_t b02_13 = - vtrn2q_f32(vrev64q_f32(beta_kp1.val[0]), beta_kp1.val[1]); - float32x4_t b13_02 = - vtrn2q_f32(vrev64q_f32(beta_kp1.val[1]), beta_kp1.val[0]); - - // Find the most probable path in which bit i was a 0 - // We need g01_01 = {gamma[g_k_idx][0], gamma[g_k_idx + 1][1], - // gamma[g_k_idx + 2][0], gamma[g_k_idx + 3][1]}; - float32x4_t g01_01 = vreinterpretq_f32_u8( - vqtbl1q_u8(vreinterpretq_u8_f32(pdf4[i].val[j]), idx_0220)); - left_1 = vaddq_f32(vaddq_f32(a02, b02_13), g01_01); - right_1 = vaddq_f32(vaddq_f32(a13, b13_02), g01_01); - prob_0[j] = vmaxvq_f32(vmaxq_f32(left_1, right_1)); - - // Find the most probable path in which bit i was a 1 - // We need g10_10 = {gamma[g_k_idx][1], gamma[g_k_idx + 1][0], - // gamma[g_k_idx + 2][1], gamma[g_k_idx + 3][0]}; - float32x4_t g10_10 = vreinterpretq_f32_u8( - vqtbl1q_u8(vreinterpretq_u8_f32(pdf4[i].val[j]), idx_3113)); - left_2 = vaddq_f32(vaddq_f32(a02, b13_02), g10_10); - right_2 = vaddq_f32(vaddq_f32(a13, b02_13), g10_10); - prob_1[j] = vmaxvq_f32(vmaxq_f32(left_2, right_2)); - - // Store the current value of beta to use in the next - // round of calculations - beta_kp1 = beta_k; - } - - // Calculate the LLRs - llr[i] = vsubq_f32(prob_0, prob_1); - } -} - -} // namespace - -// The template parameter allows us to disable checking for convergence (and -// thus terminating the iterations early) so we always run a fixed number of -// iterations in our benchmarking -template -void armral::turbo::decode_block(const int8_t *sys, const int8_t *par, - const int8_t *itl, uint32_t k, uint8_t *dst, - float32_t l_c, uint32_t max_iter, - Allocator &allocator) { - // This implements multiple steps of the max-log-MAP algorithm, - // which is an approximation to the MAP (BCJR) algorithm. - // It returns a hard decision rather than raw LLRs - - // We will be working with float32x4_t, so work out how - // many of these will be needed to store k float32_ts. - // k is always a multiple of 8, so no need to worry about remainders. - uint32_t k4 = k >> 2; - - auto sys_f32 = allocate_uninitialized(allocator, k4 + 1); - auto par_f32 = allocate_uninitialized(allocator, k4 + 1); - auto itl_f32 = allocate_uninitialized(allocator, k4 + 1); - - auto perm_idx = allocate_uninitialized(allocator, k); - auto perm_sys = allocate_uninitialized(allocator, k4 + 1); - - struct perm_pair { - uint16_t first; - uint16_t second; - }; - - auto perm_lookup = allocate_uninitialized(allocator, k); - - // Allocate space to hold the extrinsic and permuted extrinsic information - // to be passed between the two decoders. Extrinsic is initially set to 0. - auto extrinsic = allocate_zeroed(allocator, k4); - auto perm_extrinsic = allocate_zeroed(allocator, k4); - - // Allocate space for log likelihood ratios from both stages of decoding - auto l1_uky = allocate_uninitialized(allocator, k4); - auto l2_uky = allocate_uninitialized(allocator, k4); - auto prev_l2_uky = allocate_zeroed(allocator, k4); - - // Allocate space to hold alpha and gamma - // alpha stores the forward-accumulated state probabilities for each decoded - // bit, where the LTE encoder has 8 states and there are k+3 bits to decode - // plus the starting condition - auto alpha = allocate_uninitialized(allocator, 8 * k4 + 2); - // gamma stores the conditional state transition probabilities for each of the - // k+3 bits to decode - auto gamma = allocate_uninitialized(allocator, k4); - - // NOTE: All allocations done. - if constexpr (Allocator::is_counting) { - return; - } - - // Convert our LLRs from int8_ts into float32_ts - convert_llrs(k, sys, sys_f32.get()); - convert_llrs(k, par, par_f32.get()); - convert_llrs(k, itl, itl_f32.get()); - - // Unperturb the trellis termination bits. They are transmitted as: - // X0 Z1 X'0 Z'1 Z0 X2 Z'0 X'2 X1 Z2 X'1 - // Z'2 - // but need to appended to the inputs as: - // X0 X1 X2 Z0 Z1 Z2 X'0 X'1 X'2 Z'0 Z'1 - // Z'2 - // We append to the systematic (X), the parity (Z) and the interleaved parity - // (Z') values here, and to the interleaved systematic values (X') further - // down. - sys_f32[k4][0] = (float32_t)sys[k]; - sys_f32[k4][1] = (float32_t)itl[k]; - sys_f32[k4][2] = (float32_t)par[k + 1]; - - par_f32[k4][0] = (float32_t)par[k]; - par_f32[k4][1] = (float32_t)sys[k + 1]; - par_f32[k4][2] = (float32_t)itl[k + 1]; - - itl_f32[k4][0] = (float32_t)par[k + 2]; - itl_f32[k4][1] = (float32_t)sys[k + 3]; - itl_f32[k4][2] = (float32_t)itl[k + 3]; - - // Prescale l_c to avoid doing it repeatedly in the PDF calculations later. - const float32x4_t channel_reliability = vdupq_n_f32(l_c / 2); - - // Generate the permutation vector for the input value of k - // Find the index into the array of parameter arrays corresponding - // to the current k. Subtract 40 because k=40 is the lowest value. - int param_idx = armral::turbo::perm_params_lookup[(k - 40) >> 3]; - // and extract the correct values of f1 and f2 to build the - // interleaving polynomial - uint16_t f1 = armral::turbo::perm_params[param_idx][0]; - uint16_t f2 = armral::turbo::perm_params[param_idx][1]; - for (uint32_t i = 0; i < k; i++) { - perm_idx[i] = generate_perm_idx(i, f1, f2, k); - } - - // Create a permuted version of the systematic output for use - // with the second decoder - for (uint32_t i = 0; i < k4; i++) { - for (uint32_t j = 0; j < 4; j++) { - perm_sys[i][j] = (float32_t)sys[perm_idx[(i * 4) + j]]; - } - } - perm_sys[k4][0] = (float32_t)sys[k + 2]; - perm_sys[k4][1] = (float32_t)itl[k + 2]; - perm_sys[k4][2] = (float32_t)par[k + 3]; - - // Create a look-up of the permutation vector that maps [0,...k-1] indices - // to vector element/vector lane pairs. This avoids having to a modulo - // operator every time we want to apply the permutation to vector elements - for (uint32_t i = 0; i < k; i++) { - uint16_t vec_idx = perm_idx[i] / 4; - uint16_t vec_lane = perm_idx[i] % 4; - perm_lookup[i] = perm_pair{vec_idx, vec_lane}; - } - - // Separate arrays to hold the betas of the trellis termination bits for the - // original and permuted inputs - float32x4_t beta_tail[2]; - float32x4_t perm_beta_tail[2]; - - // Initialize alpha - alpha[0] = vdupq_n_f32(-INFINITY); - alpha[1] = vdupq_n_f32(-INFINITY); - alpha[0][0] = 0; - - // Calculate the trellis termination state transition probabilities, which - // do not require any extrinsic information - trellis_termination(sys_f32.get(), par_f32.get(), k4, channel_reliability, - beta_tail); - trellis_termination(perm_sys.get(), itl_f32.get(), k4, channel_reliability, - perm_beta_tail); - - // Initialize the number of iterations - uint32_t num_iter = 0; - - while (num_iter < max_iter) { - // Run the first decoder step - decode_step(sys_f32.get(), par_f32.get(), extrinsic.get(), k4, l1_uky.get(), - alpha.get(), beta_tail, gamma.get(), channel_reliability); - - // Compute the new extrinsic information to pass into the second decoder - update_extrinsic(k4, l1_uky.get(), extrinsic.get(), sys_f32.get()); - - // Need to unpermute extrinsic to match input to second decoder - for (uint32_t i = 0; i < k4; i++) { - for (uint32_t j = 0; j < 4; j++) { - perm_extrinsic[i][j] = extrinsic[perm_lookup[i * 4 + j].first] - [perm_lookup[i * 4 + j].second]; - } - } - - // Run the second decoder step - decode_step(perm_sys.get(), itl_f32.get(), perm_extrinsic.get(), k4, - l2_uky.get(), alpha.get(), perm_beta_tail, gamma.get(), - channel_reliability); - - // Compute the new extrinsic information to pass back into the first encoder - update_extrinsic(k4, l2_uky.get(), perm_extrinsic.get(), perm_sys.get()); - - // But need to unpermute extrinsic first - for (uint32_t i = 0; i < k4; i++) { - for (uint32_t j = 0; j < 4; j++) { - extrinsic[perm_lookup[i * 4 + j].first][perm_lookup[i * 4 + j].second] = - perm_extrinsic[i][j]; - } - } - - // Compare this iteration's results with those from the previous iteration - float32_t max_abs_diff = 0.0; - float32_t max_abs_val = 0.0; - for (uint32_t i = 0; i < k4; i++) { - float32_t abs_diff = vmaxvq_f32(vabdq_f32(l2_uky[i], prev_l2_uky[i])); - float32_t abs_val = vmaxvq_f32(vabsq_f32(l2_uky[i])); - if (abs_diff > max_abs_diff) { - max_abs_diff = abs_diff; - } - if (abs_val > max_abs_val) { - max_abs_val = abs_val; - } - } - - // If we've converged, finish decoding - if constexpr (check_convergence) { - if (max_abs_diff / max_abs_val < - std::numeric_limits::epsilon()) { - break; - } - } - - // Store the current "final" LLRs to use in convergence checking next - // iteration - for (uint32_t i = 0; i < k4; i++) { - prev_l2_uky[i] = l2_uky[i]; - } - - num_iter++; - } - - // Return unpermuted final output from second encoder - // Rather than allocate another new vector, copy into l1_uky and return that - for (uint32_t i = 0; i < k4; i++) { - for (uint32_t j = 0; j < 4; j++) { - l1_uky[perm_lookup[i * 4 + j].first][perm_lookup[i * 4 + j].second] = - l2_uky[i][j]; - } - } - - // Make a hard decision based on the final LLRs - turbo_llrs_to_bits(k, l1_uky.get(), dst); -} diff --git a/utils/bit_utils.hpp b/src/utils/bits_to_bytes.hpp similarity index 91% rename from utils/bit_utils.hpp rename to src/utils/bits_to_bytes.hpp index 1ed60cf..2cc811d 100644 --- a/utils/bit_utils.hpp +++ b/src/utils/bits_to_bytes.hpp @@ -9,11 +9,13 @@ #include #include +namespace armral { + // Given a byte array, where we are interested in each bit, create // an array of bytes instead in the passed-in array "out" // Data is read from the most significant bit in each byte to the least // significant -static inline void bits_to_bytes(uint32_t n, const uint8_t *in, uint8_t *out) { +inline void bits_to_bytes(uint32_t n, const uint8_t *in, uint8_t *out) { uint32_t full_bytes = n >> 3; // Set the mask uint8x16_t mask = vdupq_n_u8(1); @@ -68,8 +70,7 @@ static inline void bits_to_bytes(uint32_t n, const uint8_t *in, uint8_t *out) { // Given a byte array, where we are interested in each bit, create // an array of bytes instead and return it in a std::vector -static inline std::vector bits_to_bytes(uint32_t n, - const uint8_t *in) { +inline std::vector bits_to_bytes(uint32_t n, const uint8_t *in) { std::vector out(n); bits_to_bytes(n, in, out.data()); return out; @@ -78,7 +79,7 @@ static inline std::vector bits_to_bytes(uint32_t n, // Given a byte array of zeros and ones, write this out to // consecutive bits instead. Bytes are assumed to be big endian // so the first bit in a byte goes to the highest bit position -static inline void bytes_to_bits(uint32_t n, const uint8_t *in, uint8_t *out) { +inline void bytes_to_bits(uint32_t n, const uint8_t *in, uint8_t *out) { uint32_t full_bytes = n >> 3; uint32_t tail_bits = n & 7; for (uint32_t i = 0; i < full_bytes; ++i) { @@ -99,7 +100,7 @@ static inline void bytes_to_bits(uint32_t n, const uint8_t *in, uint8_t *out) { // negative, otherwise to 0. We do not assume that the data_out pointer is // initialized template -static inline void llrs_to_bits(uint32_t n, const T *llr, uint8_t *data_out) { +inline void llrs_to_bits(uint32_t n, const T *llr, uint8_t *data_out) { uint32_t full_bytes = n >> 3; uint32_t tail_bits = n & 7; for (uint32_t i = 0; i < full_bytes; ++i) { @@ -124,3 +125,5 @@ static inline void llrs_to_bits(uint32_t n, const T *llr, uint8_t *data_out) { } } } + +} // namespace armral diff --git a/src/utils/vec_mul.hpp b/src/utils/vec_mul.hpp index a352058..b056fd9 100644 --- a/src/utils/vec_mul.hpp +++ b/src/utils/vec_mul.hpp @@ -99,23 +99,21 @@ static inline int16x8x3_t load3_cmplx_and_scale(const int16_t *src, int16x8_t res1 = cmplx_mul_combined_re_im(in1, scale); int16x8_t res2 = cmplx_mul_combined_re_im(in2, scale); - return (int16x8x3_t){res0, res1, res2}; + return int16x8x3_t{res0, res1, res2}; } static inline void scale_and_store3_cmplx(int16_t *dst, int16x8_t out_0, int16x8_t out_1, int16x8_t out_2, int16x8x2_t scale) { // Multiply by a complex scale factor and store three vectors of output - vst2q_s16(dst, - cmplx_mul_split_re_im((int16x8x2_t){{vuzp1q_s16(out_0, out_1), - vuzp2q_s16(out_0, out_1)}}, - scale)); - vst2_s16(dst + 16, - cmplx_mul_split_re_im( - (int16x4x2_t){{vget_low_s16(vuzp1q_s16(out_2, out_2)), - vget_low_s16(vuzp2q_s16(out_2, out_2))}}, - (int16x4x2_t){ - {vget_low_s16(scale.val[0]), vget_low_s16(scale.val[1])}})); + vst2q_s16(dst, cmplx_mul_split_re_im(int16x8x2_t{{vuzp1q_s16(out_0, out_1), + vuzp2q_s16(out_0, out_1)}}, + scale)); + vst2_s16(dst + 16, cmplx_mul_split_re_im( + int16x4x2_t{{vget_low_s16(vuzp1q_s16(out_2, out_2)), + vget_low_s16(vuzp2q_s16(out_2, out_2))}}, + int16x4x2_t{{vget_low_s16(scale.val[0]), + vget_low_s16(scale.val[1])}})); } #if ARMRAL_ARCH_SVE >= 2 diff --git a/test/BasicMathFun/MatrixInv/Batch/main.cpp b/test/BasicMathFun/MatrixInv/Batch/main.cpp index c8ff2e5..ec1ad37 100644 --- a/test/BasicMathFun/MatrixInv/Batch/main.cpp +++ b/test/BasicMathFun/MatrixInv/Batch/main.cpp @@ -11,6 +11,15 @@ #include #include +using armral::utils::check_results_identity; +using armral::utils::check_results_mat_inv; +using armral::utils::gen_hermitian_matrix_batch; +using armral::utils::gen_invertible_matrix_batch; +using armral::utils::pack_data; +using armral::utils::print_cmplx_mat; +using armral::utils::reference_matinv_block; +using armral::utils::unpack_data; + /* * Run reference Matrix Inversion based on blockwise approach (batched/parallel * version) @@ -262,7 +271,7 @@ int main(int argc, char **argv) { // optimistic estimation. // Random test cases - std::vector m_hermitian = {2, 3, 4}; + std::vector m_hermitian = {2, 3, 4}; for (auto m : m_hermitian) { // Minimum number of matrices in batch // Required because m=3 and m=4 unroll by b=4, @@ -285,10 +294,10 @@ int main(int argc, char **argv) { } } } - std::vector m_general = {2, 3, 4}; + std::vector m_general = {2, 3, 4}; for (auto m : m_general) { unsigned b = m > 2 ? 4 : 2; - for (unsigned num_mats : {b, b * 2, b * 7, b * 100}) { + for (unsigned num_mats : {1U, m, m + 1, b * 2, b * 7, b * 100}) { for (int r = 0; r < num_reps; ++r) { passed &= run_batch_matinv_test(num_mats, m); passed &= run_batch_matinv_test(num_mats, m, 1.0F, 0.0F); diff --git a/test/BasicMathFun/MatrixInv/Single/main.cpp b/test/BasicMathFun/MatrixInv/Single/main.cpp index 480b5d7..88d4bec 100644 --- a/test/BasicMathFun/MatrixInv/Single/main.cpp +++ b/test/BasicMathFun/MatrixInv/Single/main.cpp @@ -11,6 +11,14 @@ #include #include +using armral::utils::allocate_random_cf32_lin_ind; +using armral::utils::check_results_identity; +using armral::utils::check_results_mat_inv; +using armral::utils::gen_hermitian_matrix; +using armral::utils::gen_invertible_matrix; +using armral::utils::print_cmplx_mat; +using armral::utils::reference_matinv_block; + /* * Run test for Hermitian Matrix Inversion and Reference for randomly generated * input matrix If inputs are random enough then matrix has high probability to diff --git a/test/BasicMathFun/MatrixMult/Batch/ArmSolve/main.cpp b/test/BasicMathFun/MatrixMult/Batch/ArmSolve/main.cpp index 97d3776..f7447aa 100644 --- a/test/BasicMathFun/MatrixMult/Batch/ArmSolve/main.cpp +++ b/test/BasicMathFun/MatrixMult/Batch/ArmSolve/main.cpp @@ -4,7 +4,7 @@ */ #include "cf32_utils.hpp" #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include "qint64.hpp" #include @@ -22,11 +22,9 @@ static bool check_tolerance_cs16(const char *name, const int16_t *result, // GCOVR_EXCL_STOP } } - if (passed) { - printf("[%s] - check result: OK\n", name); - } else { - printf("[%s] - check result: ERROR\n", name); // GCOVR_EXCL_LINE - } + + printf("[%s] - check result: %s\n", name, passed ? "OK" : "ERROR"); + return passed; } @@ -77,8 +75,8 @@ static armral_cmplx_int16_t convert_cs16_cf64(std::complex x, armral_fixed_point_index i) { int sh = (int)i; // number of decimal bits x *= (1 << sh); - qint64_t re = (int64_t)x.real(); - qint64_t im = (int64_t)x.imag(); + armral::utils::qint64_t re = (int64_t)x.real(); + armral::utils::qint64_t im = (int64_t)x.imag(); return {re.get16(), im.get16()}; } @@ -128,24 +126,24 @@ static bool run_solve_test(int sc_per_g, int num_samples, // input vector (y) is significantly larger than the output vector (x) due to // the tests/impl using different intermediate storage formats and fused // versus unfused arithmetic. - auto num_fract_bits_x = - (armral_fixed_point_index)allocate_random_u8(1, 0, 15)[0]; + armral::utils::int_random random_u8; + auto num_fract_bits_x = (armral_fixed_point_index)random_u8.one(0, 15); std::vector num_fract_bits_y(Y); for (int i = 0; i < Y; ++i) { int min_y = std::max(0, (int)num_fract_bits_x - 6); - num_fract_bits_y[i] = - (armral_fixed_point_index)allocate_random_u8(1, min_y, 15)[0]; + num_fract_bits_y[i] = (armral_fixed_point_index)random_u8.one(min_y, 15); } // arrangement is batches of num_samples, each row/col separated. - auto x = allocate_random_cs16(num_samples * x_rows); - auto y = allocate_random_cs16(num_samples * y_rows); + armral::utils::cs16_random random_cs16; + auto x = random_cs16.vector(num_samples * x_rows); + auto y = random_cs16.vector(num_samples * y_rows); // arrangement is batches of num_blocks, each row/col separated. - auto g = allocate_random_cf32(x_rows * y_rows * num_blocks); + auto g = armral::utils::allocate_random_cf32(x_rows * y_rows * num_blocks); - auto g_real = unpack_real_cf32(g); - auto g_imag = unpack_imag_cf32(g); + auto g_real = armral::utils::unpack_real_cf32(g); + auto g_imag = armral::utils::unpack_imag_cf32(g); auto x_ref = x; run_reference_solve(num_samples, sc_per_g, num_fract_bits_y.data(), diff --git a/test/BasicMathFun/MatrixMult/Batch/MatrixVectorMult16/main.cpp b/test/BasicMathFun/MatrixMult/Batch/MatrixVectorMult16/main.cpp index d47fc50..16f9985 100644 --- a/test/BasicMathFun/MatrixMult/Batch/MatrixVectorMult16/main.cpp +++ b/test/BasicMathFun/MatrixMult/Batch/MatrixVectorMult16/main.cpp @@ -5,6 +5,12 @@ #include "matrix_utils.hpp" +using armral::utils::check_results_cs16; +using armral::utils::cs16_random; +using armral::utils::pack_data; +using armral::utils::reference_matmul_cs16; +using armral::utils::unpack_data; + static void reference_batch_matvecmul_cs16( uint16_t num_mats, uint16_t vecs_per_mat, uint16_t m, uint16_t n, const armral_cmplx_int16_t *a, const armral_cmplx_int16_t *x, @@ -24,7 +30,7 @@ static void reference_batch_matvecmul_cs16( unpack_data(vec_batch_start, total_vectors, ref, single_ref.data(), m); // do one mxn matrix-vector multiplication - reference_matmul_cs16(m, n, 1, single_a.data(), single_x.data(), + reference_matmul_cs16(m, 1, n, single_a.data(), single_x.data(), single_ref.data(), round); // pack the answer back into ref for comparison with the batched results @@ -39,12 +45,13 @@ static bool run_general_matvecmul_batch_test_32(uint16_t num_mats, const char *name = "MATVECMULBATCH32 armral_cmplx_int16_t"; // choose min/max values to avoid hitting saturation on the problems // we care about (m,n <= 16). - constexpr int16_t min = -4096; - constexpr int16_t max = 4095; + constexpr armral_cmplx_int16_t min = {-4096, -4096}; + constexpr armral_cmplx_int16_t max = {4095, 4095}; auto total_vectors = num_mats * vecs_per_mat; - const auto a = allocate_random_cs16(num_mats * m * n, min, max); - const auto x = allocate_random_cs16(total_vectors * n, min, max); - auto y = allocate_random_cs16(total_vectors * m, min, max); + cs16_random random; + const auto a = random.vector(num_mats * m * n, min, max); + const auto x = random.vector(total_vectors * n, min, max); + auto y = random.vector(total_vectors * m, min, max); auto ref = y; printf("[%s] - num_mats %u vecs_per_mat %u dimension %u %u\n", name, num_mats, @@ -67,12 +74,13 @@ static bool run_general_matvecmul_batch_pa_test_32(uint16_t num_mats, const char *name = "MATVECMULBATCHPA32 armral_cmplx_int16_t"; // choose min/max values to avoid hitting saturation on the problems // we care about (m,n <= 16). - constexpr int16_t min = -4096; - constexpr int16_t max = 4095; + constexpr armral_cmplx_int16_t min = {-4096, -4096}; + constexpr armral_cmplx_int16_t max = {4095, 4095}; auto total_vectors = num_mats * vecs_per_mat; - const auto a = allocate_random_cs16(num_mats * m * n, min, max); - const auto x = allocate_random_cs16(total_vectors * n, min, max); - auto y = allocate_random_cs16(total_vectors * m, min, max); + cs16_random random; + const auto a = random.vector(num_mats * m * n, min, max); + const auto x = random.vector(total_vectors * n, min, max); + auto y = random.vector(total_vectors * m, min, max); auto ref = y; printf("[%s] - num_mats %u vecs_per_mat %u dimension %u %u\n", name, num_mats, @@ -109,9 +117,10 @@ static bool run_general_matvecmul_batch_test_64(uint16_t num_mats, uint16_t m, uint16_t n) { const char *name = "MATVECMULBATCH64 armral_cmplx_int16_t"; auto total_vectors = num_mats * vecs_per_mat; - const auto a = allocate_random_cs16(num_mats * m * n); - const auto x = allocate_random_cs16(total_vectors * n); - auto y = allocate_random_cs16(total_vectors * m); + cs16_random random; + const auto a = random.vector(num_mats * m * n); + const auto x = random.vector(total_vectors * n); + auto y = random.vector(total_vectors * m); auto ref = y; printf("[%s] - num_mats %u vecs_per_mat %u dimension %u %u\n", name, num_mats, @@ -133,9 +142,10 @@ static bool run_general_matvecmul_batch_pa_test_64(uint16_t num_mats, uint16_t m, uint16_t n) { const char *name = "MATVECMULBATCHPA64 armral_cmplx_int16_t"; auto total_vectors = num_mats * vecs_per_mat; - const auto a = allocate_random_cs16(num_mats * m * n); - const auto x = allocate_random_cs16(total_vectors * n); - auto y = allocate_random_cs16(total_vectors * m); + cs16_random random; + const auto a = random.vector(num_mats * m * n); + const auto x = random.vector(total_vectors * n); + auto y = random.vector(total_vectors * m); auto ref = y; printf("[%s] - num_mats %u vecs_per_mat %u dimension %u %u\n", name, num_mats, diff --git a/test/BasicMathFun/MatrixMult/Batch/MatrixVectorMult32/main.cpp b/test/BasicMathFun/MatrixMult/Batch/MatrixVectorMult32/main.cpp index 9c56ee4..516f075 100644 --- a/test/BasicMathFun/MatrixMult/Batch/MatrixVectorMult32/main.cpp +++ b/test/BasicMathFun/MatrixMult/Batch/MatrixVectorMult32/main.cpp @@ -22,16 +22,19 @@ static void reference_batch_matvecmul_cf32(uint16_t num_mats, auto vec_batch_start = mat * vecs_per_mat + vec; // unpack a, x, and ref into local buffers - unpack_data(mat, num_mats, a, single_a.data(), m * n); - unpack_data(vec_batch_start, total_vectors, x, single_x.data(), n); - unpack_data(vec_batch_start, total_vectors, ref, single_ref.data(), m); + armral::utils::unpack_data(mat, num_mats, a, single_a.data(), m * n); + armral::utils::unpack_data(vec_batch_start, total_vectors, x, + single_x.data(), n); + armral::utils::unpack_data(vec_batch_start, total_vectors, ref, + single_ref.data(), m); // do one mxn matrix-vector multiplication - reference_matmul_cf32(m, n, 1, single_a.data(), single_x.data(), - single_ref.data()); + armral::utils::reference_matmul_cf32(m, 1, n, single_a.data(), + single_x.data(), single_ref.data()); // pack the answer back into ref for comparison with the batched results - pack_data(vec_batch_start, total_vectors, single_ref.data(), ref, m); + armral::utils::pack_data(vec_batch_start, total_vectors, + single_ref.data(), ref, m); } } } @@ -41,9 +44,10 @@ static bool run_general_matvecmul_batch_test(uint16_t num_mats, uint16_t n) { const char *name = "MATVECMULBATCH armral_cmplx_f32_t"; auto total_vectors = num_mats * vecs_per_mat; - const auto a = allocate_random_cf32(num_mats * m * n); - const auto x = allocate_random_cf32(total_vectors * n); - auto y = allocate_random_cf32(total_vectors * m); + armral::utils::cf32_random random; + const auto a = random.vector(num_mats * m * n); + const auto x = random.vector(total_vectors * n); + auto y = random.vector(total_vectors * m); auto ref = y; printf("[%s] - num_mats %u vecs_per_mat %u dimension %u %u\n", name, num_mats, @@ -57,7 +61,8 @@ static bool run_general_matvecmul_batch_test(uint16_t num_mats, reference_batch_matvecmul_cf32(num_mats, vecs_per_mat, m, n, a.data(), x.data(), ref.data()); - return check_results_cf32(name, y.data(), ref.data(), total_vectors * m); + return armral::utils::check_results_cf32(name, y.data(), ref.data(), + total_vectors * m); } static bool run_general_matvecmul_batch_pa_test(uint16_t num_mats, @@ -65,9 +70,10 @@ static bool run_general_matvecmul_batch_pa_test(uint16_t num_mats, uint16_t m, uint16_t n) { const char *name = "MATVECMULBATCHPA armral_cmplx_f32_t"; auto total_vectors = num_mats * vecs_per_mat; - const auto a = allocate_random_cf32(num_mats * m * n); - const auto x = allocate_random_cf32(total_vectors * n); - auto y = allocate_random_cf32(total_vectors * m); + armral::utils::cf32_random random; + const auto a = random.vector(num_mats * m * n); + const auto x = random.vector(total_vectors * n); + auto y = random.vector(total_vectors * m); auto ref = y; printf("[%s] - num_mats %u vecs_per_mat %u dimension %u %u\n", name, num_mats, @@ -96,7 +102,8 @@ static bool run_general_matvecmul_batch_pa_test(uint16_t num_mats, reference_batch_matvecmul_cf32(num_mats, vecs_per_mat, m, n, a.data(), x.data(), ref.data()); - return check_results_cf32(name, y.data(), ref.data(), total_vectors * m); + return armral::utils::check_results_cf32(name, y.data(), ref.data(), + total_vectors * m); } // Entry point for unit testing for cf32 batched matrix-vector multiplication diff --git a/test/BasicMathFun/MatrixMult/Single/MatrixMult16/main.cpp b/test/BasicMathFun/MatrixMult/Single/MatrixMult16/main.cpp index e0322ac..4dfadcd 100644 --- a/test/BasicMathFun/MatrixMult/Single/MatrixMult16/main.cpp +++ b/test/BasicMathFun/MatrixMult/Single/MatrixMult16/main.cpp @@ -6,34 +6,38 @@ static bool run_general_matmul_test_64(uint16_t m, uint16_t n, uint16_t k) { const char *name = "MATMUL64 armral_cmplx_int16_t"; - const auto a = allocate_random_cs16(m * n); - const auto b = allocate_random_cs16(n * k); - auto c = allocate_random_cs16(m * k); + armral::utils::cs16_random random; + const auto a = random.vector(m * k); + const auto b = random.vector(k * n); + auto c = random.vector(m * n); auto ref = c; printf("[%s] - dimension %u %u %u\n", name, m, n, k); - reference_matmul_cs16(m, n, k, a.data(), b.data(), ref.data(), 0); - armral_cmplx_mat_mult_i16(m, n, k, a.data(), b.data(), c.data()); - return check_results_cs16(name, c.data(), ref.data(), m * k); + armral::utils::reference_matmul_cs16(m, n, k, a.data(), b.data(), ref.data(), + 0); + armral_cmplx_matmul_i16(m, n, k, a.data(), b.data(), c.data()); + return armral::utils::check_results_cs16(name, c.data(), ref.data(), m * n); } static bool run_general_matmul_test_32(uint16_t m, uint16_t n, uint16_t k) { const char *name = "MATMUL32 armral_cmplx_int16_t"; // choose min/max values to avoid hitting saturation on the problems // we care about (m,n,k <= 16). - constexpr int16_t min = -4096; - constexpr int16_t max = 4095; - auto a = allocate_random_cs16(m * n, min, max); - auto b = allocate_random_cs16(n * k, min, max); - auto c = allocate_random_cs16(m * k, min, max); + constexpr armral_cmplx_int16_t min = {-4096, -4096}; + constexpr armral_cmplx_int16_t max = {4095, 4095}; + armral::utils::cs16_random random; + const auto a = random.vector(m * k, min, max); + const auto b = random.vector(k * n, min, max); + auto c = random.vector(m * n, min, max); auto ref = c; printf("[%s] - dimension %u %u %u\n", name, m, n, k); - reference_matmul_cs16(m, n, k, a.data(), b.data(), ref.data(), 0); - armral_cmplx_mat_mult_i16_32bit(m, n, k, a.data(), b.data(), c.data()); - return check_results_cs16(name, c.data(), ref.data(), m * k); + armral::utils::reference_matmul_cs16(m, n, k, a.data(), b.data(), ref.data(), + 0); + armral_cmplx_matmul_i16_32bit(m, n, k, a.data(), b.data(), c.data()); + return armral::utils::check_results_cs16(name, c.data(), ref.data(), m * n); } // Entry point for unit testing for 16-bit matrix multiplication diff --git a/test/BasicMathFun/MatrixMult/Single/MatrixMult32/main.cpp b/test/BasicMathFun/MatrixMult/Single/MatrixMult32/main.cpp index 3c70e00..fd21f7d 100644 --- a/test/BasicMathFun/MatrixMult/Single/MatrixMult32/main.cpp +++ b/test/BasicMathFun/MatrixMult/Single/MatrixMult32/main.cpp @@ -5,22 +5,31 @@ #include "cf32_utils.hpp" #include "reference_linalg.hpp" +using armral::utils::cf32_random; +using armral::utils::check_results_cf32; +using armral::utils::pack_cf32; +using armral::utils::reference_matmul_cf32; +using armral::utils::unpack_imag_cf32; +using armral::utils::unpack_real_cf32; + static bool run_general_matmul_test(uint16_t m, uint16_t n, uint16_t k) { - const auto a = allocate_random_cf32(m * n); - const auto b = allocate_random_cf32(n * k); - auto c = allocate_random_cf32(m * k); + cf32_random random; + const auto a = random.vector(m * k); + const auto b = random.vector(k * n); + auto c = random.vector(m * n); auto ref = c; reference_matmul_cf32(m, n, k, a.data(), b.data(), ref.data()); - armral_cmplx_mat_mult_f32(m, n, k, a.data(), b.data(), c.data()); + armral_cmplx_matmul_f32(m, n, k, a.data(), b.data(), c.data()); return check_results_cf32("MATMUL armral_cmplx_f32_t", c.data(), ref.data(), - m * k); + m * n); } static bool run_specific_2x2_matmul_test() { - const auto a = allocate_random_cf32(4); - const auto b = allocate_random_cf32(4); - auto c = allocate_random_cf32(4); + cf32_random random; + const auto a = random.vector(4); + const auto b = random.vector(4); + auto c = random.vector(4); auto ref = c; // note: the a/b flip is intentional since all matrices are given transposed. @@ -32,9 +41,10 @@ static bool run_specific_2x2_matmul_test() { } static bool run_specific_2x2_iq_matmul_test() { - const auto a = allocate_random_cf32(4); - const auto b = allocate_random_cf32(4); - auto ref = allocate_random_cf32(4); + cf32_random random; + const auto a = random.vector(4); + const auto b = random.vector(4); + auto ref = random.vector(4); // note: the a/b flip is intentional since all matrices are given transposed. // i.e. C = A^T * B^T = (B * A)^T @@ -54,9 +64,10 @@ static bool run_specific_2x2_iq_matmul_test() { } static bool run_specific_4x4_matmul_test() { - const auto a = allocate_random_cf32(16); - const auto b = allocate_random_cf32(16); - auto c = allocate_random_cf32(16); + cf32_random random; + const auto a = random.vector(16); + const auto b = random.vector(16); + auto c = random.vector(16); auto ref = c; // note: the a/b flip is intentional since all matrices are given transposed. @@ -68,9 +79,10 @@ static bool run_specific_4x4_matmul_test() { } static bool run_specific_4x4_iq_matmul_test() { - const auto a = allocate_random_cf32(16); - const auto b = allocate_random_cf32(16); - auto ref = allocate_random_cf32(16); + cf32_random random; + const auto a = random.vector(16); + const auto b = random.vector(16); + auto ref = random.vector(16); // note: the a/b flip is intentional since all matrices are given transposed. // i.e. C = A^T * B^T = (B * A)^T @@ -98,11 +110,11 @@ int main(int argc, char **argv) { } } } - const uint16_t n_size[] = {32, 64, 128, 256}; - for (uint16_t n : n_size) { - for (uint16_t k = 2; k <= 4; k++) { + const uint16_t sizes[] = {32, 64, 128, 255}; + for (uint16_t n : sizes) { + for (uint16_t k : sizes) { passed &= run_general_matmul_test(4, n, k); - passed &= run_general_matmul_test(8, n, k); + passed &= run_general_matmul_test(9, n, k); } } passed &= run_specific_2x2_matmul_test(); diff --git a/test/BasicMathFun/MatrixMult/Single/MatrixMultAAH32/main.cpp b/test/BasicMathFun/MatrixMult/Single/MatrixMultAAH32/main.cpp index 854bb26..2324d97 100644 --- a/test/BasicMathFun/MatrixMult/Single/MatrixMultAAH32/main.cpp +++ b/test/BasicMathFun/MatrixMult/Single/MatrixMultAAH32/main.cpp @@ -7,17 +7,19 @@ #include "reference_linalg.hpp" static bool run_matmul_aah_cf32_test(uint16_t m, uint16_t n) { - const auto a = allocate_random_cf32(m * n); - auto c = allocate_random_cf32(m * m); + armral::utils::cf32_random random; + const auto a = random.vector(m * n); + auto c = random.vector(m * m); auto ref = c; - reference_matmul_aah_cf32(m, n, a.data(), ref.data()); - armral_cmplx_mat_mult_aah_f32(m, n, a.data(), c.data()); + armral::utils::reference_matmul_aah_cf32(m, n, a.data(), ref.data()); + armral_cmplx_matmul_aah_f32(m, n, a.data(), c.data()); printf("%ix%i -> %ix%i\n", m, n, m, m); // Each element in c is computed by a length-n complex dot product - return check_results_cf32("MATMUL_AAH armral_cmplx_f32_t", c.data(), - ref.data(), m * m, cmplx_dot_nflops(n)); + return armral::utils::check_results_cf32("MATMUL_AAH armral_cmplx_f32_t", + c.data(), ref.data(), m * m, + armral::utils::cmplx_dot_nflops(n)); } int main(int argc, char **argv) { diff --git a/test/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/main.cpp b/test/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/main.cpp index 883f8bb..d5348b6 100644 --- a/test/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/main.cpp +++ b/test/BasicMathFun/MatrixMult/Single/MatrixMultAHB32/main.cpp @@ -15,20 +15,21 @@ static bool run_matmul_ahb_cf32_test(uint16_t m, uint16_t n, uint16_t k) { const char *name = "MATMUL_AHB armral_cmplx_f32_t"; printf("[%s] m=%d n=%d k=%d\n", name, m, n, k); - cf32_random random; - const auto a = random.flip_signs(random.vector(m * n)); - const auto b = random.flip_signs(random.vector(m * k)); - auto output = random.vector(n * k); + armral::utils::cf32_random random; + const auto a = random.flip_signs(random.vector(k * m)); + const auto b = random.flip_signs(random.vector(k * n)); + auto output = random.vector(m * n); auto reference_output = output; - reference_matmul_ahb_cf32(m, n, k, a.data(), b.data(), - reference_output.data()); + armral::utils::reference_matmul_ahb_cf32(m, n, k, a.data(), b.data(), + reference_output.data()); - armral_cmplx_mat_mult_ahb_f32(m, n, k, a.data(), b.data(), output.data()); + armral_cmplx_matmul_ahb_f32(m, n, k, a.data(), b.data(), output.data()); - // Each element in the output is computed by a length-m complex dot product - return check_results_cf32(name, output.data(), reference_output.data(), n * k, - cmplx_dot_nflops(m)); + // Each element in the output is computed by a length-k complex dot product + return armral::utils::check_results_cf32(name, output.data(), + reference_output.data(), m * n, + armral::utils::cmplx_dot_nflops(k)); } int main() { @@ -40,14 +41,14 @@ int main() { } } } - std::array n_sizes{32, 64, 128, 256}; - std::array mk_sizes{2, 3, 4, 8, 16}; - for (auto n : n_sizes) { - for (auto mk : mk_sizes) { + std::array m_sizes{32, 64, 128, 256}; + std::array nk_sizes{2, 3, 4, 8, 16}; + for (auto m : m_sizes) { + for (auto nk : nk_sizes) { // Larger A matrix, square B matrix - passed &= run_matmul_ahb_cf32_test(mk, n, mk); + passed &= run_matmul_ahb_cf32_test(m, nk, nk); // Larger A matrix, rectangular B matrix - passed &= run_matmul_ahb_cf32_test(mk * 2, n, mk); + passed &= run_matmul_ahb_cf32_test(m, nk, nk * 2); } } diff --git a/test/BasicMathFun/MatrixMult/Single/MatrixVectorMult16/main.cpp b/test/BasicMathFun/MatrixMult/Single/MatrixVectorMult16/main.cpp index c859193..eb024e3 100644 --- a/test/BasicMathFun/MatrixMult/Single/MatrixVectorMult16/main.cpp +++ b/test/BasicMathFun/MatrixMult/Single/MatrixVectorMult16/main.cpp @@ -7,34 +7,38 @@ static bool run_general_matvecmul_test_64(uint16_t m, uint16_t n) { const char *name = "MATVECMUL64 armral_cmplx_int16_t"; - const auto a = allocate_random_cs16(m * n); - const auto x = allocate_random_cs16(n); - auto y = allocate_random_cs16(m); + armral::utils::cs16_random random; + const auto a = random.vector(m * n); + const auto x = random.vector(n); + auto y = random.vector(m); auto ref = y; printf("[%s] - dimension %u %u\n", name, m, n); - reference_matmul_cs16(m, n, 1, a.data(), x.data(), ref.data(), 0); + armral::utils::reference_matmul_cs16(m, 1, n, a.data(), x.data(), ref.data(), + 0); armral_cmplx_mat_vec_mult_i16(m, n, a.data(), x.data(), y.data()); - return check_results_cs16(name, y.data(), ref.data(), m); + return armral::utils::check_results_cs16(name, y.data(), ref.data(), m); } static bool run_general_matvecmul_test_32(uint16_t m, uint16_t n) { const char *name = "MATVECMUL32 armral_cmplx_int16_t"; // choose min/max values to avoid hitting saturation on the problems // we care about (m,n <= 16). - constexpr int16_t min = -4096; - constexpr int16_t max = 4095; - auto a = allocate_random_cs16(m * n, min, max); - auto x = allocate_random_cs16(n, min, max); - auto y = allocate_random_cs16(m, min, max); + constexpr armral_cmplx_int16_t min = {-4096, -4096}; + constexpr armral_cmplx_int16_t max = {4095, 4095}; + armral::utils::cs16_random random; + const auto a = random.vector(m * n, min, max); + const auto x = random.vector(n, min, max); + auto y = random.vector(m, min, max); auto ref = y; printf("[%s] - dimension %u %u\n", name, m, n); - reference_matmul_cs16(m, n, 1, a.data(), x.data(), ref.data(), 0); + armral::utils::reference_matmul_cs16(m, 1, n, a.data(), x.data(), ref.data(), + 0); armral_cmplx_mat_vec_mult_i16_32bit(m, n, a.data(), x.data(), y.data()); - return check_results_cs16(name, y.data(), ref.data(), m); + return armral::utils::check_results_cs16(name, y.data(), ref.data(), m); } // Entry point for unit testing for 16-bit matrix-vector multiplication diff --git a/test/BasicMathFun/MatrixMult/Single/MatrixVectorMult32/main.cpp b/test/BasicMathFun/MatrixMult/Single/MatrixVectorMult32/main.cpp index d186048..1b9e1ab 100644 --- a/test/BasicMathFun/MatrixMult/Single/MatrixVectorMult32/main.cpp +++ b/test/BasicMathFun/MatrixMult/Single/MatrixVectorMult32/main.cpp @@ -6,15 +6,16 @@ #include "reference_linalg.hpp" static bool run_general_matvecmul_test(uint16_t m, uint16_t n) { - const auto a = allocate_random_cf32(m * n); - const auto x = allocate_random_cf32(n); - auto y = allocate_random_cf32(m); + armral::utils::cf32_random random; + const auto a = random.vector(m * n); + const auto x = random.vector(n); + auto y = random.vector(m); auto ref = y; - reference_matmul_cf32(m, n, 1, a.data(), x.data(), ref.data()); + armral::utils::reference_matmul_cf32(m, 1, n, a.data(), x.data(), ref.data()); armral_cmplx_mat_vec_mult_f32(m, n, a.data(), x.data(), y.data()); - return check_results_cf32("MATVECMUL armral_cmplx_f32_t", y.data(), - ref.data(), m); + return armral::utils::check_results_cf32("MATVECMUL armral_cmplx_f32_t", + y.data(), ref.data(), m); } int main(int argc, char **argv) { diff --git a/test/BasicMathFun/MatrixPseudoInv/Direct/main.cpp b/test/BasicMathFun/MatrixPseudoInv/Direct/main.cpp index e644718..f2f3915 100644 --- a/test/BasicMathFun/MatrixPseudoInv/Direct/main.cpp +++ b/test/BasicMathFun/MatrixPseudoInv/Direct/main.cpp @@ -17,7 +17,7 @@ reference_left_pseudo_inverse_direct(uint32_t m, uint32_t n, float32_t lambda, // We can use p_dst as an intermediate N-by-N array since it has size N-by-M, // and N < M auto *mat_aha = p_dst; - reference_matmul_aha_cf32(m, n, p_src, mat_aha); + armral::utils::reference_matmul_aha_cf32(m, n, p_src, mat_aha); // Compute C + lambda * I for (uint32_t i = 0; i < n; i++) { @@ -31,11 +31,11 @@ reference_left_pseudo_inverse_direct(uint32_t m, uint32_t n, float32_t lambda, mat_inv[0].re = 1.F / mat_aha[0].re; mat_inv[0].im = 0.F; } else { - reference_matinv_block(n, mat_aha, mat_inv.data()); + armral::utils::reference_matinv_block(n, mat_aha, mat_inv.data()); } // Compute B * A^H - reference_matmul_bah_cf32(m, n, p_src, mat_inv.data(), p_dst); + armral::utils::reference_matmul_bah_cf32(m, n, p_src, mat_inv.data(), p_dst); } static inline void reference_right_pseudo_inverse_direct( @@ -45,7 +45,7 @@ static inline void reference_right_pseudo_inverse_direct( // We can use p_dst as an intermediate M-by-M array since it has size N-by-M, // and N >= M auto *mat_aah = p_dst; - reference_matmul_aah_cf32(m, n, p_src, mat_aah); + armral::utils::reference_matmul_aah_cf32(m, n, p_src, mat_aah); // Compute C + lambda * I for (uint32_t i = 0; i < m; i++) { @@ -59,11 +59,12 @@ static inline void reference_right_pseudo_inverse_direct( mat_inv[0].re = 1.F / mat_aah[0].re; mat_inv[0].im = 0.F; } else { - reference_matinv_block(m, mat_aah, mat_inv.data()); + armral::utils::reference_matinv_block(m, mat_aah, mat_inv.data()); } // Compute A^H * B - reference_matmul_ahb_cf32(m, n, m, p_src, mat_inv.data(), p_dst); + armral::utils::reference_matmul_ahb_cf32(n, m, m, p_src, mat_inv.data(), + p_dst); } static inline void @@ -80,7 +81,7 @@ template static bool run_pseudo_inverse_direct_cf32_test( const char *name, uint32_t m, uint32_t n, float32_t lambda, PseudoInverseFunction pseudo_inverse_under_test) { - cf32_random random; + armral::utils::cf32_random random; const auto src = random.flip_signs(random.vector(m * n)); std::vector ans(n * m); std::vector ref(n * m); @@ -89,8 +90,8 @@ static bool run_pseudo_inverse_direct_cf32_test( reference_pseudo_inverse_direct(m, n, lambda, src.data(), ref.data()); - return check_results_cf32(name, ans.data(), ref.data(), n * m, - cmplx_dot_nflops(n)); + return armral::utils::check_results_cf32(name, ans.data(), ref.data(), n * m, + armral::utils::cmplx_dot_nflops(n)); } template diff --git a/test/BasicMathFun/VectorDotProd/VecDot16/main.cpp b/test/BasicMathFun/VectorDotProd/VecDot16/main.cpp index 6003f61..9f4c4a0 100644 --- a/test/BasicMathFun/VectorDotProd/VecDot16/main.cpp +++ b/test/BasicMathFun/VectorDotProd/VecDot16/main.cpp @@ -8,28 +8,29 @@ #define NAME "VECDOT armral_cmplx_int16_t 64" static bool run_vec_dot_test(uint32_t num_samples) { - const auto a = allocate_random_cs16(num_samples); - const auto b = allocate_random_cs16(num_samples); - auto c = allocate_random_cs16(1); + armral::utils::cs16_random random; + const auto a = random.vector(num_samples); + const auto b = random.vector(num_samples); + auto c = random.one(); printf("[" NAME "] - %u samples\n", num_samples); - armral_cmplx_vecdot_i16(num_samples, a.data(), b.data(), c.data()); + armral_cmplx_vecdot_i16(num_samples, a.data(), b.data(), &c); - std::complex acc; + std::complex acc; for (uint32_t i = 0; i < num_samples; ++i) { - acc += cmplx_mul_widen_cs16(a[i], b[i]); + acc += armral::utils::cmplx_mul_widen_cs16(a[i], b[i]); } armral_cmplx_int16_t ref{(acc.real() >> 15).get16(), (acc.imag() >> 15).get16()}; - return check_results_cs16(NAME, c.data(), &ref, 1); + return armral::utils::check_results_cs16(NAME, &c, &ref, 1); } // Entry point for unit testing of 16-bit vector dot product int main(int argc, char **argv) { - std::vector params; - for (int i = 1; i <= 33; ++i) { + std::vector params; + for (uint32_t i = 1; i <= 33; ++i) { params.push_back(i); } params.push_back(64); @@ -41,7 +42,7 @@ int main(int argc, char **argv) { params.push_back(512); params.push_back(1024); bool passed = true; - for (auto &n : params) { + for (const auto &n : params) { passed &= run_vec_dot_test(n); } exit(passed ? EXIT_SUCCESS : EXIT_FAILURE); diff --git a/test/BasicMathFun/VectorDotProd/VecDot16_2/main.cpp b/test/BasicMathFun/VectorDotProd/VecDot16_2/main.cpp index 8ac1092..bf33cab 100644 --- a/test/BasicMathFun/VectorDotProd/VecDot16_2/main.cpp +++ b/test/BasicMathFun/VectorDotProd/VecDot16_2/main.cpp @@ -8,35 +8,36 @@ #define NAME "VECDOT armral_cmplx_int16_t IQ" static bool run_vec_dot_test(uint32_t num_samples) { - const auto a = allocate_random_cs16(num_samples); - const auto b = allocate_random_cs16(num_samples); - auto c = allocate_random_cs16(1); + armral::utils::cs16_random random; + const auto a = random.vector(num_samples); + const auto b = random.vector(num_samples); + auto c = random.vector(1); printf("[" NAME "] - %u samples\n", num_samples); - const auto a_re = unpack_real_cs16(a); - const auto a_im = unpack_imag_cs16(a); - const auto b_re = unpack_real_cs16(b); - const auto b_im = unpack_imag_cs16(b); - auto c_re = unpack_real_cs16(c); - auto c_im = unpack_imag_cs16(c); + const auto a_re = armral::utils::unpack_real_cs16(a); + const auto a_im = armral::utils::unpack_imag_cs16(a); + const auto b_re = armral::utils::unpack_real_cs16(b); + const auto b_im = armral::utils::unpack_imag_cs16(b); + auto c_re = armral::utils::unpack_real_cs16(c); + auto c_im = armral::utils::unpack_imag_cs16(c); armral_cmplx_vecdot_i16_2(num_samples, a_re.data(), a_im.data(), b_re.data(), b_im.data(), c_re.data(), c_im.data()); - c = pack_cs16(c_re, c_im); + c = armral::utils::pack_cs16(c_re, c_im); - std::complex acc; + std::complex acc; for (uint32_t i = 0; i < num_samples; ++i) { - acc += cmplx_mul_widen_cs16(a[i], b[i]); + acc += armral::utils::cmplx_mul_widen_cs16(a[i], b[i]); } armral_cmplx_int16_t ref{(acc.real() >> 15).get16(), (acc.imag() >> 15).get16()}; - return check_results_cs16(NAME, c.data(), &ref, 1); + return armral::utils::check_results_cs16(NAME, c.data(), &ref, 1); } int main(int argc, char **argv) { - const int params[] = { + const uint32_t params[] = { 1, 2, 3, 4, 5, 7, 8, 15, 16, 32, 64, 100, 128, 151, 256, 512, 1024, }; bool passed = true; diff --git a/test/BasicMathFun/VectorDotProd/VecDot16_2_32bit/main.cpp b/test/BasicMathFun/VectorDotProd/VecDot16_2_32bit/main.cpp index 904b13f..303faa7 100644 --- a/test/BasicMathFun/VectorDotProd/VecDot16_2_32bit/main.cpp +++ b/test/BasicMathFun/VectorDotProd/VecDot16_2_32bit/main.cpp @@ -7,40 +7,41 @@ #define NAME "VECDOT armral_cmplx_int16_t IQ 32" -static bool run_vec_dot_test(int num_samples) { +static bool run_vec_dot_test(uint32_t num_samples) { // restrict min/max to avoid hitting saturation in accumulator. - int min = -4096; - int max = 4095; - const auto a = allocate_random_cs16(num_samples, min, max); - const auto b = allocate_random_cs16(num_samples, min, max); - auto c = allocate_random_cs16(1); - - printf("[" NAME "] - %d samples\n", num_samples); - - const auto a_re = unpack_real_cs16(a); - const auto a_im = unpack_imag_cs16(a); - const auto b_re = unpack_real_cs16(b); - const auto b_im = unpack_imag_cs16(b); - auto c_re = unpack_real_cs16(c); - auto c_im = unpack_imag_cs16(c); + constexpr armral_cmplx_int16_t min = {-4096, -4096}; + constexpr armral_cmplx_int16_t max = {4095, 4095}; + armral::utils::cs16_random random; + const auto a = random.vector(num_samples, min, max); + const auto b = random.vector(num_samples, min, max); + auto c = random.vector(1); + + printf("[" NAME "] - %u samples\n", num_samples); + + const auto a_re = armral::utils::unpack_real_cs16(a); + const auto a_im = armral::utils::unpack_imag_cs16(a); + const auto b_re = armral::utils::unpack_real_cs16(b); + const auto b_im = armral::utils::unpack_imag_cs16(b); + auto c_re = armral::utils::unpack_real_cs16(c); + auto c_im = armral::utils::unpack_imag_cs16(c); armral_cmplx_vecdot_i16_2_32bit(num_samples, a_re.data(), a_im.data(), b_re.data(), b_im.data(), c_re.data(), c_im.data()); - c = pack_cs16(c_re, c_im); + c = armral::utils::pack_cs16(c_re, c_im); - std::complex acc; - for (int i = 0; i < num_samples; ++i) { - acc += cmplx_mul_widen_cs16(a[i], b[i]); + std::complex acc; + for (uint32_t i = 0; i < num_samples; ++i) { + acc += armral::utils::cmplx_mul_widen_cs16(a[i], b[i]); } armral_cmplx_int16_t ref{(acc.real() >> 16).get16(), (acc.imag() >> 16).get16()}; - return check_results_cs16(NAME, c.data(), &ref, 1); + return armral::utils::check_results_cs16(NAME, c.data(), &ref, 1); } int main(int argc, char **argv) { - const int params[] = { + const uint32_t params[] = { 1, 2, 3, 4, 5, 7, 8, 15, 16, 32, 64, 100, 128, 151, 256, 512, 1024, }; bool passed = true; diff --git a/test/BasicMathFun/VectorDotProd/VecDot16_32bit/main.cpp b/test/BasicMathFun/VectorDotProd/VecDot16_32bit/main.cpp index b4784ff..8bfb96c 100644 --- a/test/BasicMathFun/VectorDotProd/VecDot16_32bit/main.cpp +++ b/test/BasicMathFun/VectorDotProd/VecDot16_32bit/main.cpp @@ -9,29 +9,30 @@ static bool run_vec_dot_test(uint32_t num_samples) { // restrict min/max to avoid hitting saturation in accumulator. - int min = -4096; - int max = 4095; - const auto a = allocate_random_cs16(num_samples, min, max); - const auto b = allocate_random_cs16(num_samples, min, max); - auto c = allocate_random_cs16(1); + constexpr armral_cmplx_int16_t min = {-4096, -4096}; + constexpr armral_cmplx_int16_t max = {4095, 4095}; + armral::utils::cs16_random random; + const auto a = random.vector(num_samples, min, max); + const auto b = random.vector(num_samples, min, max); + auto c = random.one(); printf("[" NAME "] - %u samples\n", num_samples); - armral_cmplx_vecdot_i16_32bit(num_samples, a.data(), b.data(), c.data()); + armral_cmplx_vecdot_i16_32bit(num_samples, a.data(), b.data(), &c); - std::complex acc; + std::complex acc; for (uint32_t i = 0; i < num_samples; ++i) { - acc += cmplx_mul_widen_cs16(a[i], b[i]); + acc += armral::utils::cmplx_mul_widen_cs16(a[i], b[i]); } armral_cmplx_int16_t ref{(acc.real() >> 16).get16(), (acc.imag() >> 16).get16()}; - return check_results_cs16(NAME, c.data(), &ref, 1); + return armral::utils::check_results_cs16(NAME, &c, &ref, 1); } int main(int argc, char **argv) { - std::vector params; - for (int i = 1; i <= 33; ++i) { + std::vector params; + for (uint32_t i = 1; i <= 33; ++i) { params.push_back(i); } params.push_back(64); @@ -43,7 +44,7 @@ int main(int argc, char **argv) { params.push_back(512); params.push_back(1024); bool passed = true; - for (auto &n : params) { + for (const auto &n : params) { passed &= run_vec_dot_test(n); } exit(passed ? EXIT_SUCCESS : EXIT_FAILURE); diff --git a/test/BasicMathFun/VectorDotProd/VecDot32/main.cpp b/test/BasicMathFun/VectorDotProd/VecDot32/main.cpp index a72d8ea..e2ed128 100644 --- a/test/BasicMathFun/VectorDotProd/VecDot32/main.cpp +++ b/test/BasicMathFun/VectorDotProd/VecDot32/main.cpp @@ -7,25 +7,26 @@ #define NAME "VECDOT armral_cmplx_f32_t" static bool run_vec_dot_test(uint32_t num_samples) { - const auto a = allocate_random_cf32(num_samples); - const auto b = allocate_random_cf32(num_samples); - auto c = allocate_random_cf32(1); + armral::utils::cf32_random random; + const auto a = random.vector(num_samples); + const auto b = random.vector(num_samples); + auto c = random.one(); printf("[" NAME "] - %u samples\n", num_samples); - armral_cmplx_vecdot_f32(num_samples, a.data(), b.data(), c.data()); + armral_cmplx_vecdot_f32(num_samples, a.data(), b.data(), &c); std::complex acc; for (uint32_t i = 0; i < num_samples; ++i) { - acc += cmplx_mul_widen_cf32(a[i], b[i]); + acc += armral::utils::cmplx_mul_widen_cf32(a[i], b[i]); } armral_cmplx_f32_t ref{(float32_t)acc.real(), (float32_t)acc.imag()}; - return check_results_cf32(NAME, c.data(), &ref, 1); + return armral::utils::check_results_cf32(NAME, &c, &ref, 1); } int main(int argc, char **argv) { - const int params[] = { + const uint32_t params[] = { 1, 2, 3, 4, 5, 7, 8, 15, 16, 32, 64, 100, 128, 151, 256, 512, 1024, }; bool passed = true; diff --git a/test/BasicMathFun/VectorDotProd/VecDot32_2/main.cpp b/test/BasicMathFun/VectorDotProd/VecDot32_2/main.cpp index 232c14b..35f7be3 100644 --- a/test/BasicMathFun/VectorDotProd/VecDot32_2/main.cpp +++ b/test/BasicMathFun/VectorDotProd/VecDot32_2/main.cpp @@ -7,34 +7,35 @@ #define NAME "VECDOT armral_cmplx_f32_t IQ" static bool run_vec_dot_test(uint32_t num_samples) { - const auto a = allocate_random_cf32(num_samples); - const auto b = allocate_random_cf32(num_samples); - auto c = allocate_random_cf32(1); + armral::utils::cf32_random random; + const auto a = random.vector(num_samples); + const auto b = random.vector(num_samples); + auto c = random.vector(1); printf("[" NAME "] - %u samples\n", num_samples); - const auto a_re = unpack_real_cf32(a); - const auto a_im = unpack_imag_cf32(a); - const auto b_re = unpack_real_cf32(b); - const auto b_im = unpack_imag_cf32(b); - auto c_re = unpack_real_cf32(c); - auto c_im = unpack_imag_cf32(c); + const auto a_re = armral::utils::unpack_real_cf32(a); + const auto a_im = armral::utils::unpack_imag_cf32(a); + const auto b_re = armral::utils::unpack_real_cf32(b); + const auto b_im = armral::utils::unpack_imag_cf32(b); + auto c_re = armral::utils::unpack_real_cf32(c); + auto c_im = armral::utils::unpack_imag_cf32(c); armral_cmplx_vecdot_f32_2(num_samples, a_re.data(), a_im.data(), b_re.data(), b_im.data(), c_re.data(), c_im.data()); - c = pack_cf32(c_re, c_im); + c = armral::utils::pack_cf32(c_re, c_im); std::complex acc; for (uint32_t i = 0; i < num_samples; ++i) { - acc += cmplx_mul_widen_cf32(a[i], b[i]); + acc += armral::utils::cmplx_mul_widen_cf32(a[i], b[i]); } armral_cmplx_f32_t ref{(float32_t)acc.real(), (float32_t)acc.imag()}; - return check_results_cf32(NAME, c.data(), &ref, 1); + return armral::utils::check_results_cf32(NAME, c.data(), &ref, 1); } int main(int argc, char **argv) { - const int params[] = { + const uint32_t params[] = { 1, 2, 3, 4, 5, 7, 8, 15, 16, 32, 64, 100, 128, 151, 256, 512, 1024, }; bool passed = true; diff --git a/test/BasicMathFun/VectorMult/VecMul16/main.cpp b/test/BasicMathFun/VectorMult/VecMul16/main.cpp index aadebde..8c31bea 100644 --- a/test/BasicMathFun/VectorMult/VecMul16/main.cpp +++ b/test/BasicMathFun/VectorMult/VecMul16/main.cpp @@ -8,9 +8,10 @@ #define NAME "VECMUL armral_cmplx_int16_t" static bool run_vec_mul_test(uint32_t num_samples) { - const auto a = allocate_random_cs16(num_samples); - const auto b = allocate_random_cs16(num_samples); - auto c = allocate_random_cs16(num_samples); + armral::utils::cs16_random random; + const auto a = random.vector(num_samples); + const auto b = random.vector(num_samples); + auto c = random.vector(num_samples); auto ref = c; printf("[" NAME "] - %u samples\n", num_samples); @@ -18,13 +19,15 @@ static bool run_vec_mul_test(uint32_t num_samples) { armral_cmplx_vecmul_i16(num_samples, a.data(), b.data(), c.data()); for (uint32_t i = 0; i < num_samples; ++i) { - std::complex res = cmplx_mul_widen_cs16(a[i], b[i]); + std::complex res = + armral::utils::cmplx_mul_widen_cs16(a[i], b[i]); // add one to intermediate result to ensure correct rounding ref[i].re = (((res.real() >> 14) + 1) >> 1).get16(); ref[i].im = (((res.imag() >> 14) + 1) >> 1).get16(); } - return check_results_cs16(NAME, c.data(), ref.data(), num_samples); + return armral::utils::check_results_cs16(NAME, c.data(), ref.data(), + num_samples); } static bool vec_mul_single_val_test(uint32_t num_samples, @@ -38,16 +41,18 @@ static bool vec_mul_single_val_test(uint32_t num_samples, const auto a_vec = std::vector(num_samples, a_val); const auto b_vec = std::vector(num_samples, b_val); - auto c = allocate_random_cs16(num_samples); + auto c = armral::utils::allocate_random_cs16(num_samples); armral_cmplx_vecmul_i16(num_samples, a_vec.data(), b_vec.data(), c.data()); - std::complex res = cmplx_mul_widen_cs16(a_val, b_val); + std::complex res = + armral::utils::cmplx_mul_widen_cs16(a_val, b_val); armral_cmplx_int16_t res_int16; res_int16.re = (((res.real() >> 14) + 1) >> 1).get16(); res_int16.im = (((res.imag() >> 14) + 1) >> 1).get16(); std::vector ref(num_samples, res_int16); - if (!check_results_cs16(NAME, c.data(), ref.data(), num_samples)) { + if (!armral::utils::check_results_cs16(NAME, c.data(), ref.data(), + num_samples)) { // GCOVR_EXCL_START printf("Error for saturating multiplication with values:\n\t " "(%d + %di) * (%d + %di)\n", @@ -83,14 +88,14 @@ static bool run_vec_mul_saturation_test(uint32_t num_samples) { // Entry point for unit test for 16-bit vector multiplication int main(int argc, char **argv) { - const int params[] = { + const uint32_t params[] = { 1, 2, 3, 4, 5, 7, 8, 15, 16, 32, 64, 100, 128, 151, 256, 512, 1024, }; bool passed = true; for (const auto &n : params) { passed &= run_vec_mul_test(n); } - const int saturation_len[] = {1, 3, 8, 9}; + const uint32_t saturation_len[] = {1, 3, 8, 9}; for (auto n : saturation_len) { passed &= run_vec_mul_saturation_test(n); } diff --git a/test/BasicMathFun/VectorMult/VecMul16_2/main.cpp b/test/BasicMathFun/VectorMult/VecMul16_2/main.cpp index 3fa482b..3ee9745 100644 --- a/test/BasicMathFun/VectorMult/VecMul16_2/main.cpp +++ b/test/BasicMathFun/VectorMult/VecMul16_2/main.cpp @@ -3,40 +3,43 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "cs16_utils.hpp" +#include "int_utils.hpp" #include "qint64.hpp" #define NAME "VECMUL armral_cmplx_int16_t IQ" static bool run_vec_mul_test(uint32_t num_samples) { - constexpr int16_t min = -4096; - constexpr int16_t max = 4095; - - const auto a = allocate_random_cs16(num_samples, min, max); - const auto b = allocate_random_cs16(num_samples, min, max); - auto c = allocate_random_cs16(num_samples, min, max); + constexpr armral_cmplx_int16_t min = {-4096, -4096}; + constexpr armral_cmplx_int16_t max = {4095, 4095}; + armral::utils::cs16_random random_cs16; + const auto a = random_cs16.vector(num_samples, min, max); + const auto b = random_cs16.vector(num_samples, min, max); + auto c = random_cs16.vector(num_samples, min, max); auto ref = c; - const auto a_re = unpack_real_cs16(a); - const auto a_im = unpack_imag_cs16(a); - const auto b_re = unpack_real_cs16(b); - const auto b_im = unpack_imag_cs16(b); - auto c_re = unpack_real_cs16(c); - auto c_im = unpack_imag_cs16(c); + const auto a_re = armral::utils::unpack_real_cs16(a); + const auto a_im = armral::utils::unpack_imag_cs16(a); + const auto b_re = armral::utils::unpack_real_cs16(b); + const auto b_im = armral::utils::unpack_imag_cs16(b); + auto c_re = armral::utils::unpack_real_cs16(c); + auto c_im = armral::utils::unpack_imag_cs16(c); printf("[" NAME "] - %u samples\n", num_samples); armral_cmplx_vecmul_i16_2(num_samples, a_re.data(), a_im.data(), b_re.data(), b_im.data(), c_re.data(), c_im.data()); - c = pack_cs16(c_re, c_im); + c = armral::utils::pack_cs16(c_re, c_im); for (uint32_t i = 0; i < num_samples; ++i) { - std::complex res = cmplx_mul_widen_cs16(a[i], b[i]); + std::complex res = + armral::utils::cmplx_mul_widen_cs16(a[i], b[i]); // add one to intermediate result to ensure correct rounding ref[i].re = (((res.real() >> 14) + 1) >> 1).get16(); ref[i].im = (((res.imag() >> 14) + 1) >> 1).get16(); } - return check_results_cs16(NAME, c.data(), ref.data(), num_samples); + return armral::utils::check_results_cs16(NAME, c.data(), ref.data(), + num_samples); } static bool vec_mul_single_val_test(uint32_t num_samples, @@ -52,20 +55,25 @@ static bool vec_mul_single_val_test(uint32_t num_samples, const auto a_im = std::vector(num_samples, a_val.im); const auto b_re = std::vector(num_samples, b_val.re); const auto b_im = std::vector(num_samples, b_val.im); - auto c = allocate_random_cs16(num_samples); - auto c_re = allocate_random_i16(num_samples); - auto c_im = allocate_random_i16(num_samples); + + auto c = armral::utils::allocate_random_cs16(num_samples); + + armral::utils::int_random random_i16; + auto c_re = random_i16.vector(num_samples); + auto c_im = random_i16.vector(num_samples); armral_cmplx_vecmul_i16_2(num_samples, a_re.data(), a_im.data(), b_re.data(), b_im.data(), c_re.data(), c_im.data()); - c = pack_cs16(c_re, c_im); + c = armral::utils::pack_cs16(c_re, c_im); - std::complex res = cmplx_mul_widen_cs16(a_val, b_val); + std::complex res = + armral::utils::cmplx_mul_widen_cs16(a_val, b_val); armral_cmplx_int16_t res_int16; res_int16.re = (((res.real() >> 14) + 1) >> 1).get16(); res_int16.im = (((res.imag() >> 14) + 1) >> 1).get16(); std::vector ref(num_samples, res_int16); - if (!check_results_cs16(NAME, c.data(), ref.data(), num_samples)) { + if (!armral::utils::check_results_cs16(NAME, c.data(), ref.data(), + num_samples)) { // GCOVR_EXCL_START printf("Error for saturating multiplication with values:\n\t " "(%d + %di) * (%d + %di)\n", @@ -100,14 +108,14 @@ static bool run_vec_mul_saturation_test(uint32_t num_samples) { } int main(int argc, char **argv) { - const int params[] = { + const uint32_t params[] = { 1, 2, 3, 4, 5, 7, 8, 15, 16, 32, 64, 100, 128, 151, 256, 512, 1024, }; bool passed = true; for (const auto &n : params) { passed &= run_vec_mul_test(n); } - const int saturation_len[] = { + const uint32_t saturation_len[] = { 1, 3, 4, 5, 8, 16, 17, }; for (auto n : saturation_len) { diff --git a/test/BasicMathFun/VectorMult/VecMul32/main.cpp b/test/BasicMathFun/VectorMult/VecMul32/main.cpp index 0455a42..991b6bd 100644 --- a/test/BasicMathFun/VectorMult/VecMul32/main.cpp +++ b/test/BasicMathFun/VectorMult/VecMul32/main.cpp @@ -7,9 +7,10 @@ #define NAME "VECMUL armral_cmplx_f32_t" static bool run_vec_mul_test(uint32_t num_samples) { - const auto a = allocate_random_cf32(num_samples); - const auto b = allocate_random_cf32(num_samples); - auto c = allocate_random_cf32(num_samples); + armral::utils::cf32_random random; + const auto a = random.vector(num_samples); + const auto b = random.vector(num_samples); + auto c = random.vector(num_samples); auto ref = c; printf("[" NAME "] - %u samples\n", num_samples); @@ -17,16 +18,17 @@ static bool run_vec_mul_test(uint32_t num_samples) { armral_cmplx_vecmul_f32(num_samples, a.data(), b.data(), c.data()); for (uint32_t i = 0; i < num_samples; ++i) { - auto res = cmplx_mul_widen_cf32(a[i], b[i]); + auto res = armral::utils::cmplx_mul_widen_cf32(a[i], b[i]); ref[i].re = res.real(); ref[i].im = res.imag(); } - return check_results_cf32(NAME, c.data(), ref.data(), num_samples); + return armral::utils::check_results_cf32(NAME, c.data(), ref.data(), + num_samples); } int main(int argc, char **argv) { - const int params[] = { + const uint32_t params[] = { 1, 2, 3, 4, 5, 7, 8, 15, 16, 32, 64, 100, 128, 151, 256, 512, 1024, }; bool passed = true; diff --git a/test/BasicMathFun/VectorMult/VecMul32_2/main.cpp b/test/BasicMathFun/VectorMult/VecMul32_2/main.cpp index bda9a5e..38ac06e 100644 --- a/test/BasicMathFun/VectorMult/VecMul32_2/main.cpp +++ b/test/BasicMathFun/VectorMult/VecMul32_2/main.cpp @@ -9,35 +9,37 @@ #define NAME "VECMUL armral_cmplx_f32_t IQ" static bool run_vec_mul_test(uint32_t num_samples) { - const auto a = allocate_random_cf32(num_samples); - const auto b = allocate_random_cf32(num_samples); - auto c = allocate_random_cf32(num_samples); + armral::utils::cf32_random random; + const auto a = random.vector(num_samples); + const auto b = random.vector(num_samples); + auto c = random.vector(num_samples); auto ref = c; - const auto a_re = unpack_real_cf32(a); - const auto a_im = unpack_imag_cf32(a); - const auto b_re = unpack_real_cf32(b); - const auto b_im = unpack_imag_cf32(b); - auto c_re = unpack_real_cf32(c); - auto c_im = unpack_imag_cf32(c); + const auto a_re = armral::utils::unpack_real_cf32(a); + const auto a_im = armral::utils::unpack_imag_cf32(a); + const auto b_re = armral::utils::unpack_real_cf32(b); + const auto b_im = armral::utils::unpack_imag_cf32(b); + auto c_re = armral::utils::unpack_real_cf32(c); + auto c_im = armral::utils::unpack_imag_cf32(c); printf("[" NAME "] - %u samples\n", num_samples); armral_cmplx_vecmul_f32_2(num_samples, a_re.data(), a_im.data(), b_re.data(), b_im.data(), c_re.data(), c_im.data()); - c = pack_cf32(c_re, c_im); + c = armral::utils::pack_cf32(c_re, c_im); for (uint32_t i = 0; i < num_samples; ++i) { - auto res = cmplx_mul_widen_cf32(a[i], b[i]); + auto res = armral::utils::cmplx_mul_widen_cf32(a[i], b[i]); ref[i].re = res.real(); ref[i].im = res.imag(); } - return check_results_cf32(NAME, c.data(), ref.data(), num_samples); + return armral::utils::check_results_cf32(NAME, c.data(), ref.data(), + num_samples); } int main(int argc, char **argv) { - const int params[] = { + const uint32_t params[] = { 1, 2, 3, 4, 5, 7, 8, 15, 16, 32, 64, 100, 128, 151, 256, 512, 1024, }; bool passed = true; diff --git a/test/DuRuInterface/MuLaw/Compression/main.cpp b/test/DuRuInterface/MuLaw/Compression/main.cpp index e107e1f..d8daa67 100644 --- a/test/DuRuInterface/MuLaw/Compression/main.cpp +++ b/test/DuRuInterface/MuLaw/Compression/main.cpp @@ -4,12 +4,18 @@ */ #include "armral.h" #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include "qint64.hpp" #include #include +using armral::utils::allocate_random_cs16; +using armral::utils::allocate_random_i8; +using armral::utils::allocate_random_shifted_cs16; +using armral::utils::cmplx_mul_widen_cs16; +using armral::utils::qint64_t; + namespace { int16_t sign(int16_t x) { return x >= 0 ? 1 : -1; @@ -346,12 +352,12 @@ int main(int argc, char **argv) { for (int num_prbs : params) { for (auto [interval_min, interval_max] : intervals) { printf("Testing (%d, %d)\n", interval_min, interval_max); - passed &= run_compression_test_8b(num_prbs, interval_min, interval_max, - nullptr); - passed &= run_compression_test_9b(num_prbs, interval_min, interval_max, - nullptr); - passed &= run_compression_test_14b(num_prbs, interval_min, interval_max, - nullptr); + passed &= + run_compression_test_8b(num_prbs, interval_min, interval_max, NULL); + passed &= + run_compression_test_9b(num_prbs, interval_min, interval_max, NULL); + passed &= + run_compression_test_14b(num_prbs, interval_min, interval_max, NULL); // Magnitude of scale factor is not expected to be greater than 1, so get // a random val in range [(-sqrt(0.5),-sqrt(0.5)), (sqrt(0.5),sqrt(0.5))] diff --git a/test/DuRuInterface/MuLaw/Decompression/main.cpp b/test/DuRuInterface/MuLaw/Decompression/main.cpp index eb6ed52..f59c352 100644 --- a/test/DuRuInterface/MuLaw/Decompression/main.cpp +++ b/test/DuRuInterface/MuLaw/Decompression/main.cpp @@ -4,7 +4,7 @@ */ #include "armral.h" #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" #if ARMRAL_ARCH_SVE >= 2 #include "qint64.hpp" #endif @@ -12,6 +12,12 @@ #include #include +using armral::utils::allocate_random_cs16; +using armral::utils::allocate_random_i8; +using armral::utils::check_results_cs16; +using armral::utils::cmplx_mul_widen_cs16; +using armral::utils::qint64_t; + namespace { int sign(int x) { @@ -180,7 +186,7 @@ bool run_mu_law_decompression_test_8b(const int num_prbs, auto dst = allocate_random_cs16(num_prbs * ARMRAL_NUM_COMPLEX_SAMPLES); assert(ref.size() == dst.size()); - if (scale != nullptr) { + if (scale != NULL) { printf("[%s] - scale = (%d, %d)\n", name, scale->re, scale->im); } printf("[%s] - %d resources\n", name, num_prbs); @@ -196,7 +202,7 @@ bool run_mu_law_decompression_test_9b(const int num_prbs, auto dst = allocate_random_cs16(num_prbs * ARMRAL_NUM_COMPLEX_SAMPLES); assert(ref.size() == dst.size()); - if (scale != nullptr) { + if (scale != NULL) { printf("[%s] - scale = (%d, %d)\n", name, scale->re, scale->im); } printf("[%s] - %d resources\n", name, num_prbs); @@ -212,7 +218,7 @@ bool run_mu_law_decompression_test_14b(const int num_prbs, auto dst = allocate_random_cs16(num_prbs * ARMRAL_NUM_COMPLEX_SAMPLES); assert(ref.size() == dst.size()); - if (scale != nullptr) { + if (scale != NULL) { printf("[%s] - scale = (%d, %d)\n", name, scale->re, scale->im); } printf("[%s] - %d resources\n", name, num_prbs); @@ -228,9 +234,9 @@ int main(int argc, char **argv) { bool passed = true; for (int nprbs : params) { - passed &= run_mu_law_decompression_test_8b(nprbs, nullptr); - passed &= run_mu_law_decompression_test_9b(nprbs, nullptr); - passed &= run_mu_law_decompression_test_14b(nprbs, nullptr); + passed &= run_mu_law_decompression_test_8b(nprbs, NULL); + passed &= run_mu_law_decompression_test_9b(nprbs, NULL); + passed &= run_mu_law_decompression_test_14b(nprbs, NULL); armral_cmplx_int16_t scale = allocate_random_cs16(1)[0]; passed &= run_mu_law_decompression_test_8b(nprbs, &scale); diff --git a/test/DuRuInterface/ORanBlockFloat/Compression/main.cpp b/test/DuRuInterface/ORanBlockFloat/Compression/main.cpp index ecf5d28..9888d63 100644 --- a/test/DuRuInterface/ORanBlockFloat/Compression/main.cpp +++ b/test/DuRuInterface/ORanBlockFloat/Compression/main.cpp @@ -3,12 +3,18 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include "qint64.hpp" #include #include +using armral::utils::allocate_random_cs16; +using armral::utils::allocate_random_i8; +using armral::utils::allocate_random_shifted_cs16; +using armral::utils::cmplx_mul_widen_cs16; +using armral::utils::qint64_t; + template static bool check_results_cd(const char *name, const T *result, const T *expected, uint32_t n) { @@ -235,7 +241,7 @@ static bool run_compression_test_8b(const int num_prbs, int16_t min, const auto ref = compression_reference_8b(src, scale); auto dst = allocate_random_cd8(num_prbs); - if (scale != nullptr) { + if (scale != NULL) { printf("[%s] - scale = (%d, %d)\n", name, scale->re, scale->im); } printf("[%s] - %d resources\n", name, num_prbs); @@ -252,7 +258,7 @@ static bool run_compression_test_9b(const int num_prbs, int16_t min, const auto ref = compression_reference_9b(src, scale); auto dst = allocate_random_cd9(num_prbs); - if (scale != nullptr) { + if (scale != NULL) { printf("[%s] - scale = (%d, %d)\n", name, scale->re, scale->im); } printf("[%s] - %d resources\n", name, num_prbs); @@ -269,7 +275,7 @@ static bool run_compression_test_12b(const int num_prbs, int16_t min, const auto ref = compression_reference_12b(src, scale); auto dst = allocate_random_cd12(num_prbs); - if (scale != nullptr) { + if (scale != NULL) { printf("[%s] - scale = (%d, %d)\n", name, scale->re, scale->im); } printf("[%s] - %d resources\n", name, num_prbs); @@ -286,7 +292,7 @@ static bool run_compression_test_14b(const int num_prbs, int16_t min, const auto ref = compression_reference_14b(src, scale); auto dst = allocate_random_cd14(num_prbs); - if (scale != nullptr) { + if (scale != NULL) { printf("[%s] - scale = (%d, %d)\n", name, scale->re, scale->im); } printf("[%s] - %d resources\n", name, num_prbs); @@ -302,7 +308,7 @@ int main(int argc, char **argv) { // Magnitude of scale factor is not expected to be greater than 1, so get // a random val in range [(-sqrt(0.5),-sqrt(0.5)), (sqrt(0.5),sqrt(0.5))] armral_cmplx_int16_t scale = allocate_random_cs16(1, 0xA581, 0x5A7F)[0]; - std::vector scales{nullptr, &scale}; + std::vector scales{NULL, &scale}; bool passed = true; for (auto *const s : scales) { diff --git a/test/DuRuInterface/ORanBlockFloat/Decompression/main.cpp b/test/DuRuInterface/ORanBlockFloat/Decompression/main.cpp index 087d3a0..2e8a135 100644 --- a/test/DuRuInterface/ORanBlockFloat/Decompression/main.cpp +++ b/test/DuRuInterface/ORanBlockFloat/Decompression/main.cpp @@ -3,10 +3,15 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include +using armral::utils::allocate_random_cs16; +using armral::utils::allocate_random_i8; +using armral::utils::check_results_cs16; +using armral::utils::scale_and_truncate_cs16; + template static std::vector allocate_random_cd(uint32_t len) { const auto bytes = allocate_random_i8(len * sizeof(T)); @@ -154,7 +159,7 @@ static bool run_decompression_test_8b(const int num_prbs, auto dst = allocate_random_cs16(num_prbs * ARMRAL_NUM_COMPLEX_SAMPLES); assert(ref.size() == dst.size()); - if (scale != nullptr) { + if (scale != NULL) { printf("[%s] - scale = (%d, %d)\n", name, scale->re, scale->im); } @@ -171,7 +176,7 @@ static bool run_decompression_test_9b(const int num_prbs, auto dst = allocate_random_cs16(num_prbs * ARMRAL_NUM_COMPLEX_SAMPLES); assert(ref.size() == dst.size()); - if (scale != nullptr) { + if (scale != NULL) { printf("[%s] - scale = (%d, %d)\n", name, scale->re, scale->im); } @@ -188,7 +193,7 @@ static bool run_decompression_test_12b(const int num_prbs, auto dst = allocate_random_cs16(num_prbs * ARMRAL_NUM_COMPLEX_SAMPLES); assert(ref.size() == dst.size()); - if (scale != nullptr) { + if (scale != NULL) { printf("[%s] - scale = (%d, %d)\n", name, scale->re, scale->im); } @@ -205,7 +210,7 @@ static bool run_decompression_test_14b(const int num_prbs, auto dst = allocate_random_cs16(num_prbs * ARMRAL_NUM_COMPLEX_SAMPLES); assert(ref.size() == dst.size()); - if (scale != nullptr) { + if (scale != NULL) { printf("[%s] - scale = (%d, %d)\n", name, scale->re, scale->im); } @@ -220,10 +225,10 @@ int main(int argc, char **argv) { }; bool passed = true; for (int nprbs : params) { - passed &= run_decompression_test_8b(nprbs, nullptr); - passed &= run_decompression_test_9b(nprbs, nullptr); - passed &= run_decompression_test_12b(nprbs, nullptr); - passed &= run_decompression_test_14b(nprbs, nullptr); + passed &= run_decompression_test_8b(nprbs, NULL); + passed &= run_decompression_test_9b(nprbs, NULL); + passed &= run_decompression_test_12b(nprbs, NULL); + passed &= run_decompression_test_14b(nprbs, NULL); armral_cmplx_int16_t scale = allocate_random_cs16(1)[0]; passed &= run_decompression_test_8b(nprbs, &scale); diff --git a/test/DuRuInterface/ORanBlockScaling/Compression/main.cpp b/test/DuRuInterface/ORanBlockScaling/Compression/main.cpp index 72d680c..6cc3baa 100644 --- a/test/DuRuInterface/ORanBlockScaling/Compression/main.cpp +++ b/test/DuRuInterface/ORanBlockScaling/Compression/main.cpp @@ -3,12 +3,18 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include "qint64.hpp" #include #include +using armral::utils::allocate_random_cs16; +using armral::utils::allocate_random_i8; +using armral::utils::allocate_random_shifted_cs16; +using armral::utils::cmplx_mul_widen_cs16; +using armral::utils::qint64_t; + namespace { int calculate_scaling_fact(const armral_cmplx_int16_t *prb, uint32_t n, diff --git a/test/DuRuInterface/ORanBlockScaling/Decompression/main.cpp b/test/DuRuInterface/ORanBlockScaling/Decompression/main.cpp index 45cb5fe..37bf2d1 100644 --- a/test/DuRuInterface/ORanBlockScaling/Decompression/main.cpp +++ b/test/DuRuInterface/ORanBlockScaling/Decompression/main.cpp @@ -3,11 +3,16 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include #include +using armral::utils::allocate_random_cs16; +using armral::utils::allocate_random_i8; +using armral::utils::check_results_cs16; +using armral::utils::scale_and_truncate_cs16; + namespace { template diff --git a/test/LowerPHY/Correlation/main.cpp b/test/LowerPHY/Correlation/main.cpp index fa31a8b..1a90846 100644 --- a/test/LowerPHY/Correlation/main.cpp +++ b/test/LowerPHY/Correlation/main.cpp @@ -5,6 +5,8 @@ #include "cs16_utils.hpp" #include "qint64.hpp" +using armral::utils::qint64_t; + static bool check_single_cs16(const char *name, armral_cmplx_int16_t result, armral_cmplx_int16_t expected) { bool passed = true; @@ -22,11 +24,9 @@ static bool check_single_cs16(const char *name, armral_cmplx_int16_t result, expected.im); // GCOVR_EXCL_STOP } - if (passed) { - printf("[%s] - check result: OK\n", name); - } else { - printf("[%s] - check result: ERROR\n", name); // GCOVR_EXCL_LINE - } + + printf("[%s] - check result: %s\n", name, passed ? "OK" : "ERROR"); + return passed; } @@ -90,21 +90,24 @@ reference_correlation(int n, const armral_cmplx_int16_t *a, return {ret64.real().get16(), ret64.imag().get16()}; } -static bool run_correlation_test(int n) { +static bool run_correlation_test(uint32_t n) { // use a smaller domain of inputs to avoid hitting saturation - const auto a = allocate_random_cs16(n, -1024, 1024); - const auto b = allocate_random_cs16(n, -1024, 1024); + constexpr armral_cmplx_int16_t min = {-1024, -1024}; + constexpr armral_cmplx_int16_t max = {1024, 1024}; + armral::utils::cs16_random random; + const auto a = random.vector(n, min, max); + const auto b = random.vector(n, min, max); const auto ref = reference_correlation(n, a.data(), b.data()); - auto c = allocate_random_cs16(1); + auto c = random.one(); - printf("[CORRELATION] - %d samples\n", n); - armral_corr_coeff_i16(n, a.data(), b.data(), c.data()); - return check_single_cs16("CORRELATION", c[0], ref); + printf("[CORRELATION] - %u samples\n", n); + armral_corr_coeff_i16(n, a.data(), b.data(), &c); + return check_single_cs16("CORRELATION", c, ref); } int main(int argc, char **argv) { bool passed = true; - for (int n = 0; n < 1026; ++n) { + for (uint32_t n = 1; n < 1026; ++n) { passed &= run_correlation_test(n); } exit(passed ? EXIT_SUCCESS : EXIT_FAILURE); diff --git a/test/LowerPHY/FFT/FFT16/main.cpp b/test/LowerPHY/FFT/FFT16/main.cpp index 33b4b98..3db2dc3 100644 --- a/test/LowerPHY/FFT/FFT16/main.cpp +++ b/test/LowerPHY/FFT/FFT16/main.cpp @@ -53,13 +53,9 @@ static bool check_fft_results(const char *name, } } - if (passed) { - printf( - "[%s] - check result: OK, max error was %.10f vs tolerance of %.10f\n", - name, max_error, tol); - } else { - printf("[%s] - check result: ERROR\n", name); // GCOVR_EXCL_LINE - } + printf("[%s] - check result: %s, max error was %.10f vs tolerance of %.10f\n", + name, passed ? "OK" : "ERROR", max_error, tol); + return passed; } @@ -71,14 +67,17 @@ run_fft_ref(int n, armral_fft_direction_t dir, const armral_cmplx_int16_t *x) { in[i].real(x[i].re / (double)(1 << 15)); in[i].imag(x[i].im / (double)(1 << 15)); } - fft_ref(n, 1, dir, in.data(), out.data()); - return narrow_to_cf32(out); + armral::utils::fft_ref(n, 1, dir, in.data(), out.data()); + return armral::utils::narrow_to_cf32(out); } static bool run_fft_test(int n, armral_fft_direction_t dir) { printf("Testing FFT n=%d dir=%d\n", n, (int)dir); - const auto x = allocate_random_cs16(n, -4096, 4096); - auto y = allocate_random_cs16(n, -4096, 4096); + constexpr armral_cmplx_int16_t min = {-4096, -4096}; + constexpr armral_cmplx_int16_t max = {4095, 4095}; + armral::utils::cs16_random random; + const auto x = random.vector(n, min, max); + auto y = random.vector(n, min, max); const auto y_ref = run_fft_ref(n, dir, x.data()); armral_fft_plan_t *p; @@ -97,12 +96,12 @@ static bool run_fft_test(int n, armral_fft_direction_t dir) { int main(int argc, char **argv) { bool passed = true; - int ns[] = {2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, - 24, 25, 32, 35, 36, 40, 45, 46, 47, 48, 50, - 64, 65, 66, 68, 77, 81, 98, 99, 102, 112, 136, - 121, 169, 170, 204, 238, 255, 272, 289, 342, 361, 440, - 441, 484, 529, 552, 768, 800, 1024, 1104, 2048, 2401}; + constexpr int ns[] = { + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 32, 35, + 36, 40, 45, 46, 47, 48, 50, 64, 65, 66, 68, 77, 81, + 98, 99, 102, 112, 136, 121, 169, 170, 204, 238, 255, 272, 289, + 342, 361, 440, 441, 484, 529, 552, 768, 800, 1024, 1104, 2048, 2401}; for (int n : ns) { for (auto dir : {ARMRAL_FFT_FORWARDS, ARMRAL_FFT_BACKWARDS}) { passed &= run_fft_test(n, dir); diff --git a/test/LowerPHY/FFT/FFT32/main.cpp b/test/LowerPHY/FFT/FFT32/main.cpp index bdec57c..57275ba 100644 --- a/test/LowerPHY/FFT/FFT32/main.cpp +++ b/test/LowerPHY/FFT/FFT32/main.cpp @@ -42,28 +42,25 @@ static bool check_fft_results(const char *name, } } - if (passed) { - printf( - "[%s] - check result: OK, max error was %.10f vs tolerance of %.10f\n", - name, max_error, tol); - } else { - printf("[%s] - check result: ERROR\n", name); // GCOVR_EXCL_LINE - } + printf("[%s] - check result: %s, max error was %.10f vs tolerance of %.10f\n", + name, passed ? "OK" : "ERROR", max_error, tol); + return passed; } static std::vector run_fft_ref(int n, armral_fft_direction_t dir, const armral_cmplx_f32_t *x) { - std::vector> in = widen_cf32(x, n); + std::vector> in = armral::utils::widen_cf32(x, n); std::vector> out(n); - fft_ref(n, 1, dir, in.data(), out.data()); - return narrow_to_cf32(out); + armral::utils::fft_ref(n, 1, dir, in.data(), out.data()); + return armral::utils::narrow_to_cf32(out); } static bool run_fft_test(int n, armral_fft_direction_t dir) { printf("Testing FFT n=%d dir=%d\n", n, (int)dir); - const auto x = allocate_random_cf32(n); - auto y = allocate_random_cf32(n); + armral::utils::cf32_random random; + const auto x = random.vector(n); + auto y = random.vector(n); const auto y_ref = run_fft_ref(n, dir, x.data()); armral_fft_plan_t *p = nullptr; @@ -82,11 +79,12 @@ static bool run_fft_test(int n, armral_fft_direction_t dir) { int main(int argc, char **argv) { bool passed = true; - int ns[] = {2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, - 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, - 32, 46, 47, 64, 65, 66, 68, 77, 99, 102, 136, 121, - 169, 170, 204, 238, 255, 272, 289, 342, 361, 440, 441, 484, - 529, 552, 768, 800, 1024, 1063, 1104, 1728, 2048, 2401, 3240}; + constexpr int ns[] = {2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 32, 46, 47, 64, 65, 66, + 68, 77, 99, 102, 136, 121, 169, 170, 204, 238, + 255, 272, 289, 342, 361, 440, 441, 484, 529, 552, + 768, 800, 1024, 1063, 1104, 1728, 2048, 2401, 3240}; for (int n : ns) { for (auto dir : {ARMRAL_FFT_FORWARDS, ARMRAL_FFT_BACKWARDS}) { passed &= run_fft_test(n, dir); diff --git a/test/LowerPHY/FIR/FIR16/main.cpp b/test/LowerPHY/FIR/FIR16/main.cpp index 1cda4e8..22604a3 100644 --- a/test/LowerPHY/FIR/FIR16/main.cpp +++ b/test/LowerPHY/FIR/FIR16/main.cpp @@ -13,10 +13,11 @@ static uint32_t iround_8(uint32_t x) { static bool run_fir_test(uint32_t num_samples, uint32_t num_taps) { assert(num_samples % 8 == 0); - auto input = allocate_random_cs16(iround_8(num_samples + num_taps)); - auto coeffs = allocate_random_cs16(iround_8(num_taps)); - auto output = allocate_random_cs16(num_samples); - auto ref = allocate_random_cs16(num_samples); + armral::utils::cs16_random random; + auto input = random.vector(iround_8(num_samples + num_taps)); + auto coeffs = random.vector(iround_8(num_taps)); + auto output = random.vector(num_samples); + auto ref = random.vector(num_samples); printf("[" NAME "] - %u samples - %u taps\n", num_samples, num_taps); @@ -30,25 +31,26 @@ static bool run_fir_test(uint32_t num_samples, uint32_t num_taps) { for (uint32_t j = 0; j < num_taps; ++j) { // note: this is i+j since the FIR coefficients are assumed to be // swapped in memory (part of the interface). - acc += cmplx_mul_widen_cs16(input[i + j], coeffs[j]); + acc += armral::utils::cmplx_mul_widen_cs16(input[i + j], coeffs[j]); } ref[i].re = acc.real() >> 16; ref[i].im = acc.imag() >> 16; } - return check_results_cs16(NAME, (int16_t *)output.data(), - (int16_t *)ref.data(), num_samples); + return armral::utils::check_results_cs16(NAME, (int16_t *)output.data(), + (int16_t *)ref.data(), num_samples); } int main(int argc, char **argv) { - std::pair params[] = { - {16, 8}, {16, 16}, {128, 8}, {128, 9}, {128, 10}, {256, 14}, - {256, 16}, {512, 16}, {512, 15}, {512, 19}, {1024, 19}, {1024, 23}, - {2048, 23}, {2048, 24}, {4096, 25}, {4096, 28}, {8192, 30}, {8192, 32}, - {128, 3}, {1024, 5}, {2048, 6}, {4096, 7}, {10240, 32}, + const std::pair params[] = { + {16, 8}, {16, 16}, {128, 8}, {128, 9}, {128, 10}, {184, 12}, + {184, 13}, {184, 16}, {256, 14}, {256, 16}, {512, 16}, {512, 15}, + {512, 19}, {1024, 19}, {1024, 23}, {2048, 23}, {2048, 24}, {4096, 25}, + {4096, 28}, {8192, 30}, {8192, 32}, {128, 3}, {1024, 5}, {2048, 6}, + {4096, 7}, {10240, 32}, }; bool passed = true; - for (auto &p : params) { + for (const auto &p : params) { passed &= run_fir_test(p.first, p.second); } exit(passed ? EXIT_SUCCESS : EXIT_FAILURE); diff --git a/test/LowerPHY/FIR/FIR16Decimate2/main.cpp b/test/LowerPHY/FIR/FIR16Decimate2/main.cpp index ab179e5..d5588c2 100644 --- a/test/LowerPHY/FIR/FIR16Decimate2/main.cpp +++ b/test/LowerPHY/FIR/FIR16Decimate2/main.cpp @@ -18,10 +18,11 @@ static uint32_t iround_8(uint32_t x) { static bool run_fir_test(uint32_t num_samples, uint32_t num_taps) { assert(num_samples % 8 == 0); - auto input = allocate_random_cs16(iround_8(num_samples + num_taps)); - auto coeffs = allocate_random_cs16(iround_4(num_taps)); - auto output = allocate_random_cs16(num_samples / 2); - auto ref = allocate_random_cs16(num_samples / 2); + armral::utils::cs16_random random; + auto input = random.vector(iround_8(num_samples + num_taps)); + auto coeffs = random.vector(iround_4(num_taps)); + auto output = random.vector(num_samples / 2); + auto ref = random.vector(num_samples / 2); printf("[" NAME "] - %u samples - %u taps\n", num_samples, num_taps); @@ -36,13 +37,14 @@ static bool run_fir_test(uint32_t num_samples, uint32_t num_taps) { for (uint32_t j = 0; j < num_taps; ++j) { // note: this is i*2+j since the FIR coefficients are assumed to be // swapped in memory (part of the interface). - acc += cmplx_mul_widen_cs16(input[i * 2 + j], coeffs[j]); + acc += armral::utils::cmplx_mul_widen_cs16(input[i * 2 + j], coeffs[j]); } ref[i].re = acc.real() >> 16; ref[i].im = acc.imag() >> 16; } - return check_results_cs16(NAME, output.data(), ref.data(), num_samples / 2); + return armral::utils::check_results_cs16(NAME, output.data(), ref.data(), + num_samples / 2); } int main(int argc, char **argv) { diff --git a/test/LowerPHY/FIR/FIR32/main.cpp b/test/LowerPHY/FIR/FIR32/main.cpp index 2112b9f..910b761 100644 --- a/test/LowerPHY/FIR/FIR32/main.cpp +++ b/test/LowerPHY/FIR/FIR32/main.cpp @@ -13,10 +13,11 @@ static uint32_t iround_4(uint32_t x) { static bool run_fir_test(uint32_t num_samples, uint32_t num_taps) { assert(num_samples % 4 == 0); - auto input = allocate_random_cf32(iround_4(num_samples + num_taps)); - auto coeffs = allocate_random_cf32(iround_4(num_taps)); - auto output = allocate_random_cf32(num_samples); - auto ref = allocate_random_cf32(num_samples); + armral::utils::cf32_random random; + auto input = random.vector(iround_4(num_samples + num_taps)); + auto coeffs = random.vector(iround_4(num_taps)); + auto output = random.vector(num_samples); + auto ref = random.vector(num_samples); printf("[" NAME "] - %u samples - %u taps\n", num_samples, num_taps); @@ -30,13 +31,14 @@ static bool run_fir_test(uint32_t num_samples, uint32_t num_taps) { for (uint32_t j = 0; j < num_taps; ++j) { // note: this is i+j since the FIR coefficients are assumed to be // swapped in memory (part of the interface). - acc += cmplx_mul_widen_cf32(input[i + j], coeffs[j]); + acc += armral::utils::cmplx_mul_widen_cf32(input[i + j], coeffs[j]); } ref[i].re = acc.real(); ref[i].im = acc.imag(); } - return check_results_cf32(NAME, output.data(), ref.data(), num_samples); + return armral::utils::check_results_cf32(NAME, output.data(), ref.data(), + num_samples); } int main(int argc, char **argv) { diff --git a/test/LowerPHY/FIR/FIR32Decimate2/main.cpp b/test/LowerPHY/FIR/FIR32Decimate2/main.cpp index 361e705..671ec43 100644 --- a/test/LowerPHY/FIR/FIR32Decimate2/main.cpp +++ b/test/LowerPHY/FIR/FIR32Decimate2/main.cpp @@ -18,10 +18,11 @@ static uint32_t iround_8(uint32_t x) { static bool run_fir_test(int num_samples, int num_taps) { assert(num_samples % 8 == 0); - auto input = allocate_random_cf32(iround_8(num_samples + num_taps)); - auto coeffs = allocate_random_cf32(iround_4(num_taps)); - auto output = allocate_random_cf32(num_samples / 2); - auto ref = allocate_random_cf32(num_samples / 2); + armral::utils::cf32_random random; + auto input = random.vector(iround_8(num_samples + num_taps)); + auto coeffs = random.vector(iround_4(num_taps)); + auto output = random.vector(num_samples / 2); + auto ref = random.vector(num_samples / 2); printf("[" NAME "] - %d samples - %d taps\n", num_samples, num_taps); @@ -36,13 +37,14 @@ static bool run_fir_test(int num_samples, int num_taps) { for (int j = 0; j < num_taps; ++j) { // note: this is i*2+j since the FIR coefficients are assumed to be // swapped in memory (part of the interface). - acc += cmplx_mul_widen_cf32(input[i * 2 + j], coeffs[j]); + acc += armral::utils::cmplx_mul_widen_cf32(input[i * 2 + j], coeffs[j]); } ref[i].re = acc.real(); ref[i].im = acc.imag(); } - return check_results_cf32(NAME, output.data(), ref.data(), num_samples / 2); + return armral::utils::check_results_cf32(NAME, output.data(), ref.data(), + num_samples / 2); } int main(int argc, char **argv) { diff --git a/test/LowerPHY/Scrambling/main.cpp b/test/LowerPHY/Scrambling/main.cpp index 36ab300..39fd597 100644 --- a/test/LowerPHY/Scrambling/main.cpp +++ b/test/LowerPHY/Scrambling/main.cpp @@ -3,8 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "bit_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" static std::vector reference_scrambler(const uint8_t *input, const uint8_t *seq, @@ -25,8 +24,9 @@ static std::vector reference_scrambler(const uint8_t *input, static bool run_scrambling_test(uint32_t len) { uint32_t len_bytes = (((uint64_t)len) + 7) / 8; // Random source and destination data. - auto in = allocate_random_u8(len_bytes); - auto dst = allocate_random_u8(len_bytes); + armral::utils::int_random random; + auto in = random.vector(len_bytes); + auto dst = random.vector(len_bytes); // Generate Gold sequence. std::vector sequence(len_bytes); @@ -36,7 +36,8 @@ static bool run_scrambling_test(uint32_t len) { printf("[SCRAMBLING] len = %u\n", len); armral_scramble_code_block(in.data(), sequence.data(), len, dst.data()); - return check_results_u8("SCRAMBLING", dst.data(), ref.data(), len_bytes); + return armral::utils::check_results_u8("SCRAMBLING", dst.data(), ref.data(), + len_bytes); } int main(int argc, char **argv) { diff --git a/test/LowerPHY/SeqGenerator/main.cpp b/test/LowerPHY/SeqGenerator/main.cpp index 70d50e6..7433915 100644 --- a/test/LowerPHY/SeqGenerator/main.cpp +++ b/test/LowerPHY/SeqGenerator/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" static std::vector reference_sequence_generator(uint32_t seed, uint32_t len_bytes) { @@ -44,11 +44,11 @@ static bool run_sequence_generator_test(uint32_t seed, uint32_t len) { // of sequence, but we don't support that so just round up to the next // multiple of 8 and use the byte count instead. const auto ref = reference_sequence_generator(seed, len_bytes); - auto dst = allocate_random_u8(len_bytes); + auto dst = armral::utils::allocate_random_u8(len_bytes); printf("[SEQUENCE GENERATOR] seed=%u, len=%u\n", seed, len); armral_seq_generator(len, seed, dst.data()); - return check_results_u8("SEQUENCE GENERATOR", dst.data(), ref.data(), - (uint32_t)len_bytes); + return armral::utils::check_results_u8("SEQUENCE GENERATOR", dst.data(), + ref.data(), (uint32_t)len_bytes); } int main(int argc, char **argv) { diff --git a/test/MatrixFactorizations/SVD/main.cpp b/test/MatrixFactorizations/SVD/main.cpp index 5fe0f85..87ae255 100644 --- a/test/MatrixFactorizations/SVD/main.cpp +++ b/test/MatrixFactorizations/SVD/main.cpp @@ -3,22 +3,13 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ +#include "reference_linalg.hpp" #include "svd_sample_data.h" #include "svd_test.hpp" -namespace { +using armral::utils::convert_cf32_array_to_vector; -// Routine for converting a vector of armral_cmplx_f32_t -// to a vector of complex. -std::vector> -convert_arm_cf32_to_complex(uint16_t nvalues, - const std::vector &a) { - std::vector> out(nvalues); - for (unsigned i = 0; i < nvalues; ++i) { - out[i] = std::complex(a[i].re, a[i].im); - } - return out; -} +namespace { // Check the accuracy of our implementation // using sample data instead of randomly generated @@ -52,9 +43,10 @@ bool test_svd_with_sample(SVDFunction svd_function_under_test) { passed &= check_singular_values(n, test.s, s); // Convert data to complex for testing - auto aref_cmplx = convert_arm_cf32_to_complex(size, test.a); - auto u_cmplx = convert_arm_cf32_to_complex(size, u); - auto vt_cmplx = convert_arm_cf32_to_complex(n * n, vt); + auto aref_cmplx = + convert_cf32_array_to_vector(size, test.a.data()); + auto u_cmplx = convert_cf32_array_to_vector(size, u.data()); + auto vt_cmplx = convert_cf32_array_to_vector(n * n, vt.data()); // Check the accuracy of the full decomposition passed &= check_svd_decomposition(m, n, aref_cmplx, s, u_cmplx, vt_cmplx); @@ -77,8 +69,7 @@ bool test_svd(bool gen_singular_vectors, int m, int n, float cond, std::vector a(size); std::vector s(n); std::vector sref(n); - int seed = 0; - generate_svd_matrix(m, n, a, sref, cond, seed); + generate_svd_matrix(m, n, a, sref, cond); // Make copy of A. std::vector aref = a; @@ -100,9 +91,10 @@ bool test_svd(bool gen_singular_vectors, int m, int n, float cond, if (gen_singular_vectors) { // Convert data to complex for testing - auto aref_cmplx = convert_arm_cf32_to_complex(size, aref); - auto u_cmplx = convert_arm_cf32_to_complex(size, u); - auto vt_cmplx = convert_arm_cf32_to_complex(n * n, vt); + auto aref_cmplx = + convert_cf32_array_to_vector(size, aref.data()); + auto u_cmplx = convert_cf32_array_to_vector(size, u.data()); + auto vt_cmplx = convert_cf32_array_to_vector(n * n, vt.data()); // Check the accuracy of the full decomposition passed &= check_svd_decomposition(m, n, aref_cmplx, s, u_cmplx, vt_cmplx); diff --git a/test/MatrixFactorizations/SVD/svd_test.hpp b/test/MatrixFactorizations/SVD/svd_test.hpp index 3cbcafb..d23cef6 100644 --- a/test/MatrixFactorizations/SVD/svd_test.hpp +++ b/test/MatrixFactorizations/SVD/svd_test.hpp @@ -14,6 +14,7 @@ #include "MatrixFactorizations/SVD/matrix_view.hpp" #include "cf32_utils.hpp" +#include "reference_linalg.hpp" // In the accuracy tests, a computed solution // is acceptable if the relative error is less @@ -28,21 +29,19 @@ typedef std::complex cf32_t; -// Generate m-by-n, single complex random matrix -static inline std::vector generate_rand(const int m, const int n) { +using armral::utils::cf32_random; +using armral::utils::convert_cf32_array_to_vector; +// Generate m-by-n, single complex random matrix +static inline std::vector generate_rand(cf32_random &random, + const int m, const int n) { int size = m * n; - std::vector a = allocate_random_cf32(size); - // Convert matrix to std::complex type - std::vector out(size); - for (int i = 0; i < size; ++i) { - out[i] = std::complex(a[i].re, a[i].im); - } - return out; + return convert_cf32_array_to_vector(size, + random.vector(size).data()); } static inline float32_t infinity_norm(int m, int n, const cf32_t *a) { - column_major_matrix_view a_mat{a, m}; + column_major_matrix_view a_mat{a, static_cast(m)}; float32_t inorm = 0; for (int i = 0; i < m; i++) { float32_t tmp = 0; @@ -136,7 +135,7 @@ static inline void householder_qr(const int m, const int n, cf32_t *a, return; // GCOVR_EXCL_STOP } - column_major_matrix_view a_mat{a, m}; + column_major_matrix_view a_mat{a, static_cast(m)}; for (int i = 0; i < n; i++) { int k = std::min(i + 1, m - 1); tau[i] = @@ -171,8 +170,8 @@ static inline void apply_q(int m, int n, const cf32_t *a, const cf32_t *tau, } std::vector q(m * n); memcpy(q.data(), a, m * n * sizeof(cf32_t)); - column_major_matrix_view q_mat{q.data(), m}; - column_major_matrix_view c_mat{c, m}; + column_major_matrix_view q_mat{q.data(), static_cast(m)}; + column_major_matrix_view c_mat{c, static_cast(m)}; for (int i = n - 1; i >= 0; i--) { q_mat(i, i) = 1.0F; // Apply reflector from the left to all columns @@ -212,7 +211,7 @@ static inline std::vector get_q(const int m, const int n, // GCOVR_EXCL_STOP } std::vector q = a; - column_major_matrix_view q_mat{q.data(), m}; + column_major_matrix_view q_mat{q.data(), static_cast(m)}; // Accumulate reflectors from right to left // Q = H1 * H2....* Hn. They are applied to identity. for (int i = n - 1; i >= 0; i--) { @@ -252,8 +251,8 @@ static inline std::vector get_q(const int m, const int n, static inline void get_p(int m, int n, const cf32_t *a, const cf32_t *tau, cf32_t *p) { - column_major_matrix_view a_mat{a, m}; - column_major_matrix_view p_mat{p, n}; + column_major_matrix_view a_mat{a, static_cast(m)}; + column_major_matrix_view p_mat{p, static_cast(n)}; // Make a copy of reflectors. // P and A are not of same dimension @@ -280,7 +279,7 @@ static inline void get_p(int m, int n, const cf32_t *a, const cf32_t *tau, int n1 = n - 1; // This shift is the same in row or column major cf32_t *p1 = p + p_mat.stride() + 1; - column_major_matrix_view p1_mat{p1, n}; + column_major_matrix_view p1_mat{p1, static_cast(n)}; // Apply householder reflectors from the right for (int i = n1 - 1; i >= 0; i--) { @@ -327,7 +326,7 @@ static inline void get_p(int m, int n, const std::vector &a, static inline void generate_svd_matrix(const int m, const int n, std::vector &a, std::vector &s, - const float32_t cond, const int seed) { + const float32_t cond) { // Generate singular values from 1 to 1/cond // where cond is the condition number of the matrix @@ -336,9 +335,9 @@ static inline void generate_svd_matrix(const int m, const int n, s[i] = 1 - (float32_t)i / (n - 1) * rcond; } - srand(seed); - std::vector a1 = generate_rand(m, n); - std::vector a2 = generate_rand(n, n); + cf32_random random; + std::vector a1 = generate_rand(random, m, n); + std::vector a2 = generate_rand(random, n, n); // Perform QR of A1 std::vector tau1(n); @@ -352,7 +351,7 @@ static inline void generate_svd_matrix(const int m, const int n, std::vector q2 = get_q(n, n, a2, tau2); // multiply left orthogonal matrix by S - column_major_matrix_view q2_mat{q2.data(), n}; + column_major_matrix_view q2_mat{q2.data(), static_cast(n)}; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { q2_mat(i, j) *= s[i]; @@ -360,7 +359,7 @@ static inline void generate_svd_matrix(const int m, const int n, } // Apply Q1 to S*Q2, but first copy Q2 in an m * n matrix std::vector a_cmplx(m * n); - column_major_matrix_view q2_mat_mn{a_cmplx.data(), m}; + column_major_matrix_view q2_mat_mn{a_cmplx.data(), static_cast(m)}; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { q2_mat_mn(i, j) = q2_mat(i, j); @@ -401,7 +400,7 @@ static inline void bidiagonalization(const int m, const int n, cf32_t *a, return; // GCOVR_EXCL_STOP } - column_major_matrix_view a_mat{a, m}; + column_major_matrix_view a_mat{a, static_cast(m)}; for (int i = 0; i < n; i++) { // QR steps, generate elementary reflector H(i) to annihilate // the entries i+1 to the last of the i-th column @@ -542,8 +541,8 @@ inline static int svd_bidiagonal(const bool gen_singular_vectors, const int m, int maxiter = 2 * n; // Loop over the columns - column_major_matrix_view u_mat{u, u_stride}; - column_major_matrix_view vt_mat{vt, n}; + column_major_matrix_view u_mat{u, static_cast(u_stride)}; + column_major_matrix_view vt_mat{vt, static_cast(n)}; for (int curr_col = n - 1; curr_col >= 0; curr_col--) { // iteration to annihilate the off-diagonal E[curr_col]. int iter = 0; @@ -721,65 +720,6 @@ static inline int svd_cf32(bool gen_singular_vect, const int m, const int n, return 0; } -// armral_svd computes the SVD decomposition -// of an m-by-n matrix A in 4 steps. -// 1- QR factorization of A. -// 2- Bidiagonalization of R. -// 3- SVD of the bidigonal matrix from R. -// 4- Update of the left singular vectors -// with the orthogonal matrix from QR. -static inline int qr_svd_cf32(const bool gen_singular_vect, const int m, - const int n, std::vector &a, - std::vector &s, std::vector &u, - std::vector &vt) { - column_major_matrix_view a_mat{a.data(), m}; - - // Perform the QR factorization of A. - std::vector tau(n); - householder_qr(m, n, a.data(), tau); - - // Extract the R. - std::vector r(n * n); - column_major_matrix_view r_mat{r.data(), n}; - for (int i = 0; i < n; i++) { - for (int j = i; j < n; j++) { - r_mat(i, j) = a_mat(i, j); - } - } - // Bidiagonalization of R. - std::vector tauq(n); - std::vector taup(n); - std::vector e(n); - bidiagonalization(n, n, r.data(), s, e, tauq, taup); - - // Generate left and right orthogonal vectors. - if (gen_singular_vect) { - // Generate Q, and store it in u1. - std::vector u1 = get_q(n, n, r, tauq); - // Copy u1 in u - // Initialize u to zero in case it is not. - u.assign(u.size(), 0.0F); - column_major_matrix_view u_mat{u.data(), m}; - column_major_matrix_view u1_mat{u1.data(), n}; - for (int i = 0; i < n; i++) { - for (int j = 0; j < n; j++) { - u_mat(i, j) = u1_mat(i, j); - } - } - // Generate P and store it in vt. - get_p(n, n, r, taup, vt); - } - // Compute the singular values - // and singular vectors if required. - svd_bidiagonal(gen_singular_vect, n, n, s, e, u.data(), vt.data(), m); - - // Apply Q to U - if (gen_singular_vect) { - apply_q(m, n, a, tau, u); - } - return 0; -} - // Check ||Id - Q^H * Q||_∞/n < THRESHOLD * epsilon static inline bool check_orthogonality(const int m, const int n, cf32_t *q) { @@ -787,12 +727,12 @@ static inline bool check_orthogonality(const int m, const int n, cf32_t *q) { // Build an identity matrix Id std::vector a(n * n); - column_major_matrix_view a_mat{a.data(), n}; + column_major_matrix_view a_mat{a.data(), static_cast(n)}; for (int i = 0; i < n; i++) { a_mat(i, i) = 1.0F; } // Perform Id - Q^H * Q - column_major_matrix_view q_mat{q, m}; + column_major_matrix_view q_mat{q, static_cast(m)}; for (int j = 0; j < n; j++) { for (int i = 0; i < n; i++) { for (int k = 0; k < m; k++) { @@ -817,7 +757,7 @@ static inline bool check_orthogonality(int m, int n, std::vector &q) { // and Aref is the initial matrix static inline bool check_qr_decomposition(int m, int n, const cf32_t *aref, const cf32_t *a, const cf32_t *tau) { - column_major_matrix_view a_mat{a, m}; + column_major_matrix_view a_mat{a, static_cast(m)}; // Infinity norm of Aref float32_t anorm = infinity_norm(m, n, aref); @@ -828,7 +768,7 @@ static inline bool check_qr_decomposition(int m, int n, const cf32_t *aref, // Extract R, allocate m-by-n memory for // the multiplication by A later std::vector r(m * n); - column_major_matrix_view r_mat{r.data(), m}; + column_major_matrix_view r_mat{r.data(), static_cast(m)}; for (int i = 0; i < n; i++) { for (int j = i; j < n; j++) { r_mat(i, j) = a_mat(i, j); @@ -840,7 +780,7 @@ static inline bool check_qr_decomposition(int m, int n, const cf32_t *aref, // Copy Aref std::vector c(m * n); memcpy(c.data(), aref, m * n * sizeof(cf32_t)); - column_major_matrix_view c_mat{c.data(), m}; + column_major_matrix_view c_mat{c.data(), static_cast(m)}; // Compute Aref = Aref - QR for (int i = 0; i < m; i++) { @@ -866,9 +806,9 @@ static inline bool check_qr_decomposition(int m, int n, // Compute C = A * B + beta *C static inline void matmul(int m, int n, int k, const cf32_t *a, const cf32_t *b, const cf32_t beta, cf32_t *c) { - column_major_matrix_view a_mat{a, m}; - column_major_matrix_view b_mat{b, k}; - column_major_matrix_view c_mat{c, m}; + column_major_matrix_view a_mat{a, static_cast(m)}; + column_major_matrix_view b_mat{b, static_cast(k)}; + column_major_matrix_view c_mat{c, static_cast(m)}; for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { c_mat(i, j) *= beta; @@ -909,7 +849,7 @@ check_bidiag_decomposition(int m, int n, const cf32_t *aref, const cf32_t *a, get_p(m, n, a, taup, p.data()); // Build explicitly the n-by-n bidiagonal matrix B std::vector b(n * n); - column_major_matrix_view b_mat{b.data(), n}; + column_major_matrix_view b_mat{b.data(), static_cast(n)}; for (int i = 0; i < n - 1; i++) { b_mat(i, i) = d[i]; b_mat(i, i + 1) = e[i]; @@ -925,8 +865,8 @@ check_bidiag_decomposition(int m, int n, const cf32_t *aref, const cf32_t *a, apply_q(m, n, a, tauq, c.data()); // Compute Aref - Q * B * VT - column_major_matrix_view aref_mat{aref, m}; - column_major_matrix_view c_mat{c.data(), m}; + column_major_matrix_view aref_mat{aref, static_cast(m)}; + column_major_matrix_view c_mat{c.data(), static_cast(m)}; for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { c_mat(i, j) -= aref_mat(i, j); @@ -983,8 +923,8 @@ static inline bool check_svd_decomposition(int m, int n, const cf32_t *a, // U1 = U * S std::vector u1(m * n); - column_major_matrix_view u_mat{u, m}; - column_major_matrix_view u1_mat{u1.data(), m}; + column_major_matrix_view u_mat{u, static_cast(m)}; + column_major_matrix_view u1_mat{u1.data(), static_cast(m)}; for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { u1_mat(i, j) = u_mat(i, j) * s[j]; diff --git a/test/UpperPHY/CRC/main.cpp b/test/UpperPHY/CRC/main.cpp index 47a4d9b..f61bca7 100644 --- a/test/UpperPHY/CRC/main.cpp +++ b/test/UpperPHY/CRC/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" // CRC24A polynomial = x^24 + x^23 + x^18 + x^17 + x^14 + x^11 + x^10 + x^7 + // x^6 + x^5 + x^4 + x^3 + x + 1 @@ -71,7 +71,7 @@ static bool run_crc_test(int n) { // size is 32bits. b) If the input size is > 64bits, then a padding to 128bits // is assumed. assert(n == 8 || n % 16 == 0); - const auto buf = allocate_random_u8(n); + const auto buf = armral::utils::allocate_random_u8(n); uint64_t crc_res = 0; bool passed = true; armral_crc24_a_le(n, (const uint64_t *)buf.data(), &crc_res); diff --git a/test/UpperPHY/ConvolutionalDecoder/main.cpp b/test/UpperPHY/ConvolutionalDecoder/main.cpp index dcebd77..cd9ca73 100644 --- a/test/UpperPHY/ConvolutionalDecoder/main.cpp +++ b/test/UpperPHY/ConvolutionalDecoder/main.cpp @@ -3,8 +3,8 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "bit_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" +#include "utils/bits_to_bytes.hpp" template static bool run_convolutional_decoding_test( @@ -17,7 +17,7 @@ static bool run_convolutional_decoding_test( for (uint32_t i = 0; i < k; i++) { src_bytes[i] = rand() % 2; } - bytes_to_bits(k, src_bytes.data(), src.data()); + armral::bytes_to_bits(k, src_bytes.data(), src.data()); // Encoding std::vector dst0(k + 1); @@ -31,9 +31,9 @@ static bool run_convolutional_decoding_test( std::vector dst1_bytes(k); std::vector dst2_bytes(k); - bits_to_bytes(k, (const uint8_t *)dst0.data(), dst0_bytes.data()); - bits_to_bytes(k, (const uint8_t *)dst1.data(), dst1_bytes.data()); - bits_to_bytes(k, (const uint8_t *)dst2.data(), dst2_bytes.data()); + armral::bits_to_bytes(k, (const uint8_t *)dst0.data(), dst0_bytes.data()); + armral::bits_to_bytes(k, (const uint8_t *)dst1.data(), dst1_bytes.data()); + armral::bits_to_bytes(k, (const uint8_t *)dst2.data(), dst2_bytes.data()); // Modulation armral_modulation_type mod_type = ARMRAL_MOD_16QAM; @@ -67,7 +67,7 @@ static bool run_convolutional_decoding_test( iter_max, out.data()); std::vector out_bytes(k); - bits_to_bytes(k, out.data(), out_bytes.data()); + armral::bits_to_bytes(k, out.data(), out_bytes.data()); // Check result bool passed = true; @@ -79,8 +79,8 @@ static bool run_convolutional_decoding_test( // GCOVR_EXCL_STOP } else { printf("[%s] k=%u\n", name, k); - auto check_dst = - check_results_u8(name, out_bytes.data(), src_bytes.data(), k); + auto check_dst = armral::utils::check_results_u8(name, out_bytes.data(), + src_bytes.data(), k); passed = check_dst; } diff --git a/test/UpperPHY/ConvolutionalEncoder/main.cpp b/test/UpperPHY/ConvolutionalEncoder/main.cpp index 640bd1f..b9ab8f3 100644 --- a/test/UpperPHY/ConvolutionalEncoder/main.cpp +++ b/test/UpperPHY/ConvolutionalEncoder/main.cpp @@ -3,8 +3,8 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "bit_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" +#include "utils/bits_to_bytes.hpp" #include static void reference_tail_biting_convolutional_encode_block(const uint8_t *src, @@ -14,7 +14,7 @@ static void reference_tail_biting_convolutional_encode_block(const uint8_t *src, uint8_t *dst2) { // Cast the bits of the input stream to bytes std::vector bytes_src(k); - bits_to_bytes(k, src, bytes_src.data()); + armral::bits_to_bytes(k, src, bytes_src.data()); // Initialize the shift register with the last six // information bits of the input stream (s_i = src_(k-1-i)) @@ -46,14 +46,14 @@ static void reference_tail_biting_convolutional_encode_block(const uint8_t *src, } // Convert the bytes back to bits - bytes_to_bits(k, bytes_dst0.data(), dst0); - bytes_to_bits(k, bytes_dst1.data(), dst1); - bytes_to_bits(k, bytes_dst2.data(), dst2); + armral::bytes_to_bits(k, bytes_dst0.data(), dst0); + armral::bytes_to_bits(k, bytes_dst1.data(), dst1); + armral::bytes_to_bits(k, bytes_dst2.data(), dst2); } static bool run_convolutional_encoding_test(int k) { assert(k >= 8); - const auto src = allocate_random_u8((k + 7) / 8); + const auto src = armral::utils::allocate_random_u8((k + 7) / 8); std::vector dst0_ref((k + 7) / 8); std::vector dst1_ref((k + 7) / 8); @@ -76,15 +76,15 @@ static bool run_convolutional_encoding_test(int k) { passed = false; // GCOVR_EXCL_STOP } else { - auto check_dst0 = - check_results_u8("CONVOLUTIONAL ENCODING (STREAM D0)", dst0.data(), - dst0_ref.data(), (k + 7) / 8); - auto check_dst1 = - check_results_u8("CONVOLUTIONAL ENCODING (STREAM D1)", dst1.data(), - dst1_ref.data(), (k + 7) / 8); - auto check_dst2 = - check_results_u8("CONVOLUTIONAL ENCODING (STREAM D2)", dst2.data(), - dst2_ref.data(), (k + 7) / 8); + auto check_dst0 = armral::utils::check_results_u8( + "CONVOLUTIONAL ENCODING (STREAM D0)", dst0.data(), dst0_ref.data(), + (k + 7) / 8); + auto check_dst1 = armral::utils::check_results_u8( + "CONVOLUTIONAL ENCODING (STREAM D1)", dst1.data(), dst1_ref.data(), + (k + 7) / 8); + auto check_dst2 = armral::utils::check_results_u8( + "CONVOLUTIONAL ENCODING (STREAM D2)", dst2.data(), dst2_ref.data(), + (k + 7) / 8); passed = check_dst0 && check_dst1 && check_dst2; } diff --git a/test/UpperPHY/Demodulation/main.cpp b/test/UpperPHY/Demodulation/main.cpp index bfc68be..642e5bd 100644 --- a/test/UpperPHY/Demodulation/main.cpp +++ b/test/UpperPHY/Demodulation/main.cpp @@ -4,7 +4,7 @@ */ #include "armral.h" #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include #include @@ -29,11 +29,8 @@ static bool check_llrs_equal(const armral_cmplx_int16_t *p_src, } } - if (!passed) { - printf("Check failed!\n"); // GCOVR_EXCL_LINE - } else { - printf("Check successful!\n"); - } + printf("Check %s!\n", passed ? "successful" : "failed"); + return passed; } @@ -242,9 +239,9 @@ static bool test_demod(armral_modulation_type mod) { for (auto n : test_params.num_symbols) { auto num_llrs = n * test_params.bits_per_symbol; - auto expected_llr = allocate_random_i8(num_llrs); - auto lib_llr = allocate_random_i8(num_llrs); - auto symbols = allocate_random_cs16(n); + auto expected_llr = armral::utils::allocate_random_i8(num_llrs); + auto lib_llr = armral::utils::allocate_random_i8(num_llrs); + auto symbols = armral::utils::allocate_random_cs16(n); for (auto ulp : test_params.llr_ulps) { // Perform the reference demodulation test_params.ref_func(n, ulp, symbols.data(), expected_llr.data()); diff --git a/test/UpperPHY/LDPC/Decoding/main.cpp b/test/UpperPHY/LDPC/Decoding/main.cpp index f557ea2..f28b039 100644 --- a/test/UpperPHY/LDPC/Decoding/main.cpp +++ b/test/UpperPHY/LDPC/Decoding/main.cpp @@ -5,8 +5,8 @@ #include "../ldpc_test_common.hpp" #include "armral.h" -#include "bit_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" +#include "utils/bits_to_bytes.hpp" #include #include @@ -104,17 +104,18 @@ bool run_ldpc_decoding_test(uint32_t its, uint32_t z, armral_ldpc_graph_t bg, // Allocate a random input to be encoded uint32_t len_in = z * graph->nmessage_bits; - auto to_encode = allocate_random_u8((len_in + 7) / 8); + armral::utils::int_random random; + auto to_encode = random.vector((len_in + 7) / 8); // If we are doing CRC checking, then we need to attach CRC bits to the input if (crc_idx != ARMRAL_LDPC_NO_CRC) { - auto info_to_encode = allocate_random_u8((len_in + 7) / 8); + auto info_to_encode = random.vector((len_in + 7) / 8); ldpc_crc_attachment(info_to_encode.data(), crc_idx + 24, len_in, to_encode.data()); } uint32_t encoded_len = z * graph->ncodeword_bits; - auto encoded = allocate_random_u8((encoded_len + 7) / 8); + auto encoded = random.vector((encoded_len + 7) / 8); uint32_t len_filler_bits = 0; // Encode the data armral_ldpc_encode_block(to_encode.data(), bg, z, len_filler_bits, @@ -135,16 +136,17 @@ bool run_ldpc_decoding_test(uint32_t its, uint32_t z, armral_ldpc_graph_t bg, armral_demodulation(mod_num_symbols, ulp, mod_type, data_mod.data(), data_demod_soft.data()); - auto decoded = allocate_random_u8((encoded_len + 2 * z + 7) / 8); + auto decoded = random.vector((encoded_len + 2 * z + 7) / 8); ldpc_decoding_under_test(data_demod_soft.data(), bg, z, crc_idx, its, decoded.data()); - auto decoded_bytes = bits_to_bytes(encoded_len + 2 * z, decoded.data()); + auto decoded_bytes = + armral::bits_to_bytes(encoded_len + 2 * z, decoded.data()); // Make sure that the codeword passes the parity check bool passed = perform_parity_check(decoded_bytes.data(), z, bg); // Also check that the decoded message is equal to the original message - auto bytes_in = bits_to_bytes(len_in, to_encode.data()); + auto bytes_in = armral::bits_to_bytes(len_in, to_encode.data()); passed &= check_decoded_message(len_in, bytes_in.data(), decoded_bytes.data()); return passed; diff --git a/test/UpperPHY/LDPC/Encoding/main.cpp b/test/UpperPHY/LDPC/Encoding/main.cpp index 59ba4e3..4b7732a 100644 --- a/test/UpperPHY/LDPC/Encoding/main.cpp +++ b/test/UpperPHY/LDPC/Encoding/main.cpp @@ -4,10 +4,10 @@ */ #include "../ldpc_test_common.hpp" #include "armral.h" -#include "bit_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include "ldpc_coding.hpp" #include "ldpc_encoding_test_data.h" +#include "utils/bits_to_bytes.hpp" #include #include @@ -187,7 +187,7 @@ std::vector armral_ldpc_encode_block_ref(const uint8_t *data_in, const auto *graph = armral_ldpc_get_base_graph(bg); // Cast the bits to bytes for easier handling of data - auto bytes_in = bits_to_bytes(z * graph->nmessage_bits, data_in); + auto bytes_in = armral::bits_to_bytes(z * graph->nmessage_bits, data_in); // Get the lifting set index auto lsi = get_ldpc_lifting_index(z); @@ -232,7 +232,7 @@ std::vector armral_ldpc_encode_block_ref(const uint8_t *data_in, } // Now we set the parity bits in the out message to what they are supposed to - // be. TODO: Actually puncture the first two columns. For now we don't, and + // be. TODO: Actually puncture the first two columns. For now we don't, and // store all of the columns in the codeword std::vector codeword((graph->ncodeword_bits + 2) * z); @@ -269,11 +269,6 @@ std::vector armral_ldpc_encode_block_ref(const uint8_t *data_in, return codeword; } -std::vector gen_random_bits(uint32_t num_bits) { - uint32_t num_bytes = (num_bits + 7) / 8; - return allocate_random_u8(num_bytes, 0U, 1U); -} - bool check_bytes_equal(const std::vector &enc, const std::vector &expected, const ldpc_test_param_t &tc) { @@ -307,24 +302,17 @@ bool check_bytes_equal(const std::vector &enc, template bool test_ldpc_encode_block( char const *name, LDPCEncodeBlockFunction ldpc_encode_block_under_test) { + + armral::utils::bit_random random; bool passed = true; // Check that the implementation matches expected results - for (auto &tc : ldpc_tests) { + for (const auto &tc : ldpc_tests) { // Generate some random data to encode. This should be in a single // block + auto data_in = random.bit_vector(tc.length - tc.length_of_filler_bits); - uint32_t len = tc.length; - - if ((tc.length_of_filler_bits > 0) && - ((tc.length_of_filler_bits % 8) != 0)) { - len += 8; - } - - auto data_in = gen_random_bits(len); - - // Zero memset input vector's end portion of length length_of_filler_bits - memset(&(data_in[(tc.length - tc.length_of_filler_bits + 7) / 8]), 0, - (tc.length_of_filler_bits + 7) / 8); + // Zero input vector's end portion of length length_of_filler_bits + data_in.resize((tc.length + 7) / 8, 0); const auto *bg = armral_ldpc_get_base_graph(tc.graph_type); @@ -334,8 +322,8 @@ bool test_ldpc_encode_block( ldpc_encode_block_under_test(data_in.data(), tc.graph_type, tc.lifting_size, tc.length_of_filler_bits, encoding.data()); - auto encoding_bytes = - bits_to_bytes(tc.lifting_size * bg->ncodeword_bits, encoding.data()); + auto encoding_bytes = armral::bits_to_bytes( + tc.lifting_size * bg->ncodeword_bits, encoding.data()); // Compute full block data by reference implementation // and check its validity diff --git a/test/UpperPHY/LDPC/RateMatching/main.cpp b/test/UpperPHY/LDPC/RateMatching/main.cpp index ceb7267..647955b 100644 --- a/test/UpperPHY/LDPC/RateMatching/main.cpp +++ b/test/UpperPHY/LDPC/RateMatching/main.cpp @@ -3,8 +3,8 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "bit_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" +#include "utils/bits_to_bytes.hpp" #include #include @@ -54,12 +54,12 @@ void ref_bit_selection(uint32_t z, uint32_t n, uint32_t e, } else { - bits_to_bytes(n, (const uint8_t *)in, (uint8_t *)scratch_ptr1); + armral::bits_to_bytes(n, (const uint8_t *)in, (uint8_t *)scratch_ptr1); memcpy(scratch_ptr2, scratch_ptr1, len_s_bits); memcpy(&scratch_ptr2[len_s_bits], &scratch_ptr1[len_s_f_bits], len_p_bits); - bytes_to_bits((n - len_filler_bits), (const uint8_t *)scratch_ptr2, - scratch_ptr1); + armral::bytes_to_bits((n - len_filler_bits), + (const uint8_t *)scratch_ptr2, scratch_ptr1); } in_bits = scratch_ptr1; @@ -112,7 +112,7 @@ void ref_bit_interleave(uint32_t e, uint32_t qm, const uint8_t *in, bool test_ref_rate_matching() { bool passed = true; - uint8_t in[4] = {0xA7, 0xFF, 0xFF, 0xA9}; + constexpr uint8_t in[4] = {0xA7, 0xFF, 0xFF, 0xA9}; uint8_t out[4] = {0}; // Test bit selection for k0 = 0. @@ -186,11 +186,11 @@ bool test_ref_rate_matching() { passed &= (out[3] == in[2]); // rv_id = 0 z = 7 , n = 350, e = 328, len_filler_bits = 16 k = 70 - uint8_t in_filler[] = {0x22, 0x35, 0x72, 0xd4, 0xb5, 0x00, 0x00, 0x9a, 0x32, - 0xd0, 0x45, 0x6d, 0x18, 0x10, 0xfa, 0xf8, 0xa4, 0x5e, - 0x8c, 0x88, 0x1f, 0x8a, 0xf6, 0x66, 0xad, 0xc8, 0xb0, - 0xc8, 0xe6, 0xca, 0x5c, 0x4e, 0x0a, 0x59, 0x47, 0x33, - 0xb8, 0x61, 0x0c, 0x6c, 0x8c, 0xa8, 0xa8, 0xb0}; + constexpr uint8_t in_filler[] = { + 0x22, 0x35, 0x72, 0xd4, 0xb5, 0x00, 0x00, 0x9a, 0x32, 0xd0, 0x45, + 0x6d, 0x18, 0x10, 0xfa, 0xf8, 0xa4, 0x5e, 0x8c, 0x88, 0x1f, 0x8a, + 0xf6, 0x66, 0xad, 0xc8, 0xb0, 0xc8, 0xe6, 0xca, 0x5c, 0x4e, 0x0a, + 0x59, 0x47, 0x33, 0xb8, 0x61, 0x0c, 0x6c, 0x8c, 0xa8, 0xa8, 0xb0}; uint8_t out_filler[328 >> 3] = {0}; ref_bit_selection(7, 350, 328, 16, 70, 0, in_filler, out_filler); @@ -365,7 +365,7 @@ bool test_ldpc_rate_matching( uint32_t num_bits = src_length(bg, z); uint32_t num_bytes_enc_out = (num_bits + 7) / 8; std::vector src_store = - allocate_random_u8(num_bytes_enc_out, 0U, 1U); + armral::utils::allocate_random_u8(num_bytes_enc_out, 0U, 1U); std::vector src = std::vector(num_bytes_enc_out); for (auto mod : mod_list) { uint32_t qm = num_bit_per_symbol(mod); @@ -397,8 +397,8 @@ bool test_ldpc_rate_matching( memcpy(src.data(), src_store.data(), num_bytes_enc_out); armral_ref_rate_matching(bg, z, e, nref, len_filler_bits, g, rv, mod, src.data(), ref.data()); - passed &= check_results_u8(test_name, dst.data(), ref.data(), - num_bytes_ratematch_out); + passed &= armral::utils::check_results_u8( + test_name, dst.data(), ref.data(), num_bytes_ratematch_out); } } } diff --git a/test/UpperPHY/LDPC/RateRecovery/main.cpp b/test/UpperPHY/LDPC/RateRecovery/main.cpp index 993b08e..68e55f6 100644 --- a/test/UpperPHY/LDPC/RateRecovery/main.cpp +++ b/test/UpperPHY/LDPC/RateRecovery/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include #include @@ -208,7 +208,7 @@ bool test_ref_rate_recovery() { uint32_t k0 = 16; uint32_t z = 2; uint32_t k = 20; - auto in = allocate_random_i8(e); + auto in = armral::utils::allocate_random_i8(e); std::vector out(n); ref_undo_selection(z, n, e, 0, k, k0, in.data(), out.data()); passed &= std::equal(out.begin() + k0, out.begin() + n, in.begin()); @@ -219,7 +219,7 @@ bool test_ref_rate_recovery() { n = 100; k0 = 16; k = 20; - in = allocate_random_i8(e); + in = armral::utils::allocate_random_i8(e); memset(out.data(), 0, n * sizeof(int8_t)); ref_undo_selection(z, n, e, 0, k, k0, in.data(), out.data()); passed &= std::all_of(out.begin(), out.begin() + k0, @@ -234,7 +234,7 @@ bool test_ref_rate_recovery() { n = 50; k0 = 0; k = 20; - in = allocate_random_i8(e); + in = armral::utils::allocate_random_i8(e); // The final llrs is the sum of the i-th and (i+16)-th llr. // Ensure that saturation is tested at least once. in[0] = INT8_MIN; @@ -262,7 +262,7 @@ bool test_ref_rate_recovery() { k0 = 0; uint32_t f = 8; k = 20; - auto in_filler = allocate_random_i8(e); + auto in_filler = armral::utils::allocate_random_i8(e); std::vector out_filler(n); ref_undo_selection(z, n, e, f, z * 10, k0, in_filler.data(), out_filler.data()); @@ -344,7 +344,8 @@ bool test_ldpc_rate_recovery( uint32_t e = rb_list[rb_idx] * num_res; for (auto rv : rv_list) { for (auto nref : nref_list[bg]) { - std::vector llrs_in = allocate_random_i8(e); + std::vector llrs_in = + armral::utils::allocate_random_i8(e); std::vector llrs_out(n); std::vector llrs_ref(n); armral_ref_rate_recovery(bg, z, e, nref, len_filler_bits, g, rv, diff --git a/test/UpperPHY/Modulation/main.cpp b/test/UpperPHY/Modulation/main.cpp index 2fff2b6..528205b 100644 --- a/test/UpperPHY/Modulation/main.cpp +++ b/test/UpperPHY/Modulation/main.cpp @@ -4,7 +4,7 @@ */ #include "armral.h" #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include #include @@ -14,7 +14,7 @@ using mapping_func_t = armral_cmplx_int16_t(const std::vector &, uint8_t); static std::vector -gen_qam_constellation(mapping_func_t map_func, +gen_qam_constellation(mapping_func_t &map_func, const std::vector &values, uint16_t num_points) { assert(num_points <= 256); @@ -226,8 +226,9 @@ static void modulation_ref_256qam(const uint32_t nbits, const int8_t *p_src, } } -static bool check_results(uint32_t out_size, armral_cmplx_int16_t *computed_out, - armral_cmplx_int16_t *expected_out) { +static bool check_results(uint32_t out_size, + const armral_cmplx_int16_t *computed_out, + const armral_cmplx_int16_t *expected_out) { bool passed = true; for (uint32_t i = 0; i < out_size; i++) { @@ -243,11 +244,7 @@ static bool check_results(uint32_t out_size, armral_cmplx_int16_t *computed_out, } } - if (!passed) { - printf("Check failed!\n"); // GCOVR_EXCL_LINE - } else { - printf("Check successful!\n"); - } + printf("Check %s!\n", passed ? "successful" : "failed"); return passed; } @@ -275,9 +272,10 @@ bool test_qpsk(void) { // One complex number is generated per Q, I pair of bits (i.e. symbol) in // the input auto num_symbols = num_bits / 2; - auto res = allocate_random_cs16(num_symbols, INT16_MIN, INT16_MAX); + auto res = + armral::utils::allocate_random_cs16(num_symbols, INT16_MIN, INT16_MAX); std::vector res_exp(res); - auto input = allocate_random_i8(num_bytes); + auto input = armral::utils::allocate_random_i8(num_bytes); modulation_ref_qpsk(num_bits, input.data(), res_exp.data()); armral_modulation(num_bits, ARMRAL_MOD_QPSK, (uint8_t *)input.data(), @@ -299,9 +297,9 @@ bool test_16_qam(void) { for (auto num_bits : bits_in) { auto num_bytes = (num_bits + 7) / 8; auto num_symbols = (num_bits / 4); - auto res = allocate_random_cs16(num_symbols); + auto res = armral::utils::allocate_random_cs16(num_symbols); std::vector res_exp(res); - auto input = allocate_random_i8(num_bytes); + auto input = armral::utils::allocate_random_i8(num_bytes); modulation_ref_16qam(num_bits, input.data(), res_exp.data()); armral_modulation(num_bits, ARMRAL_MOD_16QAM, (uint8_t *)input.data(), res.data()); @@ -327,9 +325,9 @@ bool test_64_qam(void) { for (auto num_bits : bits_in) { auto num_bytes = (num_bits + 7) / 8; auto num_symbols = num_bits / 6; - auto res = allocate_random_cs16(num_symbols); + auto res = armral::utils::allocate_random_cs16(num_symbols); std::vector res_exp(res); - auto input = allocate_random_i8(num_bytes); + auto input = armral::utils::allocate_random_i8(num_bytes); modulation_ref_64qam(num_bits, input.data(), res_exp.data()); armral_modulation(num_bits, ARMRAL_MOD_64QAM, (uint8_t *)input.data(), res.data()); @@ -349,9 +347,9 @@ bool test_256_qam(void) { for (auto num_bits : bits_in) { auto num_bytes = num_bits / 8; - auto res = allocate_random_cs16(num_bytes); + auto res = armral::utils::allocate_random_cs16(num_bytes); std::vector res_exp(res); - auto input = allocate_random_i8(num_bytes); + auto input = armral::utils::allocate_random_i8(num_bytes); modulation_ref_256qam(num_bits, input.data(), res_exp.data()); armral_modulation(num_bits, ARMRAL_MOD_256QAM, (uint8_t *)input.data(), res.data()); diff --git a/test/UpperPHY/Polar/CrcAttachment/main.cpp b/test/UpperPHY/Polar/CrcAttachment/main.cpp index 21f1f71..d0fa2e2 100644 --- a/test/UpperPHY/Polar/CrcAttachment/main.cpp +++ b/test/UpperPHY/Polar/CrcAttachment/main.cpp @@ -2,7 +2,7 @@ Arm RAN Acceleration Library SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ -#include "int8_utils.hpp" +#include "int_utils.hpp" #include "polar_crc_attach_data.hpp" template @@ -17,10 +17,11 @@ static bool run_polar_crc_attachment_test( uint32_t k = t.a + l; uint32_t k_bytes = (k + 7) / 8; - auto out = allocate_random_u8(k_bytes); + auto out = armral::utils::allocate_random_u8(k_bytes); polar_crc_attach_under_test(t.in.data(), t.a, out.data()); - return check_results_u8(name, out.data(), t.out.data(), k_bytes); + return armral::utils::check_results_u8(name, out.data(), t.out.data(), + k_bytes); } template diff --git a/test/UpperPHY/Polar/Decoding/main.cpp b/test/UpperPHY/Polar/Decoding/main.cpp index e6f48ac..bf70339 100644 --- a/test/UpperPHY/Polar/Decoding/main.cpp +++ b/test/UpperPHY/Polar/Decoding/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include #include @@ -26,19 +26,20 @@ static bool run_polar_decoding_test(uint32_t n, uint32_t e, uint32_t k, // generate a random input uint32_t crc_bits = 24; // 24-CRC (L = 24) uint32_t msg_bits = k - crc_bits; // message length (A = K - L) - auto msg = allocate_random_bits(msg_bits); + armral::utils::bit_random random; + auto msg = random.bit_vector(msg_bits); // attach CRC bits - auto in = allocate_random_bits(k); + auto in = random.bit_vector(k); armral_polar_crc_attachment(msg.data(), msg_bits, in.data()); // run interleaving - auto ref = allocate_random_bits(n); + auto ref = random.bit_vector(n); armral_polar_subchannel_interleave(n, k + n_pc, frozen_mask.data(), in.data(), ref.data()); // run encoding - auto encoded = allocate_random_bits(n); + auto encoded = random.bit_vector(n); armral_polar_encode_block(n, ref.data(), encoded.data()); // run modulation @@ -81,16 +82,16 @@ static bool run_polar_decoding_test(uint32_t n, uint32_t e, uint32_t k, } if (crc_pass) { - return check_results_u8("POLAR DECODING", data_deint.data(), in.data(), - (k + 7) / 8); + return armral::utils::check_results_u8( + "POLAR DECODING", data_deint.data(), in.data(), (k + 7) / 8); } } // compare ML codeword with reference if no codeword passed the CRC check armral_polar_subchannel_deinterleave(k, frozen_mask.data(), out.data(), data_deint.data()); - return check_results_u8("POLAR DECODING", data_deint.data(), in.data(), - (k + 7) / 8); + return armral::utils::check_results_u8("POLAR DECODING", data_deint.data(), + in.data(), (k + 7) / 8); } // Entry point for unit testing of polar coding diff --git a/test/UpperPHY/Polar/Encoding/main.cpp b/test/UpperPHY/Polar/Encoding/main.cpp index ae53d57..5d7a55c 100644 --- a/test/UpperPHY/Polar/Encoding/main.cpp +++ b/test/UpperPHY/Polar/Encoding/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "cs16_utils.hpp" -#include "int8_utils.hpp" +#include "int_utils.hpp" static void reference_polar_encoding(int nbits, const uint8_t *a, uint8_t *b) { int nbytes = nbits / 8; @@ -27,14 +27,16 @@ static void reference_polar_encoding(int nbits, const uint8_t *a, uint8_t *b) { static bool run_polar_encoding_test(int n) { assert(n % 32 == 0); - const auto a = allocate_random_u8(n / 8); - auto b = allocate_random_u8(n / 8); + armral::utils::int_random random; + const auto a = random.vector(n / 8); + auto b = random.vector(n / 8); auto ref = b; reference_polar_encoding(n, a.data(), ref.data()); printf("[POLAR ENCODING] n=%d\n", n); armral_polar_encode_block(n, a.data(), b.data()); - return check_results_u8("POLAR ENCODING", b.data(), ref.data(), n / 8); + return armral::utils::check_results_u8("POLAR ENCODING", b.data(), ref.data(), + n / 8); } // Entry point for unit testing of polar coding diff --git a/test/UpperPHY/Polar/Frozen/main.cpp b/test/UpperPHY/Polar/Frozen/main.cpp index 341383f..f1c1ac6 100644 --- a/test/UpperPHY/Polar/Frozen/main.cpp +++ b/test/UpperPHY/Polar/Frozen/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include #include @@ -146,7 +146,6 @@ static void polar_frozen_mask_ref(uint32_t n, uint32_t e, uint32_t k, } } auto qi_tmp = qn(n); - std::vector qi(k); int qi_tmp_idx = n - 1; uint32_t wm_idx = ~0U; uint32_t wm_weight = ~0U; @@ -193,8 +192,8 @@ static bool run_polar_frozen_mask_test(int n, int e, int k, int n_pc, printf("[polar frozen mask] n=%d, e=%d, k=%d n_pc=%d n_pc_wm=%d\n", n, e, k, n_pc, n_pc_wm); armral_polar_frozen_mask(n, e, k, n_pc, n_pc_wm, frozen_mask.data()); - return check_results_u8("polar frozen mask", frozen_mask.data(), - ref_frozen_mask.data(), n); + return armral::utils::check_results_u8( + "polar frozen mask", frozen_mask.data(), ref_frozen_mask.data(), n); } int main(int argc, char **argv) { diff --git a/test/UpperPHY/Polar/RateMatching/main.cpp b/test/UpperPHY/Polar/RateMatching/main.cpp index 4bbc017..27962fd 100644 --- a/test/UpperPHY/Polar/RateMatching/main.cpp +++ b/test/UpperPHY/Polar/RateMatching/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include #include @@ -109,15 +109,16 @@ static bool run_polar_rate_matching_test( assert(n % 32 == 0); assert(k <= e); - auto ref_d = allocate_random_bits(n); - auto ref_f = allocate_random_bits(e); + armral::utils::bit_random random; + auto ref_d = random.bit_vector(n); + auto ref_f = random.bit_vector(e); auto f = ref_f; printf("[%s] i_bil=%d n=%d, e=%d, k=%d\n", name, i_bil, n, e, k); polar_rate_matching_under_test(n, e, k, i_bil, ref_d.data(), f.data()); reference_polar_rate_matching(n, e, k, i_bil, ref_d.data(), ref_f.data()); - return check_results_u8(name, f.data(), ref_f.data(), e / 8); + return armral::utils::check_results_u8(name, f.data(), ref_f.data(), e / 8); } template diff --git a/test/UpperPHY/Polar/RateRecovery/main.cpp b/test/UpperPHY/Polar/RateRecovery/main.cpp index 8d7d5b7..8f7cd75 100644 --- a/test/UpperPHY/Polar/RateRecovery/main.cpp +++ b/test/UpperPHY/Polar/RateRecovery/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include #include @@ -132,8 +132,8 @@ static bool run_polar_rate_recovery_test( assert(n % 32 == 0); assert(k <= e); - auto ref_llr_demod = allocate_random_i8(e); - auto ref_llr = allocate_random_i8(n); + auto ref_llr_demod = armral::utils::allocate_random_i8(e); + auto ref_llr = armral::utils::allocate_random_i8(n); auto llr = ref_llr; printf("[%s] i_bil=%d n=%d, e=%d, k=%d\n", name, i_bil, n, e, k); diff --git a/test/UpperPHY/Polar/SubchannelDeinterleave/main.cpp b/test/UpperPHY/Polar/SubchannelDeinterleave/main.cpp index 0b0840f..b5c2c04 100644 --- a/test/UpperPHY/Polar/SubchannelDeinterleave/main.cpp +++ b/test/UpperPHY/Polar/SubchannelDeinterleave/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include #include @@ -16,9 +16,10 @@ static bool run_polar_subchannel_deinterleave_test(int n, int k, int n_pc, assert(k + n_pc <= n); assert(n_pc_wm <= n_pc); - auto ref_c = allocate_random_bits(k); + armral::utils::bit_random random; + auto ref_c = random.bit_vector(k); auto c = ref_c; - auto u = allocate_random_bits(n); + auto u = random.bit_vector(n); // build a frozen mask to pass to the interleaving. // not handling parity bits for now. @@ -30,8 +31,8 @@ static bool run_polar_subchannel_deinterleave_test(int n, int k, int n_pc, u.data()); armral_polar_subchannel_deinterleave(k, frozen.data(), u.data(), c.data()); - return check_results_u8("polar subchannel deinterleave", c.data(), - ref_c.data(), (k + 7) / 8); + return armral::utils::check_results_u8("polar subchannel deinterleave", + c.data(), ref_c.data(), (k + 7) / 8); } int main(int argc, char **argv) { diff --git a/test/UpperPHY/Polar/SubchannelInterleave/main.cpp b/test/UpperPHY/Polar/SubchannelInterleave/main.cpp index 4a83cb0..4938c2d 100644 --- a/test/UpperPHY/Polar/SubchannelInterleave/main.cpp +++ b/test/UpperPHY/Polar/SubchannelInterleave/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include #include @@ -52,8 +52,9 @@ static bool run_polar_subchannel_interleave_test(int n, int k, int n_pc, assert(n % 32 == 0); assert(k <= n); - const auto c = allocate_random_bits(k); - auto u = allocate_random_bits(n); + armral::utils::bit_random random; + const auto c = random.bit_vector(k); + auto u = random.bit_vector(n); // build a frozen mask to pass to the interleaving. // not handling parity bits for now. @@ -65,8 +66,8 @@ static bool run_polar_subchannel_interleave_test(int n, int k, int n_pc, printf("[polar subchannel interleave] n=%d, k=%d\n", n, k); armral_polar_subchannel_interleave(n, k + n_pc, frozen.data(), c.data(), u.data()); - return check_results_u8("polar subchannel interleave", u.data(), ref_u.data(), - n / 8); + return armral::utils::check_results_u8("polar subchannel interleave", + u.data(), ref_u.data(), n / 8); } int main(int argc, char **argv) { diff --git a/test/UpperPHY/Turbo/RateMatching/main.cpp b/test/UpperPHY/Turbo/RateMatching/main.cpp index 4353ab5..44c66e8 100644 --- a/test/UpperPHY/Turbo/RateMatching/main.cpp +++ b/test/UpperPHY/Turbo/RateMatching/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include #include @@ -11,7 +11,7 @@ #include // The number of bits in the encoded message passed to rate matching is the -// length of the code block given to Turbo encoding (k) plus 4 tremination bits. +// length of the code block given to Turbo encoding (k) plus 4 termination bits. static constexpr uint32_t valid_ds[188] = { 44, 52, 60, 68, 76, 84, 92, 100, 108, 116, 124, 132, 140, 148, 156, 164, 172, 180, 188, 196, 204, 212, 220, 228, @@ -203,7 +203,7 @@ static void reference_turbo_rate_matching(uint32_t d, uint32_t e, uint32_t rv, const uint8_t *src2, uint8_t *dst) { assert(d > 0); assert(e > 0); - assert(rv >= 0 && rv <= 3); + assert(rv <= 3); constexpr uint32_t ctc = 32; // The minimum number of rows which gives rtc * ctc >= d. @@ -246,17 +246,19 @@ static bool run_turbo_rate_matching_test( printf("[%s] d=%u, e=%u, rv=%u\n", name, d, e, rv); - auto ref_src0 = allocate_random_bits(d); - auto ref_src1 = allocate_random_bits(d); - auto ref_src2 = allocate_random_bits(d); - auto ref_dst = allocate_random_bits(e); + armral::utils::bit_random random; + auto ref_src0 = random.bit_vector(d); + auto ref_src1 = random.bit_vector(d); + auto ref_src2 = random.bit_vector(d); + auto ref_dst = random.bit_vector(e); auto dst = ref_dst; turbo_rate_matching_under_test(d, e, rv, ref_src0.data(), ref_src1.data(), ref_src2.data(), dst.data()); reference_turbo_rate_matching(d, e, rv, ref_src0.data(), ref_src1.data(), ref_src2.data(), ref_dst.data()); - passed &= check_results_u8(name, dst.data(), ref_dst.data(), (e + 7) / 8); + passed &= armral::utils::check_results_u8(name, dst.data(), ref_dst.data(), + (e + 7) / 8); return passed; } diff --git a/test/UpperPHY/Turbo/RateRecovery/main.cpp b/test/UpperPHY/Turbo/RateRecovery/main.cpp index 5f91d3a..b9636de 100644 --- a/test/UpperPHY/Turbo/RateRecovery/main.cpp +++ b/test/UpperPHY/Turbo/RateRecovery/main.cpp @@ -3,7 +3,7 @@ SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ #include "armral.h" -#include "int8_utils.hpp" +#include "int_utils.hpp" #include "rate_recovery_data.hpp" template @@ -20,12 +20,12 @@ static bool run_turbo_rate_recovery_test( turbo_rate_recovery_under_test(test.d, test.e, test.rv, test.src.data(), dst0.data(), dst1.data(), dst2.data()); - passed &= check_results_i8("turbo rate recovery dst0", dst0.data(), - test.out0.data(), test.d); - passed &= check_results_i8("turbo rate recovery dst1", dst1.data(), - test.out1.data(), test.d); - passed &= check_results_i8("turbo rate recovery dst2", dst2.data(), - test.out2.data(), test.d); + passed &= armral::utils::check_results_i8( + "turbo rate recovery dst0", dst0.data(), test.out0.data(), test.d); + passed &= armral::utils::check_results_i8( + "turbo rate recovery dst1", dst1.data(), test.out1.data(), test.d); + passed &= armral::utils::check_results_i8( + "turbo rate recovery dst2", dst2.data(), test.out2.data(), test.d); } return passed; } diff --git a/utils/cf32_utils.hpp b/utils/cf32_utils.hpp index f54c10e..c015e24 100644 --- a/utils/cf32_utils.hpp +++ b/utils/cf32_utils.hpp @@ -12,6 +12,8 @@ #include #include +namespace armral::utils { + /* * A helper class for generating random complex numbers. This class preserves * the random state between calls, so if repeatedly requesting vectors of the @@ -40,36 +42,32 @@ * Note: You must use the same cf32_random instance between calls to reuse the * random state. */ -class cf32_random { +class cf32_random : public base_random { public: cf32_random(std::initializer_list seeds = {42}) - : m_state(armral::utils::random_state::from_seeds(seeds)) {} + : base_random(seeds) {} static constexpr armral_cmplx_f32_t default_min{1.0F, 2.0F}; static constexpr armral_cmplx_f32_t default_max{2.0F, 4.0F}; armral_cmplx_f32_t one(armral_cmplx_f32_t min = default_min, armral_cmplx_f32_t max = default_max) { - return armral_cmplx_f32_t{rand(min.re, max.re), - rand(min.im, max.im)}; + return armral_cmplx_f32_t{rand(min.re, max.re), rand(min.im, max.im)}; } std::vector vector(size_t len, armral_cmplx_f32_t min = default_min, armral_cmplx_f32_t max = default_max) { - std::vector ret(len); - for (auto &cmplx : ret) { - cmplx = one(min, max); - } - return ret; + return base_random::vector_impl(len, min, + max); } std::vector & flip_signs(std::vector &vector, float32_t chance_re = 0.5F, float32_t chance_im = 0.5F) { for (auto &cmplx : vector) { - bool re_flip = rand(0, 1) < chance_re; - bool im_flip = rand(0, 1) < chance_im; + bool re_flip = rand(0, 1) < chance_re; + bool im_flip = rand(0, 1) < chance_im; cmplx.re = re_flip ? -cmplx.re : cmplx.re; cmplx.im = im_flip ? -cmplx.im : cmplx.im; } @@ -82,31 +80,21 @@ public: auto result = std::move(vector); return flip_signs(result); } - -private: - template - float32_t rand(float32_t min, float32_t max) { - armral::utils::linear_congruential_generator lcg; - return lcg.one(&m_state, min, max); - } - - armral::utils::random_state m_state; }; -static inline std::vector -allocate_random_cf32(uint32_t len) { +inline std::vector allocate_random_cf32(uint32_t len) { return cf32_random().vector(len); } -static inline std::complex cmplx_mul_widen_cf32(armral_cmplx_f32_t a, - armral_cmplx_f32_t b) { +inline std::complex cmplx_mul_widen_cf32(armral_cmplx_f32_t a, + armral_cmplx_f32_t b) { std::complex a_wide = {a.re, a.im}; std::complex b_wide = {b.re, b.im}; return a_wide * b_wide; } -static inline std::vector> -widen_cf32(const armral_cmplx_f32_t *a, int n) { +inline std::vector> widen_cf32(const armral_cmplx_f32_t *a, + int n) { std::vector> a_wide(n); for (int i = 0; i < n; i++) { a_wide[i] = {a[i].re, a[i].im}; @@ -114,7 +102,7 @@ widen_cf32(const armral_cmplx_f32_t *a, int n) { return a_wide; } -static inline std::vector +inline std::vector narrow_to_cf32(const std::vector> &a) { int n = a.size(); std::vector a_narrow(n); @@ -125,7 +113,7 @@ narrow_to_cf32(const std::vector> &a) { return a_narrow; } -static inline std::vector +inline std::vector pack_cf32(const std::vector &re, const std::vector &im) { assert(re.size() == im.size()); std::vector ret(re.size()); @@ -135,7 +123,7 @@ pack_cf32(const std::vector &re, const std::vector &im) { return ret; } -static inline std::vector +inline std::vector unpack_real_cf32(const std::vector &in) { std::vector ret(in.size()); for (unsigned i = 0; i < ret.size(); ++i) { @@ -144,7 +132,7 @@ unpack_real_cf32(const std::vector &in) { return ret; } -static inline std::vector +inline std::vector unpack_imag_cf32(const std::vector &in) { std::vector ret(in.size()); for (unsigned i = 0; i < ret.size(); ++i) { @@ -172,9 +160,9 @@ unpack_imag_cf32(const std::vector &in) { * * Returns true if the elements match elementwise, within tolerance. */ -static inline bool check_results_cf32(const char *name, const float32_t *result, - const float32_t *expected, uint32_t n, - uint32_t op_count = 400) { +inline bool check_results_cf32(const char *name, const float32_t *result, + const float32_t *expected, uint32_t n, + uint32_t op_count = 400) { bool passed = true; float32_t max_error = 0; float32_t diff_at_max_error = 0; @@ -259,10 +247,12 @@ static inline bool check_results_cf32(const char *name, const float32_t *result, * * Returns true if the elements match elementwise, within tolerance. */ -static inline bool check_results_cf32(const char *name, - const armral_cmplx_f32_t *result, - const armral_cmplx_f32_t *expected, - uint32_t n, uint32_t op_count = 400) { +inline bool check_results_cf32(const char *name, + const armral_cmplx_f32_t *result, + const armral_cmplx_f32_t *expected, uint32_t n, + uint32_t op_count = 400) { return check_results_cf32(name, (const float32_t *)result, (const float32_t *)expected, n * 2, op_count); } + +} // namespace armral::utils diff --git a/utils/cs16_utils.hpp b/utils/cs16_utils.hpp index e4794d6..9463e7d 100644 --- a/utils/cs16_utils.hpp +++ b/utils/cs16_utils.hpp @@ -6,71 +6,102 @@ #include "armral.h" #include "qint64.hpp" +#include "rng.hpp" #include #include #include #include -static inline std::vector -allocate_random_i16(uint32_t len, int16_t min, int16_t max) { - static std::mt19937 gen; - std::uniform_int_distribution dis(min, max); +namespace armral::utils { - std::vector ret(len); - for (uint32_t i = 0; i < len; ++i) { - ret[i] = dis(gen); - } - return ret; -} +/* + * A helper class for generating random complex numbers. This class preserves + * the random state between calls, so if repeatedly requesting vectors of the + * same length a uniquely random vector is returned each time. Note this is not + * the case for the allocate_random_cf32() helper function. + * + * Example usage: + * + * cs16_random random; + * auto a = random.vector(m * n); // returns std::vector + * auto b = random.vector(m * k); + * armral_cmplx_int16_t foo = random.one() + * + * You can also specify a min/max if necessary: + * + * cs16_random random; + * auto c = random.vector(n * k, {1, 10}, {10, 20}); + * auto bar = random.one({-10, -10}, {10, 10}); + * + * Note: the default is INT16_MIN + INT16_MINj to INT16_MAX + INT16_MAXj + * + * To use custom seeds just pass them to the constructor: + * + * cs16_random random({ 1337, 42 }); + * + * Note: You must use the same cs16_random instance between calls to reuse the + * random state. + */ +class cs16_random : public base_random { +public: + cs16_random(std::initializer_list seeds = {42}) + : base_random(seeds) {} -static inline std::vector allocate_random_i16(uint32_t len) { - return allocate_random_i16(len, INT16_MIN, INT16_MAX); -} + static constexpr armral_cmplx_int16_t default_min{INT16_MIN, INT16_MIN}; + static constexpr armral_cmplx_int16_t default_max{INT16_MAX, INT16_MAX}; -static inline std::vector -allocate_random_cs16(uint32_t len, int16_t min, int16_t max) { - static std::mt19937 gen; - std::uniform_int_distribution dis(min, max); + armral_cmplx_int16_t one(armral_cmplx_int16_t min = default_min, + armral_cmplx_int16_t max = default_max) { + return armral_cmplx_int16_t{rand(min.re, max.re), rand(min.im, max.im)}; + } - std::vector ret(len); - for (uint32_t i = 0; i < len; ++i) { - ret[i].re = dis(gen); - ret[i].im = dis(gen); + std::vector + vector(size_t len, armral_cmplx_int16_t min = default_min, + armral_cmplx_int16_t max = default_max) { + return base_random::vector_impl(len, min, + max); } - return ret; + + std::vector + shifted_vector(size_t len, armral_cmplx_int16_t min = default_min, + armral_cmplx_int16_t max = default_max) { + std::vector ret(len); + for (uint32_t i = 0; i < len; ++i) { + ret[i] = one(min, max); + int16_t shift = rand(0, 16); + ret[i].re >>= shift; + ret[i].im >>= shift; + } + return ret; + } +}; + +inline std::vector +allocate_random_cs16(uint32_t len, int16_t min, int16_t max) { + return cs16_random().vector(len, armral_cmplx_int16_t{min, min}, + armral_cmplx_int16_t{max, max}); } -static inline std::vector +inline std::vector allocate_random_shifted_cs16(uint32_t len, int16_t min, int16_t max) { - static std::mt19937 gen; - std::uniform_int_distribution val_dis(min, max); - std::uniform_int_distribution shift_dis(0, 16); - - std::vector ret(len); - for (uint32_t i = 0; i < len; ++i) { - auto shift = shift_dis(gen); - ret[i].re = val_dis(gen) >> shift; - ret[i].im = val_dis(gen) >> shift; - } - return ret; + return cs16_random().shifted_vector(len, armral_cmplx_int16_t{min, min}, + armral_cmplx_int16_t{max, max}); } -static inline std::vector -allocate_random_cs16(uint32_t len) { +inline std::vector allocate_random_cs16(uint32_t len) { return allocate_random_cs16(len, INT16_MIN, INT16_MAX); } -static inline std::complex -cmplx_mul_widen_cs16(armral_cmplx_int16_t a, armral_cmplx_int16_t b) { +inline std::complex cmplx_mul_widen_cs16(armral_cmplx_int16_t a, + armral_cmplx_int16_t b) { std::complex a_wide = {a.re, a.im}; std::complex b_wide = {b.re, b.im}; return a_wide * b_wide; } -static inline void -scale_and_truncate_cs16(std::vector &ret, - const armral_cmplx_int16_t *scale) { +inline void scale_and_truncate_cs16(std::vector &ret, + const armral_cmplx_int16_t *scale) { for (unsigned i = 0; i < ret.size(); ++i) { std::complex res = cmplx_mul_widen_cs16(ret[i], *scale); // truncate to Q15 directly, no rounding. @@ -79,7 +110,7 @@ scale_and_truncate_cs16(std::vector &ret, } } -static inline std::vector +inline std::vector pack_cs16(const std::vector &re, const std::vector &im) { assert(re.size() == im.size()); std::vector ret(re.size()); @@ -89,7 +120,7 @@ pack_cs16(const std::vector &re, const std::vector &im) { return ret; } -static inline std::vector +inline std::vector unpack_real_cs16(const std::vector &in) { std::vector ret(in.size()); for (unsigned i = 0; i < ret.size(); ++i) { @@ -98,7 +129,7 @@ unpack_real_cs16(const std::vector &in) { return ret; } -static inline std::vector +inline std::vector unpack_imag_cs16(const std::vector &in) { std::vector ret(in.size()); for (unsigned i = 0; i < ret.size(); ++i) { @@ -117,8 +148,8 @@ unpack_imag_cs16(const std::vector &in) { * * Returns true if the elements match elementwise, within tolerance. */ -static inline bool check_results_cs16(const char *name, const int16_t *result, - const int16_t *expected, uint32_t n) { +inline bool check_results_cs16(const char *name, const int16_t *result, + const int16_t *expected, uint32_t n) { bool passed = true; for (uint32_t i = 0; i < n; ++i) { @@ -150,10 +181,12 @@ static inline bool check_results_cs16(const char *name, const int16_t *result, * * Returns true if the elements match elementwise, within tolerance. */ -static inline bool check_results_cs16(const char *name, - const armral_cmplx_int16_t *result, - const armral_cmplx_int16_t *expected, - uint32_t n) { +inline bool check_results_cs16(const char *name, + const armral_cmplx_int16_t *result, + const armral_cmplx_int16_t *expected, + uint32_t n) { return check_results_cs16(name, (const int16_t *)result, (const int16_t *)expected, n * 2); } + +} // namespace armral::utils diff --git a/utils/fft_utils.hpp b/utils/fft_utils.hpp index c34e259..494ea90 100644 --- a/utils/fft_utils.hpp +++ b/utils/fft_utils.hpp @@ -12,9 +12,10 @@ #define M_PI 3.14159265358979323846 #endif -static inline void fft_ref(int n, int s, armral_fft_direction_t dir, - const std::complex *x, - std::complex *y) { +namespace armral::utils { + +inline void fft_ref(int n, int s, armral_fft_direction_t dir, + const std::complex *x, std::complex *y) { using namespace std::complex_literals; if (n % 2 == 0) { fft_ref(n / 2, 2 * s, dir, x, y); @@ -38,3 +39,5 @@ static inline void fft_ref(int n, int s, armral_fft_direction_t dir, } } } + +} // namespace armral::utils diff --git a/utils/int8_utils.hpp b/utils/int8_utils.hpp deleted file mode 100644 index bc5bbdc..0000000 --- a/utils/int8_utils.hpp +++ /dev/null @@ -1,86 +0,0 @@ -/* - Arm RAN Acceleration Library - SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates -*/ -#pragma once - -#include -#include -#include -#include - -static inline std::vector allocate_random_i8(uint32_t len) { - static std::mt19937 gen; - int32_t lo = INT8_MIN; - int32_t hi = INT8_MAX; - std::uniform_int_distribution dis(lo, hi); - - std::vector ret(len); - for (uint32_t i = 0; i < len; ++i) { - ret[i] = static_cast(dis(gen)); - } - return ret; -} - -static inline std::vector allocate_random_u8(uint32_t len, uint8_t min, - uint8_t max) { - static std::mt19937 gen; - uint32_t lo = min; - uint32_t hi = max; - std::uniform_int_distribution dis(lo, hi); - - std::vector ret(len); - for (uint32_t i = 0; i < len; ++i) { - ret[i] = static_cast(dis(gen)); - } - return ret; -} - -static inline std::vector allocate_random_u8(uint32_t len) { - return allocate_random_u8(len, 0, UINT8_MAX); -} - -static inline std::vector allocate_random_bits(uint32_t nbits) { - auto ret = allocate_random_u8((nbits + 7) / 8); - // if nbits is not a byte boundary, ensure only (nbits % 8) bits are set. - if (nbits % 8 != 0) { - ret[nbits / 8] &= 0xff00U >> (nbits % 8); - } - return ret; -} - -/* - * Check an array of results against an array of expected values. - * - * name: The name of the function, used when printing errors. - * result: A pointer to an array of results, length n. - * expected: A pointer to an array of expected values, length n. - * n: The length of the two arrays, in elements. - * - * Returns true if the elements match elementwise. - */ -template -static inline bool check_results(const char *name, const T *result, - const T *expected, uint32_t n) { - bool passed = true; - - for (uint32_t i = 0; i < n; ++i) { - if (result[i] != expected[i]) { - // GCOVR_EXCL_START - passed = false; - printf("Error! [%s] result[%u]= %d (0x%x) and expected[%u]= %d (0x%x)\n", - name, i, result[i], result[i], i, expected[i], expected[i]); - // GCOVR_EXCL_STOP - } - } - - if (passed) { - printf("[%s] - check result: OK\n", name); - } else { - printf("[%s] - check result: ERROR\n", name); // GCOVR_EXCL_LINE - } - return passed; -} - -constexpr inline auto check_results_u8 = check_results; -constexpr inline auto check_results_i8 = check_results; diff --git a/utils/int_utils.hpp b/utils/int_utils.hpp new file mode 100644 index 0000000..6d65eae --- /dev/null +++ b/utils/int_utils.hpp @@ -0,0 +1,149 @@ +/* + Arm RAN Acceleration Library + SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates +*/ +#pragma once + +#include "rng.hpp" + +#include +#include +#include +#include + +namespace armral::utils { + +/* + * A helper class for generating random integers. This class preserves + * the random state between calls, so if repeatedly requesting vectors of the + * same length a uniquely random vector is returned each time. Note this is not + * the case for the allocate_random_X() helper functions. + * + * Example usage: + * + * int_random random; + * auto a = random.vector(m * n); // returns std::vector + * auto b = random.vector(m * k); + * int16_t foo = random.one() + * + * You can also specify a min/max if necessary: + * + * random random; + * auto c = random.vector(n * k, 1, 10}); + * auto bar = random.one(-10, 10); + * + * Note: the default is std::numeric_limits::min() to + * std::numeric_limits::max(). + * + * To use custom seeds just pass them to the constructor: + * + * int_random random({ 1337, 42 }); + * + * Note: You must use the same int_random instance between calls to reuse the + * random state. + */ +template +class int_random : public base_random { +public: + int_random(std::initializer_list seeds = {42}) + : base_random(seeds) {} + + static constexpr T default_min{std::numeric_limits::min()}; + static constexpr T default_max{std::numeric_limits::max()}; + + T one(T min = default_min, T max = default_max) { + return base_random::rand(min, max); + } + + std::vector vector(size_t len, T min = default_min, T max = default_max) { + return base_random::vector_impl(len, min, max); + } +}; + +inline std::vector allocate_random_u8(uint32_t len, uint8_t min, + uint8_t max) { + return int_random().vector(len, min, max); +} + +inline std::vector allocate_random_u8(uint32_t len) { + return allocate_random_u8(len, 0, UINT8_MAX); +} + +inline std::vector allocate_random_i8(uint32_t len) { + return int_random().vector(len, INT8_MIN, INT8_MAX); +} + +inline std::vector allocate_random_i16(uint32_t len, int16_t min, + int16_t max) { + return int_random().vector(len, min, max); +} + +inline std::vector allocate_random_i16(uint32_t len) { + return allocate_random_i16(len, INT16_MIN, INT16_MAX); +} + +/* + * A helper class for generating random bit vectors. It differs from + * the above int_random in that if the number of bits requested + * does not align with a byte boundary it only sets (number of bits % 8) + * bits of the output in the final byte. + * + * We inherit privately so that we do not expose the vector() methods + * from the base classes. + */ +class bit_random : private int_random { +public: + bit_random(std::initializer_list seeds = {42}) + : int_random(seeds) {} + + std::vector bit_vector(uint32_t nbits) { + auto ret = this->vector((nbits + 7) / 8); + // if nbits is not a byte boundary, ensure only (nbits % 8) bits are set. + if (nbits % 8 != 0) { + ret[nbits / 8] &= 0xff00U >> (nbits % 8); + } + return ret; + } +}; + +inline std::vector allocate_random_bits(uint32_t nbits) { + return bit_random().bit_vector(nbits); +} + +/* + * Check an array of results against an array of expected values. + * + * name: The name of the function, used when printing errors. + * result: A pointer to an array of results, length n. + * expected: A pointer to an array of expected values, length n. + * n: The length of the two arrays, in elements. + * + * Returns true if the elements match elementwise. + */ +template +inline bool check_results(const char *name, const T *result, const T *expected, + uint32_t n) { + bool passed = true; + + for (uint32_t i = 0; i < n; ++i) { + if (result[i] != expected[i]) { + // GCOVR_EXCL_START + passed = false; + printf("Error! [%s] result[%u]= %d (0x%x) and expected[%u]= %d (0x%x)\n", + name, i, result[i], result[i], i, expected[i], expected[i]); + // GCOVR_EXCL_STOP + } + } + + if (passed) { + printf("[%s] - check result: OK\n", name); + } else { + printf("[%s] - check result: ERROR\n", name); // GCOVR_EXCL_LINE + } + return passed; +} + +constexpr inline auto check_results_u8 = check_results; +constexpr inline auto check_results_i8 = check_results; + +} // namespace armral::utils diff --git a/utils/matrix_utils.hpp b/utils/matrix_utils.hpp index e3a5d0c..b12f622 100644 --- a/utils/matrix_utils.hpp +++ b/utils/matrix_utils.hpp @@ -8,14 +8,16 @@ #include "reference_linalg.hpp" #include "rng.hpp" +namespace armral::utils { + /* * Generate random values, the resulting matrix will have linearly independent * columns with probability almost 1. */ -static inline std::vector +inline std::vector allocate_random_cf32_lin_ind(uint32_t len) { - static armral::utils::linear_congruential_generator lcg; - auto state = armral::utils::random_state::from_seeds({42}); + static linear_congruential_generator lcg; + auto state = random_state::from_seeds({42}); std::vector ret(len); for (uint32_t i = 0; i < len; ++i) { @@ -28,7 +30,7 @@ allocate_random_cf32_lin_ind(uint32_t len) { /* * Generate random invertible matrices. */ -static inline std::vector +inline std::vector gen_invertible_matrix(uint32_t m, float32_t scale_re = 1.0F, float32_t scale_im = 1.0F) { @@ -57,7 +59,7 @@ gen_invertible_matrix(uint32_t m, float32_t scale_re = 1.0F, /* * Generate a batch of random invertible matrices */ -static inline std::vector +inline std::vector gen_invertible_matrix_batch(uint32_t batch_size, uint32_t m, float32_t scale_re = 1.0F, float32_t scale_im = 1.0F) { @@ -78,7 +80,7 @@ gen_invertible_matrix_batch(uint32_t batch_size, uint32_t m, * Generate random Hermitian matrices (with option to force positive * definiteness) */ -static inline std::vector +inline std::vector gen_hermitian_matrix(uint32_t m, bool is_hpd = false, float32_t scale_re = 1.0F, float32_t scale_im = 1.0F, bool perf = false) { @@ -145,7 +147,7 @@ gen_hermitian_matrix(uint32_t m, bool is_hpd = false, float32_t scale_re = 1.0F, * Generate a batch of random Hermitian matrices (with option to force positive * definiteness) */ -static inline std::vector +inline std::vector gen_hermitian_matrix_batch(uint32_t batch_size, uint32_t m, bool is_hpd = false, float32_t scale_re = 1.0F, float32_t scale_im = 1.0F, bool perf = false) { @@ -243,9 +245,9 @@ static bool check_results_mat_inv( /* * Check that MM^{-1} = M^{-1}M = Id. */ -static inline bool check_results_identity(const armral_cmplx_f32_t *mat, - const armral_cmplx_f32_t *inv_m, - uint32_t m, int verbose = 0) { +inline bool check_results_identity(const armral_cmplx_f32_t *mat, + const armral_cmplx_f32_t *inv_m, uint32_t m, + int verbose = 0) { bool passed = true; // Init arrays std::vector id(m * m); @@ -288,8 +290,8 @@ static inline bool check_results_identity(const armral_cmplx_f32_t *mat, * Unpack data from batched format into a contiguous array */ template -static inline void unpack_data(unsigned batch, unsigned batch_size, - const T *src, T *dst, unsigned length) { +inline void unpack_data(unsigned batch, unsigned batch_size, const T *src, + T *dst, unsigned length) { for (unsigned i = 0; i < length; ++i) { dst[i] = src[i * batch_size + batch]; } @@ -299,15 +301,15 @@ static inline void unpack_data(unsigned batch, unsigned batch_size, * Pack data from a contiguous array into batched format */ template -static inline void pack_data(unsigned batch, unsigned batch_size, const T *src, - T *dst, unsigned length) { +inline void pack_data(unsigned batch, unsigned batch_size, const T *src, T *dst, + unsigned length) { for (unsigned i = 0; i < length; ++i) { dst[i * batch_size + batch] = src[i]; } } -static inline void print_cmplx_row(uint32_t m, const armral_cmplx_f32_t *row, - uint32_t stride = 1) { +inline void print_cmplx_row(uint32_t m, const armral_cmplx_f32_t *row, + uint32_t stride = 1) { printf("[%.3g+%.3gj", row[0].re, row[0].im); for (unsigned i = stride; i < m * stride; i += stride) { printf(", %.3g+%.3gj", row[i].re, row[i].im); @@ -315,9 +317,9 @@ static inline void print_cmplx_row(uint32_t m, const armral_cmplx_f32_t *row, printf("]"); } -static inline void print_cmplx_mat(const std::string &ref, uint32_t m, - const armral_cmplx_f32_t *a, - uint32_t batch_size = 1) { +inline void print_cmplx_mat(const std::string &ref, uint32_t m, + const armral_cmplx_f32_t *a, + uint32_t batch_size = 1) { for (unsigned i = 0; i < batch_size; i++) { printf("%s[%u]=[", ref.c_str(), i); print_cmplx_row(m, a + i, batch_size); @@ -333,7 +335,7 @@ static inline void print_cmplx_mat(const std::string &ref, uint32_t m, * Return the number of floating-point operations required to calculate a length-n * complex dot product */ -static inline uint32_t cmplx_dot_nflops(uint32_t n) { +inline uint32_t cmplx_dot_nflops(uint32_t n) { // A complex multiplication requires 6 floating-point operations uint32_t op_mul = 6; // A complex multiply-accumulate requires 8 floating-point operations @@ -349,3 +351,5 @@ static inline uint32_t cmplx_dot_nflops(uint32_t n) { } return nflops; } + +} // namespace armral::utils diff --git a/utils/qint64.hpp b/utils/qint64.hpp index 02ed5b0..f5f8756 100644 --- a/utils/qint64.hpp +++ b/utils/qint64.hpp @@ -7,6 +7,8 @@ #include #include +namespace armral::utils { + /// A saturating 64b signed integer. class qint64_t { int64_t val; @@ -190,4 +192,6 @@ std::complex operator*(std::complex a, std::complex b) { return {a.real() * b.real() - a.imag() * b.imag(), a.real() * b.imag() + a.imag() * b.real()}; -} \ No newline at end of file +} + +} // namespace armral::utils diff --git a/utils/reference_linalg.hpp b/utils/reference_linalg.hpp index 605b3db..d01d1bf 100644 --- a/utils/reference_linalg.hpp +++ b/utils/reference_linalg.hpp @@ -12,13 +12,15 @@ #include #include +namespace armral::utils { + /* * Multiply a vector by a uniform scaling factor. * * This is explicitly noinline since it avoids a compiler bug with GCC 8.2.0 * where the code is incorrectly inlined into gen_hermitian_matrix. */ -static inline void __attribute__((noinline)) +inline void __attribute__((noinline)) cscal(uint32_t n, armral_cmplx_f32_t *a, armral_cmplx_f32_t s) { for (unsigned i = 0; i < n; ++i) { a[i].re *= s.re; @@ -29,17 +31,17 @@ cscal(uint32_t n, armral_cmplx_f32_t *a, armral_cmplx_f32_t s) { /* * ZGEMM: General complex double matrix multiplication C = beta*C + alpha*A*B */ -static inline void reference_zgemm(uint16_t m, uint16_t n, uint16_t p, - const double alpha, - const std::vector> &a, - const std::vector> &b, - const double beta, - std::vector> &c) { +inline void reference_zgemm(uint16_t m, uint16_t n, uint16_t k, + const double alpha, + const std::vector> &a, + const std::vector> &b, + const double beta, + std::vector> &c) { for (int i = 0; i < m; ++i) { - for (int j = 0; j < p; ++j) { - c[i * p + j] = beta * c[i * p + j]; - for (int k = 0; k < n; ++k) { - c[i * p + j] += alpha * a[i * n + k] * b[k * p + j]; + for (int nn = 0; nn < n; ++nn) { + c[i * n + nn] = beta * c[i * n + nn]; + for (int j = 0; j < k; ++j) { + c[i * n + nn] += alpha * a[i * k + j] * b[j * n + nn]; } } } @@ -48,7 +50,7 @@ static inline void reference_zgemm(uint16_t m, uint16_t n, uint16_t p, /* * Reorder matrices to allow easy access to blocks. */ -static unsigned zorder_y_of(unsigned index) { +unsigned zorder_y_of(unsigned index) { unsigned y = 0; for (unsigned b = 0, k = 0; (1U << b) <= index; b += 2, k++) { y += static_cast((index & (1U << b)) != 0) << k; @@ -56,14 +58,14 @@ static unsigned zorder_y_of(unsigned index) { return y; } -static unsigned zorder_x_of(unsigned index) { +unsigned zorder_x_of(unsigned index) { return zorder_y_of(index >> 1); } /* * Convert from z-order to row-major. */ -static std::vector> +std::vector> zorder_to_rowmajor(uint32_t m, const std::vector> &z) { std::vector> a(m * m); for (unsigned i = 0; i < m; ++i) { @@ -79,7 +81,7 @@ zorder_to_rowmajor(uint32_t m, const std::vector> &z) { /* * Convert from row-major to z-order. */ -static std::vector> +std::vector> rowmajor_to_zorder(uint32_t m, const std::vector> &a) { std::vector> z(m * m); for (unsigned i = 0; i < m; ++i) { @@ -95,11 +97,11 @@ rowmajor_to_zorder(uint32_t m, const std::vector> &a) { /* * General matrix multiplication on matrices stored in z-order. */ -static void reference_zgemm_zorder(uint32_t m, const double alpha, - const std::vector> &a, - const std::vector> &b, - const double beta, - std::vector> &c) { +void reference_zgemm_zorder(uint32_t m, const double alpha, + const std::vector> &a, + const std::vector> &b, + const double beta, + std::vector> &c) { // Convert to row-major auto a64 = zorder_to_rowmajor(m, a); auto b64 = zorder_to_rowmajor(m, b); @@ -112,7 +114,7 @@ static void reference_zgemm_zorder(uint32_t m, const double alpha, c = rowmajor_to_zorder(m, c64); } -static std::vector> +std::vector> reference_zgeinv_2x2(uint32_t m, const std::vector> &mat) { std::vector> inv_m(m * m); // Inverse 2x2 matrix using analytic expression @@ -124,7 +126,7 @@ reference_zgeinv_2x2(uint32_t m, const std::vector> &mat) { return inv_m; } -static std::vector> +std::vector> reference_zgeinv_3x3(uint32_t m, const std::vector> &mat) { std::vector> inv_m(m * m); auto a0 = mat[0]; @@ -177,7 +179,7 @@ reference_zgeinv_3x3(uint32_t m, const std::vector> &mat) { * M = [A B] M^{-1} = [X^{-1} -A^{-1}BU^{-1}] * [C D] [-D^{-1}CX^{-1} U^{-1} ] */ -static std::vector> +std::vector> reference_zgeinv(uint32_t m, const std::vector> &mat) { if (m == 2) { return reference_zgeinv_2x2(m, mat); @@ -231,7 +233,7 @@ reference_zgeinv(uint32_t m, const std::vector> &mat) { return inv_m; } -static inline std::vector> +inline std::vector> reference_zgeinv_small(uint32_t m, const std::vector> &mat) { if (m == 2) { @@ -266,7 +268,7 @@ armral_cmplx_f32_t complex_convert(std::complex cmplx) { */ template std::vector> -convert_cf32_array_to_vector(uint16_t nvalues, const armral_cmplx_f32_t *a) { +convert_cf32_array_to_vector(uint32_t nvalues, const armral_cmplx_f32_t *a) { std::vector> out(nvalues); for (unsigned i = 0; i < nvalues; ++i) { out[i] = std::complex(a[i].re, a[i].im); @@ -275,7 +277,7 @@ convert_cf32_array_to_vector(uint16_t nvalues, const armral_cmplx_f32_t *a) { } template -void convert_vector_to_cf32_array(uint16_t nvalues, +void convert_vector_to_cf32_array(uint32_t nvalues, const std::vector> &a, armral_cmplx_f32_t *b) { for (unsigned i = 0; i < nvalues; ++i) { @@ -286,24 +288,24 @@ void convert_vector_to_cf32_array(uint16_t nvalues, /* * Reference matrix multiplication (C=A*B) on cs16 input matrices */ -static inline void reference_matmul_cs16(uint16_t m, uint16_t n, uint16_t k, - const armral_cmplx_int16_t *a, - const armral_cmplx_int16_t *b, - armral_cmplx_int16_t *c, int round) { +inline void reference_matmul_cs16(uint16_t m, uint16_t n, uint16_t k, + const armral_cmplx_int16_t *a, + const armral_cmplx_int16_t *b, + armral_cmplx_int16_t *c, int round) { assert(round == 0 || round == 1); for (unsigned i = 0; i < m; ++i) { - for (unsigned kk = 0; kk < k; ++kk) { + for (unsigned nn = 0; nn < n; ++nn) { std::complex acc; - for (unsigned j = 0; j < n; ++j) { - auto ae = a[i * n + j]; - auto be = b[j * k + kk]; - auto intermed = cmplx_mul_widen_cs16(ae, be); + for (unsigned j = 0; j < k; ++j) { + auto ae = a[i * k + j]; + auto be = b[j * n + nn]; + auto intermed = armral::utils::cmplx_mul_widen_cs16(ae, be); acc += intermed; } // round works by adding one to the intermediate result, to ensure we get // exact rounding if required (e.g. such that 0b011.1 rounds to 0b100.0). - c[i * k + kk].re = (((acc.real() >> 14) + round) >> 1).get16(); - c[i * k + kk].im = (((acc.imag() >> 14) + round) >> 1).get16(); + c[i * n + nn].re = (((acc.real() >> 14) + round) >> 1).get16(); + c[i * n + nn].im = (((acc.imag() >> 14) + round) >> 1).get16(); } } } @@ -311,23 +313,23 @@ static inline void reference_matmul_cs16(uint16_t m, uint16_t n, uint16_t k, /* * Reference matrix multiplication (C=A*B) on cf32 input matrices */ -static inline void reference_matmul_cf32(uint16_t m, uint16_t n, uint16_t p, - const armral_cmplx_f32_t *a, - const armral_cmplx_f32_t *b, - armral_cmplx_f32_t *c) { +inline void reference_matmul_cf32(uint16_t m, uint16_t n, uint16_t k, + const armral_cmplx_f32_t *a, + const armral_cmplx_f32_t *b, + armral_cmplx_f32_t *c) { // Convert float to double - auto a64 = convert_cf32_array_to_vector(m * n, a); - auto b64 = convert_cf32_array_to_vector(n * p, b); - auto c64 = convert_cf32_array_to_vector(m * p, c); + auto a64 = convert_cf32_array_to_vector(m * k, a); + auto b64 = convert_cf32_array_to_vector(k * n, b); + auto c64 = convert_cf32_array_to_vector(m * n, c); // Double precision matrix multiply - reference_zgemm(m, n, p, 1.0, a64, b64, 0.0, c64); + reference_zgemm(m, n, k, 1.0, a64, b64, 0.0, c64); // Convert back to float for (unsigned i = 0; i < m; ++i) { - for (unsigned k = 0; k < p; ++k) { - c[i * p + k].re = c64[i * p + k].real(); - c[i * p + k].im = c64[i * p + k].imag(); + for (unsigned nn = 0; nn < n; ++nn) { + c[i * n + nn].re = c64[i * n + nn].real(); + c[i * n + nn].im = c64[i * n + nn].imag(); } } } @@ -336,7 +338,7 @@ static inline void reference_matmul_cf32(uint16_t m, uint16_t n, uint16_t p, * Reference conjugate transpose matrix multiplication (C=B * A^H) on cf32 input * matrices */ -static inline void reference_matmul_bah_cf32( +inline void reference_matmul_bah_cf32( uint16_t m, uint16_t n, const armral_cmplx_f32_t *__restrict p_src_a, const armral_cmplx_f32_t *__restrict p_src_b, armral_cmplx_f32_t *p_dst) { for (uint16_t i = 0; i < n; i++) { @@ -356,27 +358,27 @@ static inline void reference_matmul_bah_cf32( * Reference conjugate transpose matrix multiplication (C=A^H * B) on cf32 input * matrices */ -static inline void +inline void reference_matmul_ahb_cf32(uint16_t m, uint16_t n, uint16_t k, const armral_cmplx_f32_t *__restrict p_src_a, const armral_cmplx_f32_t *__restrict p_src_b, armral_cmplx_f32_t *p_dst) { // For every row of output C (column of A, row of A^H)... - for (uint32_t j = 0; j < n; j++) { + for (uint32_t j = 0; j < m; j++) { // For every column of output C (column of B)... - for (uint32_t i = 0; i < k; i++) { + for (uint32_t i = 0; i < n; i++) { // Every row of A and B (where row of A = column of A^H) std::complex dot = 0.; - for (uint32_t r = 0; r < m; r++) { - auto a_jr = complex_convert(p_src_a[r * n + j]); - auto b_ir = complex_convert(p_src_b[r * k + i]); + for (uint32_t r = 0; r < k; r++) { + auto a_jr = complex_convert(p_src_a[r * m + j]); + auto b_ir = complex_convert(p_src_b[r * n + i]); dot += std::conj(a_jr) * b_ir; } - p_dst[k * j + i] = complex_convert(dot); + p_dst[n * j + i] = complex_convert(dot); } } } @@ -384,7 +386,7 @@ reference_matmul_ahb_cf32(uint16_t m, uint16_t n, uint16_t k, /* * Reference matrix multiplication (C=A*A^H) on a cf32 input matrix */ -static inline void +inline void reference_matmul_aah_cf32(uint16_t m, uint16_t n, const armral_cmplx_f32_t *__restrict p_src_a, armral_cmplx_f32_t *p_dst_c) { @@ -416,7 +418,7 @@ reference_matmul_aah_cf32(uint16_t m, uint16_t n, /* * Reference matrix multiplication (C=A^H*A) on a cf32 input matrix */ -static inline void +inline void reference_matmul_aha_cf32(uint16_t m, uint16_t n, const armral_cmplx_f32_t *__restrict p_src, armral_cmplx_f32_t *p_dst) { @@ -440,9 +442,8 @@ reference_matmul_aha_cf32(uint16_t m, uint16_t n, /* * Run reference Matrix Inversion based on blockwise approach. */ -static inline void reference_matinv_block(uint32_t m, - const armral_cmplx_f32_t *a, - armral_cmplx_f32_t *b) { +inline void reference_matinv_block(uint32_t m, const armral_cmplx_f32_t *a, + armral_cmplx_f32_t *b) { // Init double precision input matrix (use z-order for easy access to blocks) auto a_tmp = convert_cf32_array_to_vector(m * m, a); @@ -462,3 +463,5 @@ static inline void reference_matinv_block(uint32_t m, convert_vector_to_cf32_array(m * m, b_tmp, b); } } + +} // namespace armral::utils diff --git a/utils/rng.cpp b/utils/rng.cpp index 7904bc0..33887ee 100644 --- a/utils/rng.cpp +++ b/utils/rng.cpp @@ -2,11 +2,10 @@ Arm RAN Acceleration Library SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates */ +#include #include "rng.hpp" -#include - namespace armral::utils { static inline uint64_t lcg_step(uint64_t x) { diff --git a/utils/rng.hpp b/utils/rng.hpp index fb129b0..358eb47 100644 --- a/utils/rng.hpp +++ b/utils/rng.hpp @@ -6,8 +6,10 @@ #pragma once #include +#include #include #include +#include namespace armral::utils { @@ -41,13 +43,32 @@ public: * @param[in,out] state The state to use as a seed for the generator. * @param[in] min The lower bound of the generated value. * @param[out] max The upper bound of the generated value. - * @returns a pseudo-randomly generated integer. + * @returns a pseudo-randomly generated floating-point value. */ - template>> + template, bool> = true> T one(random_state *state, const T min, const T max) { return (max - min) * one(state) + min; } + /** + * Returns a single pseudo-random value using a simple linear congruential + * generator based on the specified state. The state is updated as part of + * the call. + * + * The range of the returned value is bounded by [min, max]. Note that this + * is only defined for integer types. + * + * @param[in,out] state The state to use as a seed for the generator. + * @param[in] min The lower bound of the generated value. + * @param[out] max The upper bound of the generated value. + * @returns a pseudo-randomly generated integer. + */ + template, bool> = true> + T one(random_state *state, const T min, const T max) { + return ((max - min + 1) * one(state)) + min; + } + /** * Updates the state as if `one` was called `n` times and discarded the * result. @@ -80,4 +101,30 @@ struct random_state { static random_state from_seeds(std::initializer_list seeds); }; +// An abstract base class from which other stateful RNG helper classes are defined +template +class base_random { +public: + base_random(std::initializer_list seeds = {42}) + : m_state(random_state::from_seeds(seeds)) {} + + virtual CmplxType one(CmplxType min, CmplxType max) = 0; + +protected: + std::vector vector_impl(size_t len, CmplxType min, CmplxType max) { + std::vector ret(len); + for (auto &value : ret) { + value = one(min, max); + } + return ret; + } + + RealType rand(RealType min, RealType max) { + linear_congruential_generator lcg; + return lcg.one(&m_state, min, max); + } + + random_state m_state; +}; + } // namespace armral::utils -- GitLab