diff --git a/CMakeLists.txt b/CMakeLists.txt index fa6993d54105f9219ce2d084d52d34933ef7ca43..5aecb07e23eb309b0400ee77dc98c65421495b08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -336,6 +336,12 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp test/tests/matmul_test.cpp ) + + set_source_files_properties( + test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp + test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp + test/tests/matmul_test.cpp + PROPERTIES COMPILE_FLAGS "-Wpedantic") endif() target_link_libraries(kleidiai_test diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp index 1a9b9483541a9756d08c08d2f9955bf1925e5cb4..730dd86e59259e7cae48620144aceb01ce8a45c7 100644 --- a/test/common/data_format.hpp +++ b/test/common/data_format.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -34,9 +34,9 @@ public: /// @param[in] subblock_height Sub-block height. /// @param[in] subblock_width Sub-block width. DataFormat( - DataType data_type, size_t block_height = 0, size_t block_width = 0, PackFormat pack_format = PackFormat::NONE, - DataType zero_point_dt = DataType::UNKNOWN, DataType scale_dt = DataType::UNKNOWN, size_t subblock_height = 0, - size_t subblock_width = 0) noexcept; + DataType data_type = DataType::UNKNOWN, size_t block_height = 0, size_t block_width = 0, + PackFormat pack_format = PackFormat::NONE, DataType zero_point_dt = DataType::UNKNOWN, + DataType scale_dt = DataType::UNKNOWN, size_t subblock_height = 0, size_t subblock_width = 0) noexcept; /// Equality operator. [[nodiscard]] bool operator==(const DataFormat& rhs) const; diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index dfdc56d15b9650cc08c6dadde5365bb58da039b1..5b3e2424bae08e93799cd33cb09f620693da1725 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -28,72 +28,72 @@ struct MatMulShape { /// Matrix multiplication method. struct MatMulMethod { - std::string_view name; ///< Name of matmul method. + std::string_view name{}; ///< Name of matmul method. size_t m0{0}; ///< Block size in M dimension. size_t n0{0}; ///< Block size in N dimension. size_t k0{0}; ///< Block size in K dimension. - DataFormat dst_format; ///< Data format of the destination matrix. - DataFormat lhs_format; ///< Data format of the LHS matrix. - DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. - DataFormat rhs_format; ///< Data format of the RHS matrix. - DataFormat packed_rhs_format; ///< Data format of the packed RHS matrix. - DataFormat bias_format; ///< Data format of the bias vector. + DataFormat dst_format{}; ///< Data format of the destination matrix. + DataFormat lhs_format{}; ///< Data format of the LHS matrix. + DataFormat packed_lhs_format{}; ///< Data format of the packed LHS matrix. + DataFormat rhs_format{}; ///< Data format of the RHS matrix. + DataFormat packed_rhs_format{}; ///< Data format of the packed RHS matrix. + DataFormat bias_format{}; ///< Data format of the bias vector. /// Check if CPU supports required features. /// /// @return Supported (true) or not supported (false). - std::function fn_is_supported; + std::function fn_is_supported{nullptr}; /// Gets mr value. /// /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). /// /// @return The mr value. - std::function fn_get_mr; + std::function fn_get_mr{nullptr}; /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). /// /// @return The nr value. - std::function fn_get_nr; + std::function fn_get_nr{nullptr}; /// Gets kr value. /// /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). /// /// @return The kr value. - std::function fn_get_kr; + std::function fn_get_kr{nullptr}; /// Gets sr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The sr value. - std::function fn_get_sr; + std::function fn_get_sr{nullptr}; /// Gets m step value for main kernel. /// /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. - std::function fn_get_main_m_step; + std::function fn_get_main_m_step{nullptr}; /// Gets n step value for RHS packing kernel. /// /// The starting row index must be divisible by `n_step`. /// /// @return The n step value. - std::function fn_get_pack_rhs_n_step; + std::function fn_get_pack_rhs_n_step{nullptr}; /// Gets n step value for main kernel. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. - std::function fn_get_main_n_step; + std::function fn_get_main_n_step{nullptr}; /// Gets the offset in bytes of the LHS matrix. /// @@ -101,7 +101,7 @@ struct MatMulMethod { /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes. - std::function fn_get_lhs_offset; + std::function fn_get_lhs_offset{nullptr}; /// Gets the size in bytes of the packed LHS matrix. /// @@ -112,7 +112,7 @@ struct MatMulMethod { /// @param[in] sr Unused. Must be 1. /// /// @return The size in bytes. - std::function fn_get_packed_lhs_size; + std::function fn_get_packed_lhs_size{nullptr}; /// Gets the offset in bytes of the packed LHS matrix. /// @@ -120,7 +120,7 @@ struct MatMulMethod { /// @param[in] k Size of the matrix in K dimension. /// /// @return The offset in bytes. - std::function fn_get_packed_lhs_offset; + std::function fn_get_packed_lhs_offset{nullptr}; /// Preprocesses the LHS matrix. /// @@ -136,7 +136,7 @@ struct MatMulMethod { std::function - fn_pack_lhs; + fn_pack_lhs{nullptr}; /// Gets a value indicating whether LHS packing is needed. [[nodiscard]] bool is_pack_lhs_needed() const { @@ -148,7 +148,7 @@ struct MatMulMethod { /// @param[in] n_idx Coordinate of the matrix in N dimension. /// /// @return The offset in bytes. - std::function fn_get_rhs_offset; + std::function fn_get_rhs_offset{nullptr}; /// Gets the size in bytes of the packed RHS matrix. /// @@ -156,7 +156,7 @@ struct MatMulMethod { /// @param[in] k Size of the matrix in K dimension. /// /// @return The size in bytes. - std::function fn_get_packed_rhs_size; + std::function fn_get_packed_rhs_size{nullptr}; /// Gets the size in bytes of the packed RHS matrix. /// @@ -174,7 +174,7 @@ struct MatMulMethod { /// @param[in] k Size of the matrix in K dimension. /// /// @return The offset in bytes. - std::function fn_get_pack_rhs_packed_rhs_offset; + std::function fn_get_pack_rhs_packed_rhs_offset{nullptr}; /// Gets the offset in bytes of the packed RHS matrix in the main kernel. /// @@ -182,12 +182,12 @@ struct MatMulMethod { /// @param[in] k Size of the matrix in K dimension. /// /// @return The offset in bytes. - std::function fn_get_main_packed_rhs_offset; + std::function fn_get_main_packed_rhs_offset{nullptr}; std::function - fn_pack_rhs; + fn_pack_rhs{nullptr}; /// Gets n step value. /// @@ -259,7 +259,7 @@ struct MatMulMethod { /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. - std::function fn_get_bias_offset; + std::function fn_get_bias_offset{nullptr}; /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -268,7 +268,7 @@ struct MatMulMethod { /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. - std::function fn_get_dst_offset; + std::function fn_get_dst_offset{nullptr}; /// Gets the size in bytes of the destination matrix buffer. /// @@ -276,7 +276,7 @@ struct MatMulMethod { /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. - std::function fn_get_dst_size; + std::function fn_get_dst_size{nullptr}; /// Performs F16 or F32 matrix multiplication with RHS packing /// followed by clamp operation. diff --git a/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp index 038e361569949441f7540600f23c214fff423fde..548484a7934b5a8a0b18bb10f9ee4ffe34672d1c 100644 --- a/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp @@ -39,95 +39,84 @@ namespace kai::test { /// List of supported matrix multiplication methods. namespace { -const std::array matmul_methods = { - MatMulMethod{ - .name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP16), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::FP16), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla_opt_bias", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP16), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::UNKNOWN), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }}; +std::array matmul_methods{}; + +struct MatMulMethodInitializer { + MatMulMethodInitializer() { + matmul_methods[0].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla"; + matmul_methods[0].m0 = 8; + matmul_methods[0].n0 = 12; + matmul_methods[0].k0 = 4; + matmul_methods[0].dst_format = DataFormat(DataType::FP16); + matmul_methods[0].lhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + matmul_methods[0].rhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4); + matmul_methods[0].bias_format = DataFormat(DataType::FP16); + matmul_methods[0].fn_is_supported = cpu_has_bf16; + matmul_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[0].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + matmul_methods[1].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla_opt_bias"; + matmul_methods[1].m0 = 8; + matmul_methods[1].n0 = 12; + matmul_methods[1].k0 = 4; + matmul_methods[1].dst_format = DataFormat(DataType::FP16); + matmul_methods[1].lhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + matmul_methods[1].rhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4); + matmul_methods[1].bias_format = DataFormat(DataType::UNKNOWN); + matmul_methods[1].fn_is_supported = cpu_has_bf16; + matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + }; +}; + +MatMulMethodInitializer init{}; + } // namespace /// Matrix multiplication test fixture. @@ -214,15 +203,14 @@ protected: method.dst_format.data_type(), // info.m, info.n, info.k, false /* lhs_transposed */, false /* rhs_transposed */); - const auto& data = _data[data_id] = { - .lhs = std::move(lhs), - .ref_packed_lhs = std::move(ref_packed_lhs), - .rhs = std::move(rhs), - .rhs_scales = std::move(rhs_scales), - .bias = std::move(bias), - .ref_packed_rhs = std::move(packed_rhs), - .ref_dst = std::move(ref_dst), - }; + auto& data = _data[data_id] = {}; + data.lhs = std::move(lhs); + data.ref_packed_lhs = std::move(ref_packed_lhs); + data.rhs = std::move(rhs); + data.rhs_scales = std::move(rhs_scales); + data.bias = std::move(bias); + data.ref_packed_rhs = std::move(packed_rhs); + data.ref_dst = std::move(ref_dst); return data; } diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 9fff8005fb61150844e0c99a01a29fc00a3cee45..a17399ef285258e534967409ed8e31a3c04c4045 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -50,330 +50,265 @@ namespace kai::test { /// List of supported matrix multiplication methods. namespace { -const std::array gemm_methods = { - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", - - .m0 = 2 * get_sme_vector_length(), - .n0 = 2 * get_sme_vector_length(), - .k0 = 2, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = DataFormat( +std::array gemm_methods{}; +std::array gemv_methods{}; + +struct MatMulMethodInitializer { + MatMulMethodInitializer() { + gemm_methods[0].name = "matmul_nt_nt_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa"; + gemm_methods[0].m0 = 2 * get_sme_vector_length(); + gemm_methods[0].n0 = 2 * get_sme_vector_length(); + gemm_methods[0].k0 = 2; + gemm_methods[0].dst_format = DataFormat(DataType::FP32); + gemm_methods[0].lhs_format = DataFormat(DataType::FP32); + gemm_methods[0].packed_lhs_format = DataFormat( DataType::BF16, 2 * get_sme_vector_length(), 2, DataFormat::PackFormat::NONE, DataType::FP32, - DataType::UNKNOWN, 2 * get_sme_vector_length(), 2), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( + DataType::UNKNOWN, 2 * get_sme_vector_length(), 2); + gemm_methods[0].rhs_format = DataFormat(DataType::FP32); + gemm_methods[0].packed_rhs_format = DataFormat( DataType::BF16, 2 * get_sme_vector_length(), 2, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, - DataType::UNKNOWN, 2 * get_sme_vector_length(), 2), - .bias_format = DataFormat(DataType::FP32), - .fn_is_supported = cpu_has_sme2, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_pack_lhs = kai_run_lhs_pack_bf16p2vlx2_f32_sme, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::FP32), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_packed_rhs_size = nullptr, - .fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::FP32), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_packed_rhs_size_generic_block_size = nullptr, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output_opt_bias", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::UNKNOWN), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_packed_rhs_size_generic_block_size = nullptr, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::UNKNOWN), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_packed_rhs_size = nullptr, - .fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }}; - -const std::array gemv_methods = { - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot", - - .m0 = 1, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = - DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::FP32), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_packed_rhs_size = nullptr, - .fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias", - - .m0 = 1, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = - DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::UNKNOWN), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_packed_rhs_size = nullptr, - .fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - }}; + DataType::UNKNOWN, 2 * get_sme_vector_length(), 2); + gemm_methods[0].bias_format = DataFormat(DataType::FP32); + gemm_methods[0].fn_is_supported = cpu_has_sme2; + gemm_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme; + gemm_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme; + gemm_methods[0].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p2vlx2_f32_sme; + gemm_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + + gemm_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla"; + gemm_methods[1].m0 = 8; + gemm_methods[1].n0 = 12; + gemm_methods[1].k0 = 4; + gemm_methods[1].dst_format = DataFormat(DataType::FP32); + gemm_methods[1].lhs_format = DataFormat(DataType::FP32); + gemm_methods[1].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); + gemm_methods[1].rhs_format = DataFormat(DataType::FP32); + gemm_methods[1].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[1].bias_format = DataFormat(DataType::FP32); + gemm_methods[1].fn_is_supported = cpu_has_bf16; + gemm_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + gemm_methods[2].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output"; + gemm_methods[2].m0 = 8; + gemm_methods[2].n0 = 12; + gemm_methods[2].k0 = 4; + gemm_methods[2].dst_format = DataFormat(DataType::FP32); + gemm_methods[2].lhs_format = DataFormat(DataType::FP16); + gemm_methods[2].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + gemm_methods[2].rhs_format = DataFormat(DataType::FP16); + gemm_methods[2].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[2].bias_format = DataFormat(DataType::FP32); + gemm_methods[2].fn_is_supported = cpu_has_bf16; + gemm_methods[2].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[2].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[2].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + gemm_methods[3].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output_opt_bias"; + gemm_methods[3].m0 = 8; + gemm_methods[3].n0 = 12; + gemm_methods[3].k0 = 4; + gemm_methods[3].dst_format = DataFormat(DataType::FP32); + gemm_methods[3].lhs_format = DataFormat(DataType::FP16); + gemm_methods[3].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + gemm_methods[3].rhs_format = DataFormat(DataType::FP16); + gemm_methods[3].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[3].bias_format = DataFormat(DataType::UNKNOWN); + gemm_methods[3].fn_is_supported = cpu_has_bf16; + gemm_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[3].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + gemm_methods[4].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias"; + gemm_methods[4].m0 = 8; + gemm_methods[4].n0 = 12; + gemm_methods[4].k0 = 4; + gemm_methods[4].dst_format = DataFormat(DataType::FP32); + gemm_methods[4].lhs_format = DataFormat(DataType::FP32); + gemm_methods[4].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); + gemm_methods[4].rhs_format = DataFormat(DataType::FP32); + gemm_methods[4].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[4].bias_format = DataFormat(DataType::UNKNOWN); + gemm_methods[4].fn_is_supported = cpu_has_bf16; + gemm_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[4].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + gemv_methods[0].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot"; + gemv_methods[0].m0 = 1; + gemv_methods[0].n0 = 12; + gemv_methods[0].k0 = 4; + gemv_methods[0].dst_format = DataFormat(DataType::FP32); + gemv_methods[0].lhs_format = DataFormat(DataType::FP32); + gemv_methods[0].packed_lhs_format = + DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); + gemv_methods[0].rhs_format = DataFormat(DataType::FP32); + gemv_methods[0].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemv_methods[0].bias_format = DataFormat(DataType::FP32); + gemv_methods[0].fn_is_supported = cpu_has_bf16; + gemv_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[0].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + + gemv_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias"; + gemv_methods[1].m0 = 1; + gemv_methods[1].n0 = 12; + gemv_methods[1].k0 = 4; + gemv_methods[1].dst_format = DataFormat(DataType::FP32); + gemv_methods[1].lhs_format = DataFormat(DataType::FP32); + gemv_methods[1].packed_lhs_format = + DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); + gemv_methods[1].rhs_format = DataFormat(DataType::FP32); + gemv_methods[1].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemv_methods[1].bias_format = DataFormat(DataType::UNKNOWN); + gemv_methods[1].fn_is_supported = cpu_has_bf16; + gemv_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + }; +}; + +MatMulMethodInitializer init{}; } // namespace /// Matrix multiplication test fixture. @@ -469,15 +404,14 @@ protected: method.dst_format.data_type(), // info.m, info.n, info.k, false, false); - const auto& data = _data[data_id] = { - .lhs = std::move(lhs), - .ref_packed_lhs = std::move(ref_packed_lhs), - .rhs = std::move(rhs), - .rhs_scales = std::move(rhs_scales), - .bias = std::move(bias), - .ref_packed_rhs = std::move(packed_rhs), - .ref_dst = std::move(ref_dst), - }; + auto& data = _data[data_id] = {}; + data.lhs = std::move(lhs); + data.ref_packed_lhs = std::move(ref_packed_lhs); + data.rhs = std::move(rhs); + data.rhs_scales = std::move(rhs_scales); + data.bias = std::move(bias); + data.ref_packed_rhs = std::move(packed_rhs); + data.ref_dst = std::move(ref_dst); return data; } diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index c57b34890a16598e46c1fecb285597275016ef13..f65bb9c9747a5440c57eec88d575bc010d929944 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -61,290 +61,205 @@ namespace kai::test { /// List of supported matrix multiplication methods. -static const std::array matmul_methods = { - MatMulMethod{ - .name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla", - - .m0 = 6, - .n0 = 16, - - .dst_format = DataFormat(DataType::FP16), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = DataFormat(DataType::UNKNOWN), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( - DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1), - .bias_format = DataFormat(DataType::FP16), - - .fn_is_supported = cpu_has_fp16, - .fn_get_mr = nullptr, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - - .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_packed_lhs_size = nullptr, - .fn_get_packed_lhs_offset = nullptr, - .fn_pack_lhs = nullptr, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - - .fn_pack_rhs_nxk_get_n_step = nullptr, - .fn_pack_rhs_nxk_get_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_bias_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_size = nullptr, - .fn_pack_rhs_nxk = nullptr, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - - .fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_matmul_f32_f32_f32p = nullptr, - .fn_matmul_f16_f16p_f16p = nullptr, - .fn_matmul_f32_f32p_f32p = nullptr, - }, - - MatMulMethod{ - .name = "matmul_nt_nt_f16_f16p_f16p_2vlx2vl_sme2_mopa", - - .m0 = 2 * get_sme_vector_length(), - .n0 = 2 * get_sme_vector_length(), - - .dst_format = DataFormat(DataType::FP16), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length(), 2), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( +std::array matmul_methods{}; + +/// List of supported vector by matrix multiplication methods +std::array vecmul_methods{}; + +struct MatMulMethodInitializer { + MatMulMethodInitializer() { + matmul_methods[0].name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla"; + matmul_methods[0].m0 = 6; + matmul_methods[0].n0 = 16; + matmul_methods[0].dst_format = DataFormat(DataType::FP16); + matmul_methods[0].lhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); + matmul_methods[0].rhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_rhs_format = DataFormat( + DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1); + matmul_methods[0].bias_format = DataFormat(DataType::FP16); + matmul_methods[0].fn_is_supported = cpu_has_fp16; + matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + + matmul_methods[1].name = "matmul_nt_nt_f16_f16p_f16p_2vlx2vl_sme2_mopa"; + matmul_methods[1].m0 = 2 * get_sme_vector_length(); + matmul_methods[1].n0 = 2 * get_sme_vector_length(); + matmul_methods[1].dst_format = DataFormat(DataType::FP16); + matmul_methods[1].lhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length(), 2); + matmul_methods[1].rhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_rhs_format = DataFormat( DataType::FP16, // Output type 2 * get_sme_vector_length(), 2, // Block size DataFormat::PackFormat::BIAS_PER_ROW, // Data layout DataType::FP16, // Bias format DataType::UNKNOWN, // Scaling type - 2 * get_sme_vector_length(), 2), // Sub-block - .bias_format = DataFormat(DataType::FP16), - - .fn_is_supported = cpu_has_sme2, - .fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - - .fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - .fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - .fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - .fn_pack_rhs_nxk_get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - .fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - .fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_matmul_f16_f16_f16p = nullptr, - .fn_matmul_f32_f32_f32p = nullptr, - .fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_matmul_f32_f32p_f32p = nullptr, - }, - - MatMulMethod{ - .name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla", - - .m0 = 6, - .n0 = 8, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = DataFormat(DataType::UNKNOWN), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( - DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1), - .bias_format = DataFormat(DataType::FP32), - - .fn_is_supported = cpu_has_advsimd, - .fn_get_mr = nullptr, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - - .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_packed_lhs_size = nullptr, - .fn_get_packed_lhs_offset = nullptr, - .fn_pack_lhs = nullptr, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - - .fn_pack_rhs_nxk_get_n_step = nullptr, - .fn_pack_rhs_nxk_get_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_bias_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_size = nullptr, - .fn_pack_rhs_nxk = nullptr, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - - .fn_matmul_f16_f16_f16p = nullptr, - .fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_matmul_f16_f16p_f16p = nullptr, - .fn_matmul_f32_f32p_f32p = nullptr, - }, - - MatMulMethod{ - .name = "matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa", - - .m0 = 2 * get_sme_vector_length(), - .n0 = 2 * get_sme_vector_length(), - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length(), 1), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( + 2 * get_sme_vector_length(), 2); // Sub-block + matmul_methods[1].bias_format = DataFormat(DataType::FP16); + matmul_methods[1].fn_is_supported = cpu_has_sme2; + matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + + matmul_methods[2].name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla"; + matmul_methods[2].m0 = 6; + matmul_methods[2].n0 = 8; + matmul_methods[2].dst_format = DataFormat(DataType::FP32); + matmul_methods[2].lhs_format = DataFormat(DataType::FP32); + matmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN); + matmul_methods[2].rhs_format = DataFormat(DataType::FP32); + matmul_methods[2].packed_rhs_format = DataFormat( + DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1); + matmul_methods[2].bias_format = DataFormat(DataType::FP32); + matmul_methods[2].fn_is_supported = cpu_has_advsimd; + matmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + + matmul_methods[3].name = "matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa"; + matmul_methods[3].m0 = 2 * get_sme_vector_length(); + matmul_methods[3].n0 = 2 * get_sme_vector_length(); + matmul_methods[3].dst_format = DataFormat(DataType::FP32); + matmul_methods[3].lhs_format = DataFormat(DataType::FP32); + matmul_methods[3].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length(), 1); + matmul_methods[3].rhs_format = DataFormat(DataType::FP32); + matmul_methods[3].packed_rhs_format = DataFormat( DataType::FP32, 2 * get_sme_vector_length(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, - DataType::UNKNOWN, 2 * get_sme_vector_length(), 1), - .bias_format = DataFormat(DataType::FP32), - - .fn_is_supported = cpu_has_sme2, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - .fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - - .fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - .fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - .fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - .fn_pack_rhs_nxk_get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - .fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - .fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - - .fn_matmul_f16_f16_f16p = nullptr, - .fn_matmul_f32_f32_f32p = nullptr, - .fn_matmul_f16_f16p_f16p = nullptr, - .fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - }, -}; - -/// List of supported vector by meatrix multiplication methods -static const std::array vecmul_methods{ - MatMulMethod{ - .name = "vecmul_kxn_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", - - .m0 = 1, - .n0 = 16 * get_sme_vector_length(), - - .dst_format = DataFormat(DataType::FP16), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = DataFormat(DataType::UNKNOWN), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( + DataType::UNKNOWN, 2 * get_sme_vector_length(), 1); + matmul_methods[3].bias_format = DataFormat(DataType::FP32); + matmul_methods[3].fn_is_supported = cpu_has_sme2; + matmul_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme; + matmul_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme; + matmul_methods[3].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme; + matmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_bias_offset = + kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + + vecmul_methods[0].name = "vecmul_kxn_f16_f16_f16p2vlx2b_1x16vl_sme2_dot"; + vecmul_methods[0].m0 = 1; + vecmul_methods[0].n0 = 16 * get_sme_vector_length(); + vecmul_methods[0].dst_format = DataFormat(DataType::FP16); + vecmul_methods[0].lhs_format = DataFormat(DataType::FP16); + vecmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); + vecmul_methods[0].rhs_format = DataFormat(DataType::FP16); + vecmul_methods[0].packed_rhs_format = DataFormat( DataType::FP16, // Output type 2 * get_sme_vector_length(), 2, // Block size DataFormat::PackFormat::BIAS_PER_ROW, // Data layout DataType::FP16, // Bias format DataType::UNKNOWN, // Scaling type - 2 * get_sme_vector_length(), 2), // Sub-block - .bias_format = DataFormat(DataType::FP16), - - .fn_is_supported = cpu_has_sme2, - .fn_get_mr = nullptr, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme, - .fn_get_packed_lhs_size = nullptr, - .fn_get_packed_lhs_offset = nullptr, - .fn_pack_lhs = nullptr, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - - .fn_pack_rhs_nxk_get_n_step = nullptr, - .fn_pack_rhs_nxk_get_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_bias_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_size = nullptr, - .fn_pack_rhs_nxk = nullptr, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - - .fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_matmul_f32_f32_f32p = nullptr, - .fn_matmul_f16_f16p_f16p = nullptr, - .fn_matmul_f32_f32p_f32p = nullptr, - }, - + 2 * get_sme_vector_length(), 2); // Sub-block + vecmul_methods[0].bias_format = DataFormat(DataType::FP16); + vecmul_methods[0].fn_is_supported = cpu_has_sme2; + vecmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; + vecmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + }; }; +MatMulMethodInitializer init{}; + /// Matrix multiplication test fixture. class MatMulTest : public testing::TestWithParam { private: @@ -462,18 +377,17 @@ protected: KAI_ERROR("Unsupported data type!"); } - const auto& data = _data[data_id] = { - .lhs = std::move(lhs), - .ref_packed_lhs = std::move(ref_packed_lhs), - .rhs = std::move(rhs), - .rhs_scales = std::move(rhs_scales), - .bias = std::move(bias), - .rhs_t = std::move(rhs_t), - .ref_packed_rhs = std::move(packed_rhs), - .ref_dst = std::move(ref_dst), - .clamp_min = clamp_min, - .clamp_max = clamp_max, - }; + auto& data = _data[data_id] = {}; + data.lhs = std::move(lhs); + data.ref_packed_lhs = std::move(ref_packed_lhs); + data.rhs = std::move(rhs); + data.rhs_scales = std::move(rhs_scales); + data.bias = std::move(bias); + data.rhs_t = std::move(rhs_t); + data.ref_packed_rhs = std::move(packed_rhs); + data.ref_dst = std::move(ref_dst); + data.clamp_min = clamp_min; + data.clamp_max = clamp_max; return data; }