From a91e75aeb39e6429421d40e9ab9caadad41456c5 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 12 Feb 2025 11:32:21 +0100 Subject: [PATCH 1/8] Remove use of designated initializers in certain unit tests Removes designated initializers from - matmul_test.cpp - matmul_clamp_f16_bf16p_bf16p_test.cpp - matmul_clamp_f32_bf16p_bf16p_test.cpp Following changes are made to the test framework: - Added default constructor to DataFormat class - Initialize members of struct MatMulMethod Signed-off-by: Jens Elofsson --- test/common/data_format.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp index 1a9b9483..7ff2596a 100644 --- a/test/common/data_format.hpp +++ b/test/common/data_format.hpp @@ -23,6 +23,7 @@ public: QUANTIZE_PER_ROW, ///< Per-row quantization. }; + DataFormat() = default; /// Creates a new data format. /// /// @param[in] data_type Data type of data value. -- GitLab From ee3910624a3a8d85e520d0f35d16a7f9d9aeb8f6 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 12 Feb 2025 11:35:55 +0100 Subject: [PATCH 2/8] Initialize MatMulMethod members. Signed-off-by: Jens Elofsson --- test/common/matmul_test_common.hpp | 42 +++++++++++++++--------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index dfdc56d1..a9a56c32 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -28,7 +28,7 @@ struct MatMulShape { /// Matrix multiplication method. struct MatMulMethod { - std::string_view name; ///< Name of matmul method. + std::string_view name = ""; ///< Name of matmul method. size_t m0{0}; ///< Block size in M dimension. size_t n0{0}; ///< Block size in N dimension. @@ -44,56 +44,56 @@ struct MatMulMethod { /// Check if CPU supports required features. /// /// @return Supported (true) or not supported (false). - std::function fn_is_supported; + std::function fn_is_supported{nullptr}; /// Gets mr value. /// /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). /// /// @return The mr value. - std::function fn_get_mr; + std::function fn_get_mr{nullptr}; /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). /// /// @return The nr value. - std::function fn_get_nr; + std::function fn_get_nr{nullptr}; /// Gets kr value. /// /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). /// /// @return The kr value. - std::function fn_get_kr; + std::function fn_get_kr{nullptr}; /// Gets sr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The sr value. - std::function fn_get_sr; + std::function fn_get_sr{nullptr}; /// Gets m step value for main kernel. /// /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. - std::function fn_get_main_m_step; + std::function fn_get_main_m_step{nullptr}; /// Gets n step value for RHS packing kernel. /// /// The starting row index must be divisible by `n_step`. /// /// @return The n step value. - std::function fn_get_pack_rhs_n_step; + std::function fn_get_pack_rhs_n_step{nullptr}; /// Gets n step value for main kernel. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. - std::function fn_get_main_n_step; + std::function fn_get_main_n_step{nullptr}; /// Gets the offset in bytes of the LHS matrix. /// @@ -101,7 +101,7 @@ struct MatMulMethod { /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes. - std::function fn_get_lhs_offset; + std::function fn_get_lhs_offset{nullptr}; /// Gets the size in bytes of the packed LHS matrix. /// @@ -112,7 +112,7 @@ struct MatMulMethod { /// @param[in] sr Unused. Must be 1. /// /// @return The size in bytes. - std::function fn_get_packed_lhs_size; + std::function fn_get_packed_lhs_size{nullptr}; /// Gets the offset in bytes of the packed LHS matrix. /// @@ -120,7 +120,7 @@ struct MatMulMethod { /// @param[in] k Size of the matrix in K dimension. /// /// @return The offset in bytes. - std::function fn_get_packed_lhs_offset; + std::function fn_get_packed_lhs_offset{nullptr}; /// Preprocesses the LHS matrix. /// @@ -136,7 +136,7 @@ struct MatMulMethod { std::function - fn_pack_lhs; + fn_pack_lhs{nullptr}; /// Gets a value indicating whether LHS packing is needed. [[nodiscard]] bool is_pack_lhs_needed() const { @@ -148,7 +148,7 @@ struct MatMulMethod { /// @param[in] n_idx Coordinate of the matrix in N dimension. /// /// @return The offset in bytes. - std::function fn_get_rhs_offset; + std::function fn_get_rhs_offset{nullptr}; /// Gets the size in bytes of the packed RHS matrix. /// @@ -156,7 +156,7 @@ struct MatMulMethod { /// @param[in] k Size of the matrix in K dimension. /// /// @return The size in bytes. - std::function fn_get_packed_rhs_size; + std::function fn_get_packed_rhs_size{nullptr}; /// Gets the size in bytes of the packed RHS matrix. /// @@ -174,7 +174,7 @@ struct MatMulMethod { /// @param[in] k Size of the matrix in K dimension. /// /// @return The offset in bytes. - std::function fn_get_pack_rhs_packed_rhs_offset; + std::function fn_get_pack_rhs_packed_rhs_offset{nullptr}; /// Gets the offset in bytes of the packed RHS matrix in the main kernel. /// @@ -182,12 +182,12 @@ struct MatMulMethod { /// @param[in] k Size of the matrix in K dimension. /// /// @return The offset in bytes. - std::function fn_get_main_packed_rhs_offset; + std::function fn_get_main_packed_rhs_offset{nullptr}; std::function - fn_pack_rhs; + fn_pack_rhs{nullptr}; /// Gets n step value. /// @@ -259,7 +259,7 @@ struct MatMulMethod { /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. - std::function fn_get_bias_offset; + std::function fn_get_bias_offset{nullptr}; /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -268,7 +268,7 @@ struct MatMulMethod { /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. - std::function fn_get_dst_offset; + std::function fn_get_dst_offset{nullptr}; /// Gets the size in bytes of the destination matrix buffer. /// @@ -276,7 +276,7 @@ struct MatMulMethod { /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. - std::function fn_get_dst_size; + std::function fn_get_dst_size{nullptr}; /// Performs F16 or F32 matrix multiplication with RHS packing /// followed by clamp operation. -- GitLab From 46aaa18eec0f346b0ab02235d5d606ba684ef2f7 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 26 Feb 2025 16:00:45 +0100 Subject: [PATCH 3/8] Remove the use of designated initializers Remove the use of designated initializers from matmul_clamp_f16_bf16p_bf16p_test to comply with C++ 17 standard. Signed-off-by: Jens Elofsson --- .../matmul_clamp_f16_bf16p_bf16p_test.cpp | 186 +++++++++--------- 1 file changed, 88 insertions(+), 98 deletions(-) diff --git a/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp index 038e3615..e27ea6b4 100644 --- a/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp @@ -39,95 +39,86 @@ namespace kai::test { /// List of supported matrix multiplication methods. namespace { -const std::array matmul_methods = { - MatMulMethod{ - .name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP16), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::FP16), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla_opt_bias", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP16), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::UNKNOWN), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }}; +std::array matmul_methods; + +struct MatMulMethodLoader { + MatMulMethodLoader() { + matmul_methods[0].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla"; + matmul_methods[0].m0 = 8; + matmul_methods[0].n0 = 12; + matmul_methods[0].k0 = 4; + matmul_methods[0].dst_format = DataFormat(DataType::FP16); + matmul_methods[0].lhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + matmul_methods[0].rhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4); + matmul_methods[0].bias_format = DataFormat(DataType::FP16); + matmul_methods[0].fn_is_supported = cpu_has_bf16; + matmul_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[0].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = nullptr; + matmul_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[0].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + matmul_methods[1].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla_opt_bias"; + matmul_methods[1].m0 = 8; + matmul_methods[1].n0 = 12; + matmul_methods[1].k0 = 4; + matmul_methods[1].dst_format = DataFormat(DataType::FP16); + matmul_methods[1].lhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + matmul_methods[1].rhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4); + matmul_methods[1].bias_format = DataFormat(DataType::UNKNOWN); + matmul_methods[1].fn_is_supported = cpu_has_bf16; + matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = nullptr; + matmul_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; + matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + matmul_methods[1].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + }; +}; + +MatMulMethodLoader loader; + } // namespace /// Matrix multiplication test fixture. @@ -214,15 +205,14 @@ protected: method.dst_format.data_type(), // info.m, info.n, info.k, false /* lhs_transposed */, false /* rhs_transposed */); - const auto& data = _data[data_id] = { - .lhs = std::move(lhs), - .ref_packed_lhs = std::move(ref_packed_lhs), - .rhs = std::move(rhs), - .rhs_scales = std::move(rhs_scales), - .bias = std::move(bias), - .ref_packed_rhs = std::move(packed_rhs), - .ref_dst = std::move(ref_dst), - }; + auto& data = _data[data_id] = {}; + data.lhs = std::move(lhs); + data.ref_packed_lhs = std::move(ref_packed_lhs); + data.rhs = std::move(rhs); + data.rhs_scales = std::move(rhs_scales); + data.bias = std::move(bias); + data.ref_packed_rhs = std::move(packed_rhs); + data.ref_dst = std::move(ref_dst); return data; } -- GitLab From 9580d68097df494bbc845ad7038794af240cde26 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 26 Feb 2025 16:01:44 +0100 Subject: [PATCH 4/8] Remove the use of designated initializers Remove the use of designated initializers from matmul_clamp_f32_bf16p_bf16p_test to comply with C++ 17 standard. Signed-off-by: Jens Elofsson --- .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 609 ++++++++---------- 1 file changed, 278 insertions(+), 331 deletions(-) diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 9fff8005..718a4a66 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -50,330 +50,278 @@ namespace kai::test { /// List of supported matrix multiplication methods. namespace { -const std::array gemm_methods = { - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", - - .m0 = 2 * get_sme_vector_length(), - .n0 = 2 * get_sme_vector_length(), - .k0 = 2, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = DataFormat( +std::array gemm_methods; +std::array gemv_methods; + +struct MatMulMethodLoader { + MatMulMethodLoader() { + gemm_methods[0].name = "matmul_nt_nt_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa"; + gemm_methods[0].m0 = 2 * get_sme_vector_length(); + gemm_methods[0].n0 = 2 * get_sme_vector_length(); + gemm_methods[0].k0 = 2; + gemm_methods[0].dst_format = DataFormat(DataType::FP32); + gemm_methods[0].lhs_format = DataFormat(DataType::FP32); + gemm_methods[0].packed_lhs_format = DataFormat( DataType::BF16, 2 * get_sme_vector_length(), 2, DataFormat::PackFormat::NONE, DataType::FP32, - DataType::UNKNOWN, 2 * get_sme_vector_length(), 2), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( + DataType::UNKNOWN, 2 * get_sme_vector_length(), 2); + gemm_methods[0].rhs_format = DataFormat(DataType::FP32); + gemm_methods[0].packed_rhs_format = DataFormat( DataType::BF16, 2 * get_sme_vector_length(), 2, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, - DataType::UNKNOWN, 2 * get_sme_vector_length(), 2), - .bias_format = DataFormat(DataType::FP32), - .fn_is_supported = cpu_has_sme2, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_pack_lhs = kai_run_lhs_pack_bf16p2vlx2_f32_sme, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::FP32), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_packed_rhs_size = nullptr, - .fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::FP32), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_packed_rhs_size_generic_block_size = nullptr, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output_opt_bias", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::UNKNOWN), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - .fn_get_packed_rhs_size_generic_block_size = nullptr, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", - - .m0 = 8, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = - DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::UNKNOWN), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_packed_rhs_size = nullptr, - .fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - }}; - -const std::array gemv_methods = { - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot", - - .m0 = 1, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = - DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::FP32), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_packed_rhs_size = nullptr, - .fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - }, - MatMulMethod{ - .name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias", - - .m0 = 1, - .n0 = 12, - .k0 = 4, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = - DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( - DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4), - .bias_format = DataFormat(DataType::UNKNOWN), - .fn_is_supported = cpu_has_bf16, - - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_packed_rhs_size = nullptr, - .fn_get_packed_rhs_size_generic_block_size = - kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot, - }}; + DataType::UNKNOWN, 2 * get_sme_vector_length(), 2); + gemm_methods[0].bias_format = DataFormat(DataType::FP32); + gemm_methods[0].fn_is_supported = cpu_has_sme2; + gemm_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme; + gemm_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme; + gemm_methods[0].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p2vlx2_f32_sme; + gemm_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_pack_rhs_packed_rhs_offset = nullptr; + gemm_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; + gemm_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + gemm_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; + + gemm_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla"; + gemm_methods[1].m0 = 8; + gemm_methods[1].n0 = 12; + gemm_methods[1].k0 = 4; + gemm_methods[1].dst_format = DataFormat(DataType::FP32); + gemm_methods[1].lhs_format = DataFormat(DataType::FP32); + gemm_methods[1].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); + gemm_methods[1].rhs_format = DataFormat(DataType::FP32); + gemm_methods[1].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[1].bias_format = DataFormat(DataType::FP32); + gemm_methods[1].fn_is_supported = cpu_has_bf16; + gemm_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_packed_rhs_size = nullptr; + gemm_methods[1].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_pack_rhs_packed_rhs_offset = nullptr; + gemm_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + gemm_methods[2].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output"; + gemm_methods[2].m0 = 8; + gemm_methods[2].n0 = 12; + gemm_methods[2].k0 = 4; + gemm_methods[2].dst_format = DataFormat(DataType::FP32); + gemm_methods[2].lhs_format = DataFormat(DataType::FP16); + gemm_methods[2].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + gemm_methods[2].rhs_format = DataFormat(DataType::FP16); + gemm_methods[2].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[2].bias_format = DataFormat(DataType::FP32); + gemm_methods[2].fn_is_supported = cpu_has_bf16; + gemm_methods[2].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[2].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[2].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_packed_rhs_size_generic_block_size = nullptr; + gemm_methods[2].fn_get_pack_rhs_packed_rhs_offset = nullptr; + gemm_methods[2].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[2].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + gemm_methods[3].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output_opt_bias"; + gemm_methods[3].m0 = 8; + gemm_methods[3].n0 = 12; + gemm_methods[3].k0 = 4; + gemm_methods[3].dst_format = DataFormat(DataType::FP32); + gemm_methods[3].lhs_format = DataFormat(DataType::FP16); + gemm_methods[3].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); + gemm_methods[3].rhs_format = DataFormat(DataType::FP16); + gemm_methods[3].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[3].bias_format = DataFormat(DataType::UNKNOWN); + gemm_methods[3].fn_is_supported = cpu_has_bf16; + gemm_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[3].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; + gemm_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_packed_rhs_size_generic_block_size = nullptr; + gemm_methods[3].fn_get_pack_rhs_packed_rhs_offset = nullptr; + gemm_methods[3].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; + gemm_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[3].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + gemm_methods[4].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias"; + gemm_methods[4].m0 = 8; + gemm_methods[4].n0 = 12; + gemm_methods[4].k0 = 4; + gemm_methods[4].dst_format = DataFormat(DataType::FP32); + gemm_methods[4].lhs_format = DataFormat(DataType::FP32); + gemm_methods[4].packed_lhs_format = + DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); + gemm_methods[4].rhs_format = DataFormat(DataType::FP32); + gemm_methods[4].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemm_methods[4].bias_format = DataFormat(DataType::UNKNOWN); + gemm_methods[4].fn_is_supported = cpu_has_bf16; + gemm_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[4].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; + gemm_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_packed_rhs_size = nullptr; + gemm_methods[4].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_pack_rhs_packed_rhs_offset = nullptr; + gemm_methods[4].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemm_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + gemm_methods[4].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; + + gemv_methods[0].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot"; + gemv_methods[0].m0 = 1; + gemv_methods[0].n0 = 12; + gemv_methods[0].k0 = 4; + gemv_methods[0].dst_format = DataFormat(DataType::FP32); + gemv_methods[0].lhs_format = DataFormat(DataType::FP32); + gemv_methods[0].packed_lhs_format = + DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); + gemv_methods[0].rhs_format = DataFormat(DataType::FP32); + gemv_methods[0].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemv_methods[0].bias_format = DataFormat(DataType::FP32); + gemv_methods[0].fn_is_supported = cpu_has_bf16; + gemv_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[0].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_packed_rhs_size = nullptr; + gemv_methods[0].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_pack_rhs_packed_rhs_offset = nullptr; + gemv_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + + gemv_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias"; + gemv_methods[1].m0 = 1; + gemv_methods[1].n0 = 12; + gemv_methods[1].k0 = 4; + gemv_methods[1].dst_format = DataFormat(DataType::FP32); + gemv_methods[1].lhs_format = DataFormat(DataType::FP32); + gemv_methods[1].packed_lhs_format = + DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); + gemv_methods[1].rhs_format = DataFormat(DataType::FP32); + gemv_methods[1].packed_rhs_format = DataFormat( + DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); + gemv_methods[1].bias_format = DataFormat(DataType::UNKNOWN); + gemv_methods[1].fn_is_supported = cpu_has_bf16; + gemv_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; + gemv_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_packed_rhs_size = nullptr; + gemv_methods[1].fn_get_packed_rhs_size_generic_block_size = + kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_pack_rhs_packed_rhs_offset = nullptr; + gemv_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; + gemv_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + gemv_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + }; +}; + +MatMulMethodLoader loader; } // namespace /// Matrix multiplication test fixture. @@ -469,15 +417,14 @@ protected: method.dst_format.data_type(), // info.m, info.n, info.k, false, false); - const auto& data = _data[data_id] = { - .lhs = std::move(lhs), - .ref_packed_lhs = std::move(ref_packed_lhs), - .rhs = std::move(rhs), - .rhs_scales = std::move(rhs_scales), - .bias = std::move(bias), - .ref_packed_rhs = std::move(packed_rhs), - .ref_dst = std::move(ref_dst), - }; + auto& data = _data[data_id] = {}; + data.lhs = std::move(lhs); + data.ref_packed_lhs = std::move(ref_packed_lhs); + data.rhs = std::move(rhs); + data.rhs_scales = std::move(rhs_scales); + data.bias = std::move(bias); + data.ref_packed_rhs = std::move(packed_rhs); + data.ref_dst = std::move(ref_dst); return data; } -- GitLab From 8748c47ffb1c88dda6943e05141543b17db42870 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 26 Feb 2025 16:03:05 +0100 Subject: [PATCH 5/8] Remove the use of designated initializers. Remove the use of designated initializers from matmul_test to comply with C++ 17 standard. Signed-off-by: Jens Elofsson --- test/tests/matmul_test.cpp | 525 +++++++++++++++++-------------------- 1 file changed, 242 insertions(+), 283 deletions(-) diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index c57b3489..006abf3c 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -61,290 +61,250 @@ namespace kai::test { /// List of supported matrix multiplication methods. -static const std::array matmul_methods = { - MatMulMethod{ - .name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla", - - .m0 = 6, - .n0 = 16, - - .dst_format = DataFormat(DataType::FP16), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = DataFormat(DataType::UNKNOWN), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( - DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1), - .bias_format = DataFormat(DataType::FP16), - - .fn_is_supported = cpu_has_fp16, - .fn_get_mr = nullptr, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - - .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_packed_lhs_size = nullptr, - .fn_get_packed_lhs_offset = nullptr, - .fn_pack_lhs = nullptr, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - - .fn_pack_rhs_nxk_get_n_step = nullptr, - .fn_pack_rhs_nxk_get_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_bias_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_size = nullptr, - .fn_pack_rhs_nxk = nullptr, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - - .fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_matmul_f32_f32_f32p = nullptr, - .fn_matmul_f16_f16p_f16p = nullptr, - .fn_matmul_f32_f32p_f32p = nullptr, - }, - - MatMulMethod{ - .name = "matmul_nt_nt_f16_f16p_f16p_2vlx2vl_sme2_mopa", - - .m0 = 2 * get_sme_vector_length(), - .n0 = 2 * get_sme_vector_length(), - - .dst_format = DataFormat(DataType::FP16), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length(), 2), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( +std::array matmul_methods; + +/// List of supported vector by matrix multiplication methods +std::array vecmul_methods; + +struct MatMulMethodLoader { + MatMulMethodLoader() { + matmul_methods[0].name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla"; + matmul_methods[0].m0 = 6; + matmul_methods[0].n0 = 16; + matmul_methods[0].dst_format = DataFormat(DataType::FP16); + matmul_methods[0].lhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); + matmul_methods[0].rhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_rhs_format = DataFormat( + DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1); + matmul_methods[0].bias_format = DataFormat(DataType::FP16); + matmul_methods[0].fn_is_supported = cpu_has_fp16; + matmul_methods[0].fn_get_mr = nullptr; + matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_packed_lhs_size = nullptr; + matmul_methods[0].fn_get_packed_lhs_offset = nullptr; + matmul_methods[0].fn_pack_lhs = nullptr; + matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_pack_rhs_nxk_get_n_step = nullptr; + matmul_methods[0].fn_pack_rhs_nxk_get_rhs_offset = nullptr; + matmul_methods[0].fn_pack_rhs_nxk_get_bias_offset = nullptr; + matmul_methods[0].fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr; + matmul_methods[0].fn_pack_rhs_nxk_get_packed_rhs_size = nullptr; + matmul_methods[0].fn_pack_rhs_nxk = nullptr; + matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_matmul_f32_f32_f32p = nullptr; + matmul_methods[0].fn_matmul_f16_f16p_f16p = nullptr; + matmul_methods[0].fn_matmul_f32_f32p_f32p = nullptr; + + matmul_methods[1].name = "matmul_nt_nt_f16_f16p_f16p_2vlx2vl_sme2_mopa"; + matmul_methods[1].m0 = 2 * get_sme_vector_length(); + matmul_methods[1].n0 = 2 * get_sme_vector_length(); + matmul_methods[1].dst_format = DataFormat(DataType::FP16); + matmul_methods[1].lhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length(), 2); + matmul_methods[1].rhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_rhs_format = DataFormat( DataType::FP16, // Output type 2 * get_sme_vector_length(), 2, // Block size DataFormat::PackFormat::BIAS_PER_ROW, // Data layout DataType::FP16, // Bias format DataType::UNKNOWN, // Scaling type - 2 * get_sme_vector_length(), 2), // Sub-block - .bias_format = DataFormat(DataType::FP16), - - .fn_is_supported = cpu_has_sme2, - .fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - - .fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - .fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - .fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - .fn_pack_rhs_nxk_get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - .fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - .fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - - .fn_matmul_f16_f16_f16p = nullptr, - .fn_matmul_f32_f32_f32p = nullptr, - .fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa, - .fn_matmul_f32_f32p_f32p = nullptr, - }, - - MatMulMethod{ - .name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla", - - .m0 = 6, - .n0 = 8, - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = DataFormat(DataType::UNKNOWN), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( - DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1), - .bias_format = DataFormat(DataType::FP32), - - .fn_is_supported = cpu_has_advsimd, - .fn_get_mr = nullptr, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - - .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_packed_lhs_size = nullptr, - .fn_get_packed_lhs_offset = nullptr, - .fn_pack_lhs = nullptr, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - - .fn_pack_rhs_nxk_get_n_step = nullptr, - .fn_pack_rhs_nxk_get_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_bias_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_size = nullptr, - .fn_pack_rhs_nxk = nullptr, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - - .fn_matmul_f16_f16_f16p = nullptr, - .fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla, - .fn_matmul_f16_f16p_f16p = nullptr, - .fn_matmul_f32_f32p_f32p = nullptr, - }, - - MatMulMethod{ - .name = "matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa", - - .m0 = 2 * get_sme_vector_length(), - .n0 = 2 * get_sme_vector_length(), - - .dst_format = DataFormat(DataType::FP32), - .lhs_format = DataFormat(DataType::FP32), - .packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length(), 1), - .rhs_format = DataFormat(DataType::FP32), - .packed_rhs_format = DataFormat( + 2 * get_sme_vector_length(), 2); // Sub-block + matmul_methods[1].bias_format = DataFormat(DataType::FP16); + matmul_methods[1].fn_is_supported = cpu_has_sme2; + matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_matmul_f16_f16_f16p = nullptr; + matmul_methods[1].fn_matmul_f32_f32_f32p = nullptr; + matmul_methods[1].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_matmul_f32_f32p_f32p = nullptr; + + matmul_methods[2].name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla"; + matmul_methods[2].m0 = 6; + matmul_methods[2].n0 = 8; + matmul_methods[2].dst_format = DataFormat(DataType::FP32); + matmul_methods[2].lhs_format = DataFormat(DataType::FP32); + matmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN); + matmul_methods[2].rhs_format = DataFormat(DataType::FP32); + matmul_methods[2].packed_rhs_format = DataFormat( + DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1); + matmul_methods[2].bias_format = DataFormat(DataType::FP32); + matmul_methods[2].fn_is_supported = cpu_has_advsimd; + matmul_methods[2].fn_get_mr = nullptr; + matmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_packed_lhs_size = nullptr; + matmul_methods[2].fn_get_packed_lhs_offset = nullptr; + matmul_methods[2].fn_pack_lhs = nullptr; + matmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_pack_rhs_nxk_get_n_step = nullptr; + matmul_methods[2].fn_pack_rhs_nxk_get_rhs_offset = nullptr; + matmul_methods[2].fn_pack_rhs_nxk_get_bias_offset = nullptr; + matmul_methods[2].fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr; + matmul_methods[2].fn_pack_rhs_nxk_get_packed_rhs_size = nullptr; + matmul_methods[2].fn_pack_rhs_nxk = nullptr; + matmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_matmul_f16_f16_f16p = nullptr; + matmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_matmul_f16_f16p_f16p = nullptr; + matmul_methods[2].fn_matmul_f32_f32p_f32p = nullptr; + + matmul_methods[3].name = "matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa"; + matmul_methods[3].m0 = 2 * get_sme_vector_length(); + matmul_methods[3].n0 = 2 * get_sme_vector_length(); + matmul_methods[3].dst_format = DataFormat(DataType::FP32); + matmul_methods[3].lhs_format = DataFormat(DataType::FP32); + matmul_methods[3].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length(), 1); + matmul_methods[3].rhs_format = DataFormat(DataType::FP32); + matmul_methods[3].packed_rhs_format = DataFormat( DataType::FP32, 2 * get_sme_vector_length(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, - DataType::UNKNOWN, 2 * get_sme_vector_length(), 1), - .bias_format = DataFormat(DataType::FP32), - - .fn_is_supported = cpu_has_sme2, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - .fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - - .fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - .fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - .fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - .fn_pack_rhs_nxk_get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - .fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - .fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - - .fn_matmul_f16_f16_f16p = nullptr, - .fn_matmul_f32_f32_f32p = nullptr, - .fn_matmul_f16_f16p_f16p = nullptr, - .fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - }, -}; - -/// List of supported vector by meatrix multiplication methods -static const std::array vecmul_methods{ - MatMulMethod{ - .name = "vecmul_kxn_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", - - .m0 = 1, - .n0 = 16 * get_sme_vector_length(), - - .dst_format = DataFormat(DataType::FP16), - .lhs_format = DataFormat(DataType::FP16), - .packed_lhs_format = DataFormat(DataType::UNKNOWN), - .rhs_format = DataFormat(DataType::FP16), - .packed_rhs_format = DataFormat( + DataType::UNKNOWN, 2 * get_sme_vector_length(), 1); + matmul_methods[3].bias_format = DataFormat(DataType::FP32); + matmul_methods[3].fn_is_supported = cpu_has_sme2; + matmul_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme; + matmul_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme; + matmul_methods[3].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme; + matmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_bias_offset = + kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_matmul_f16_f16_f16p = nullptr; + matmul_methods[3].fn_matmul_f32_f32_f32p = nullptr; + matmul_methods[3].fn_matmul_f16_f16p_f16p = nullptr; + matmul_methods[3].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + + vecmul_methods[0].name = "vecmul_kxn_f16_f16_f16p2vlx2b_1x16vl_sme2_dot"; + vecmul_methods[0].m0 = 1; + vecmul_methods[0].n0 = 16 * get_sme_vector_length(); + vecmul_methods[0].dst_format = DataFormat(DataType::FP16); + vecmul_methods[0].lhs_format = DataFormat(DataType::FP16); + vecmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); + vecmul_methods[0].rhs_format = DataFormat(DataType::FP16); + vecmul_methods[0].packed_rhs_format = DataFormat( DataType::FP16, // Output type 2 * get_sme_vector_length(), 2, // Block size DataFormat::PackFormat::BIAS_PER_ROW, // Data layout DataType::FP16, // Bias format DataType::UNKNOWN, // Scaling type - 2 * get_sme_vector_length(), 2), // Sub-block - .bias_format = DataFormat(DataType::FP16), - - .fn_is_supported = cpu_has_sme2, - .fn_get_mr = nullptr, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme, - .fn_get_packed_lhs_size = nullptr, - .fn_get_packed_lhs_offset = nullptr, - .fn_pack_lhs = nullptr, - - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - - .fn_pack_rhs_nxk_get_n_step = nullptr, - .fn_pack_rhs_nxk_get_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_bias_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr, - .fn_pack_rhs_nxk_get_packed_rhs_size = nullptr, - .fn_pack_rhs_nxk = nullptr, - - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme, - - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - - .fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot, - .fn_matmul_f32_f32_f32p = nullptr, - .fn_matmul_f16_f16p_f16p = nullptr, - .fn_matmul_f32_f32p_f32p = nullptr, - }, - + 2 * get_sme_vector_length(), 2); // Sub-block + vecmul_methods[0].bias_format = DataFormat(DataType::FP16); + vecmul_methods[0].fn_is_supported = cpu_has_sme2; + vecmul_methods[0].fn_get_mr = nullptr; + vecmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; + vecmul_methods[0].fn_get_packed_lhs_size = nullptr; + vecmul_methods[0].fn_get_packed_lhs_offset = nullptr; + vecmul_methods[0].fn_pack_lhs = nullptr; + vecmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_pack_rhs_nxk_get_n_step = nullptr; + vecmul_methods[0].fn_pack_rhs_nxk_get_rhs_offset = nullptr; + vecmul_methods[0].fn_pack_rhs_nxk_get_bias_offset = nullptr; + vecmul_methods[0].fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr; + vecmul_methods[0].fn_pack_rhs_nxk_get_packed_rhs_size = nullptr; + vecmul_methods[0].fn_pack_rhs_nxk = nullptr; + vecmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_matmul_f32_f32_f32p = nullptr; + vecmul_methods[0].fn_matmul_f16_f16p_f16p = nullptr; + vecmul_methods[0].fn_matmul_f32_f32p_f32p = nullptr; + }; }; +MatMulMethodLoader loader; + /// Matrix multiplication test fixture. class MatMulTest : public testing::TestWithParam { private: @@ -462,18 +422,17 @@ protected: KAI_ERROR("Unsupported data type!"); } - const auto& data = _data[data_id] = { - .lhs = std::move(lhs), - .ref_packed_lhs = std::move(ref_packed_lhs), - .rhs = std::move(rhs), - .rhs_scales = std::move(rhs_scales), - .bias = std::move(bias), - .rhs_t = std::move(rhs_t), - .ref_packed_rhs = std::move(packed_rhs), - .ref_dst = std::move(ref_dst), - .clamp_min = clamp_min, - .clamp_max = clamp_max, - }; + auto& data = _data[data_id] = {}; + data.lhs = std::move(lhs); + data.ref_packed_lhs = std::move(ref_packed_lhs); + data.rhs = std::move(rhs); + data.rhs_scales = std::move(rhs_scales); + data.bias = std::move(bias); + data.rhs_t = std::move(rhs_t); + data.ref_packed_rhs = std::move(packed_rhs); + data.ref_dst = std::move(ref_dst); + data.clamp_min = clamp_min; + data.clamp_max = clamp_max; return data; } -- GitLab From 4a8abc13f9cf2e114e747d0a8eb74da9dd1dff99 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Thu, 6 Mar 2025 13:32:02 +0100 Subject: [PATCH 6/8] Fix copyright header. Signed-off-by: Jens Elofsson --- test/common/data_format.hpp | 2 +- test/common/matmul_test_common.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp index 7ff2596a..b326cec1 100644 --- a/test/common/data_format.hpp +++ b/test/common/data_format.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index a9a56c32..4c9c2bae 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // -- GitLab From aadd03843bedf148463db9742dcce9eae36c18a1 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Wed, 19 Mar 2025 09:37:21 +0100 Subject: [PATCH 7/8] Address review comments. - Initialize DataFormat members in struct MatMulMethod - Change name from MatMulMethodLoader to MatMulMethodInitializer - Add "-Wpedantic" compile flag to the testcases that had designated initializers removed - Remove unnecessary nullptr assignments. Signed-off-by: Jens Elofsson --- CMakeLists.txt | 6 ++ test/common/data_format.hpp | 1 - test/common/matmul_test_common.hpp | 14 ++--- .../matmul_clamp_f16_bf16p_bf16p_test.cpp | 10 ++-- .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 23 ++------ test/tests/matmul_test.cpp | 55 ++----------------- 6 files changed, 27 insertions(+), 82 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index fa6993d5..5aecb07e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -336,6 +336,12 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp test/tests/matmul_test.cpp ) + + set_source_files_properties( + test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp + test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp + test/tests/matmul_test.cpp + PROPERTIES COMPILE_FLAGS "-Wpedantic") endif() target_link_libraries(kleidiai_test diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp index b326cec1..4b5bd53e 100644 --- a/test/common/data_format.hpp +++ b/test/common/data_format.hpp @@ -23,7 +23,6 @@ public: QUANTIZE_PER_ROW, ///< Per-row quantization. }; - DataFormat() = default; /// Creates a new data format. /// /// @param[in] data_type Data type of data value. diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index 4c9c2bae..8237a498 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -28,18 +28,18 @@ struct MatMulShape { /// Matrix multiplication method. struct MatMulMethod { - std::string_view name = ""; ///< Name of matmul method. + std::string_view name = std::string_view{}; ///< Name of matmul method. size_t m0{0}; ///< Block size in M dimension. size_t n0{0}; ///< Block size in N dimension. size_t k0{0}; ///< Block size in K dimension. - DataFormat dst_format; ///< Data format of the destination matrix. - DataFormat lhs_format; ///< Data format of the LHS matrix. - DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix. - DataFormat rhs_format; ///< Data format of the RHS matrix. - DataFormat packed_rhs_format; ///< Data format of the packed RHS matrix. - DataFormat bias_format; ///< Data format of the bias vector. + DataFormat dst_format{DataType::UNKNOWN}; ///< Data format of the destination matrix. + DataFormat lhs_format{DataType::UNKNOWN}; ///< Data format of the LHS matrix. + DataFormat packed_lhs_format{DataType::UNKNOWN}; ///< Data format of the packed LHS matrix. + DataFormat rhs_format{DataType::UNKNOWN}; ///< Data format of the RHS matrix. + DataFormat packed_rhs_format{DataType::UNKNOWN}; ///< Data format of the packed RHS matrix. + DataFormat bias_format{DataType::UNKNOWN}; ///< Data format of the bias vector. /// Check if CPU supports required features. /// diff --git a/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp index e27ea6b4..548484a7 100644 --- a/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp @@ -39,10 +39,10 @@ namespace kai::test { /// List of supported matrix multiplication methods. namespace { -std::array matmul_methods; +std::array matmul_methods{}; -struct MatMulMethodLoader { - MatMulMethodLoader() { +struct MatMulMethodInitializer { + MatMulMethodInitializer() { matmul_methods[0].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla"; matmul_methods[0].m0 = 8; matmul_methods[0].n0 = 12; @@ -70,7 +70,6 @@ struct MatMulMethodLoader { matmul_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = nullptr; matmul_methods[0].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; @@ -106,7 +105,6 @@ struct MatMulMethodLoader { matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; - matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = nullptr; matmul_methods[1].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; @@ -117,7 +115,7 @@ struct MatMulMethodLoader { }; }; -MatMulMethodLoader loader; +MatMulMethodInitializer init{}; } // namespace diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 718a4a66..a17399ef 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -50,11 +50,11 @@ namespace kai::test { /// List of supported matrix multiplication methods. namespace { -std::array gemm_methods; -std::array gemv_methods; +std::array gemm_methods{}; +std::array gemv_methods{}; -struct MatMulMethodLoader { - MatMulMethodLoader() { +struct MatMulMethodInitializer { + MatMulMethodInitializer() { gemm_methods[0].name = "matmul_nt_nt_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa"; gemm_methods[0].m0 = 2 * get_sme_vector_length(); gemm_methods[0].n0 = 2 * get_sme_vector_length(); @@ -84,7 +84,6 @@ struct MatMulMethodLoader { gemm_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p2vlx2_f32_sme; gemm_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; gemm_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; - gemm_methods[0].fn_get_pack_rhs_packed_rhs_offset = nullptr; gemm_methods[0].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; gemm_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; @@ -119,10 +118,8 @@ struct MatMulMethodLoader { kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; gemm_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; gemm_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[1].fn_get_packed_rhs_size = nullptr; gemm_methods[1].fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[1].fn_get_pack_rhs_packed_rhs_offset = nullptr; gemm_methods[1].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; gemm_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; @@ -158,8 +155,6 @@ struct MatMulMethodLoader { gemm_methods[2].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; gemm_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; gemm_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[2].fn_get_packed_rhs_size_generic_block_size = nullptr; - gemm_methods[2].fn_get_pack_rhs_packed_rhs_offset = nullptr; gemm_methods[2].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; gemm_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; @@ -195,8 +190,6 @@ struct MatMulMethodLoader { gemm_methods[3].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; gemm_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; gemm_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; - gemm_methods[3].fn_get_packed_rhs_size_generic_block_size = nullptr; - gemm_methods[3].fn_get_pack_rhs_packed_rhs_offset = nullptr; gemm_methods[3].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; gemm_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; @@ -231,10 +224,8 @@ struct MatMulMethodLoader { kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; gemm_methods[4].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; gemm_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[4].fn_get_packed_rhs_size = nullptr; gemm_methods[4].fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemm_methods[4].fn_get_pack_rhs_packed_rhs_offset = nullptr; gemm_methods[4].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; gemm_methods[4].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; @@ -269,10 +260,8 @@ struct MatMulMethodLoader { kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; gemv_methods[0].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; gemv_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[0].fn_get_packed_rhs_size = nullptr; gemv_methods[0].fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[0].fn_get_pack_rhs_packed_rhs_offset = nullptr; gemv_methods[0].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; gemv_methods[0].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; @@ -307,10 +296,8 @@ struct MatMulMethodLoader { kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; gemv_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; gemv_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[1].fn_get_packed_rhs_size = nullptr; gemv_methods[1].fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; - gemv_methods[1].fn_get_pack_rhs_packed_rhs_offset = nullptr; gemv_methods[1].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; gemv_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; @@ -321,7 +308,7 @@ struct MatMulMethodLoader { }; }; -MatMulMethodLoader loader; +MatMulMethodInitializer init{}; } // namespace /// Matrix multiplication test fixture. diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index 006abf3c..f65bb9c9 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -61,13 +61,13 @@ namespace kai::test { /// List of supported matrix multiplication methods. -std::array matmul_methods; +std::array matmul_methods{}; /// List of supported vector by matrix multiplication methods -std::array vecmul_methods; +std::array vecmul_methods{}; -struct MatMulMethodLoader { - MatMulMethodLoader() { +struct MatMulMethodInitializer { + MatMulMethodInitializer() { matmul_methods[0].name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla"; matmul_methods[0].m0 = 6; matmul_methods[0].n0 = 16; @@ -79,7 +79,6 @@ struct MatMulMethodLoader { DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1); matmul_methods[0].bias_format = DataFormat(DataType::FP16); matmul_methods[0].fn_is_supported = cpu_has_fp16; - matmul_methods[0].fn_get_mr = nullptr; matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; @@ -87,9 +86,6 @@ struct MatMulMethodLoader { matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - matmul_methods[0].fn_get_packed_lhs_size = nullptr; - matmul_methods[0].fn_get_packed_lhs_offset = nullptr; - matmul_methods[0].fn_pack_lhs = nullptr; matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = @@ -97,19 +93,10 @@ struct MatMulMethodLoader { matmul_methods[0].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; - matmul_methods[0].fn_pack_rhs_nxk_get_n_step = nullptr; - matmul_methods[0].fn_pack_rhs_nxk_get_rhs_offset = nullptr; - matmul_methods[0].fn_pack_rhs_nxk_get_bias_offset = nullptr; - matmul_methods[0].fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr; - matmul_methods[0].fn_pack_rhs_nxk_get_packed_rhs_size = nullptr; - matmul_methods[0].fn_pack_rhs_nxk = nullptr; matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; matmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - matmul_methods[0].fn_matmul_f32_f32_f32p = nullptr; - matmul_methods[0].fn_matmul_f16_f16p_f16p = nullptr; - matmul_methods[0].fn_matmul_f32_f32p_f32p = nullptr; matmul_methods[1].name = "matmul_nt_nt_f16_f16p_f16p_2vlx2vl_sme2_mopa"; matmul_methods[1].m0 = 2 * get_sme_vector_length(); @@ -157,10 +144,7 @@ struct MatMulMethodLoader { matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_matmul_f16_f16_f16p = nullptr; - matmul_methods[1].fn_matmul_f32_f32_f32p = nullptr; matmul_methods[1].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_matmul_f32_f32p_f32p = nullptr; matmul_methods[2].name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla"; matmul_methods[2].m0 = 6; @@ -173,7 +157,6 @@ struct MatMulMethodLoader { DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1); matmul_methods[2].bias_format = DataFormat(DataType::FP32); matmul_methods[2].fn_is_supported = cpu_has_advsimd; - matmul_methods[2].fn_get_mr = nullptr; matmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; matmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; matmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; @@ -181,9 +164,6 @@ struct MatMulMethodLoader { matmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; matmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; matmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_get_packed_lhs_size = nullptr; - matmul_methods[2].fn_get_packed_lhs_offset = nullptr; - matmul_methods[2].fn_pack_lhs = nullptr; matmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; matmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; matmul_methods[2].fn_get_pack_rhs_packed_rhs_offset = @@ -191,19 +171,10 @@ struct MatMulMethodLoader { matmul_methods[2].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; matmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; - matmul_methods[2].fn_pack_rhs_nxk_get_n_step = nullptr; - matmul_methods[2].fn_pack_rhs_nxk_get_rhs_offset = nullptr; - matmul_methods[2].fn_pack_rhs_nxk_get_bias_offset = nullptr; - matmul_methods[2].fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr; - matmul_methods[2].fn_pack_rhs_nxk_get_packed_rhs_size = nullptr; - matmul_methods[2].fn_pack_rhs_nxk = nullptr; matmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; matmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; matmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_matmul_f16_f16_f16p = nullptr; matmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_matmul_f16_f16p_f16p = nullptr; - matmul_methods[2].fn_matmul_f32_f32p_f32p = nullptr; matmul_methods[3].name = "matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa"; matmul_methods[3].m0 = 2 * get_sme_vector_length(); @@ -248,9 +219,6 @@ struct MatMulMethodLoader { matmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; matmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; matmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - matmul_methods[3].fn_matmul_f16_f16_f16p = nullptr; - matmul_methods[3].fn_matmul_f32_f32_f32p = nullptr; - matmul_methods[3].fn_matmul_f16_f16p_f16p = nullptr; matmul_methods[3].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; vecmul_methods[0].name = "vecmul_kxn_f16_f16_f16p2vlx2b_1x16vl_sme2_dot"; @@ -269,7 +237,6 @@ struct MatMulMethodLoader { 2 * get_sme_vector_length(), 2); // Sub-block vecmul_methods[0].bias_format = DataFormat(DataType::FP16); vecmul_methods[0].fn_is_supported = cpu_has_sme2; - vecmul_methods[0].fn_get_mr = nullptr; vecmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; vecmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; vecmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; @@ -277,9 +244,6 @@ struct MatMulMethodLoader { vecmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; vecmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; vecmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; - vecmul_methods[0].fn_get_packed_lhs_size = nullptr; - vecmul_methods[0].fn_get_packed_lhs_offset = nullptr; - vecmul_methods[0].fn_pack_lhs = nullptr; vecmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; vecmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; vecmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = @@ -287,23 +251,14 @@ struct MatMulMethodLoader { vecmul_methods[0].fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; vecmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - vecmul_methods[0].fn_pack_rhs_nxk_get_n_step = nullptr; - vecmul_methods[0].fn_pack_rhs_nxk_get_rhs_offset = nullptr; - vecmul_methods[0].fn_pack_rhs_nxk_get_bias_offset = nullptr; - vecmul_methods[0].fn_pack_rhs_nxk_get_packed_rhs_offset = nullptr; - vecmul_methods[0].fn_pack_rhs_nxk_get_packed_rhs_size = nullptr; - vecmul_methods[0].fn_pack_rhs_nxk = nullptr; vecmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; vecmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; vecmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; vecmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; - vecmul_methods[0].fn_matmul_f32_f32_f32p = nullptr; - vecmul_methods[0].fn_matmul_f16_f16p_f16p = nullptr; - vecmul_methods[0].fn_matmul_f32_f32p_f32p = nullptr; }; }; -MatMulMethodLoader loader; +MatMulMethodInitializer init{}; /// Matrix multiplication test fixture. class MatMulTest : public testing::TestWithParam { -- GitLab From 74e2fa192ae7e9cfe140f22f501d31c4ba04aeeb Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Thu, 20 Mar 2025 16:01:57 +0100 Subject: [PATCH 8/8] Address review comments - Add default data_type argument in DataFormat constructor - Semantic change to the initialization of name in MatMulMethod Signed-off-by: Jens Elofsson --- test/common/data_format.hpp | 6 +++--- test/common/matmul_test_common.hpp | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/common/data_format.hpp b/test/common/data_format.hpp index 4b5bd53e..730dd86e 100644 --- a/test/common/data_format.hpp +++ b/test/common/data_format.hpp @@ -34,9 +34,9 @@ public: /// @param[in] subblock_height Sub-block height. /// @param[in] subblock_width Sub-block width. DataFormat( - DataType data_type, size_t block_height = 0, size_t block_width = 0, PackFormat pack_format = PackFormat::NONE, - DataType zero_point_dt = DataType::UNKNOWN, DataType scale_dt = DataType::UNKNOWN, size_t subblock_height = 0, - size_t subblock_width = 0) noexcept; + DataType data_type = DataType::UNKNOWN, size_t block_height = 0, size_t block_width = 0, + PackFormat pack_format = PackFormat::NONE, DataType zero_point_dt = DataType::UNKNOWN, + DataType scale_dt = DataType::UNKNOWN, size_t subblock_height = 0, size_t subblock_width = 0) noexcept; /// Equality operator. [[nodiscard]] bool operator==(const DataFormat& rhs) const; diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp index 8237a498..5b3e2424 100644 --- a/test/common/matmul_test_common.hpp +++ b/test/common/matmul_test_common.hpp @@ -28,18 +28,18 @@ struct MatMulShape { /// Matrix multiplication method. struct MatMulMethod { - std::string_view name = std::string_view{}; ///< Name of matmul method. + std::string_view name{}; ///< Name of matmul method. size_t m0{0}; ///< Block size in M dimension. size_t n0{0}; ///< Block size in N dimension. size_t k0{0}; ///< Block size in K dimension. - DataFormat dst_format{DataType::UNKNOWN}; ///< Data format of the destination matrix. - DataFormat lhs_format{DataType::UNKNOWN}; ///< Data format of the LHS matrix. - DataFormat packed_lhs_format{DataType::UNKNOWN}; ///< Data format of the packed LHS matrix. - DataFormat rhs_format{DataType::UNKNOWN}; ///< Data format of the RHS matrix. - DataFormat packed_rhs_format{DataType::UNKNOWN}; ///< Data format of the packed RHS matrix. - DataFormat bias_format{DataType::UNKNOWN}; ///< Data format of the bias vector. + DataFormat dst_format{}; ///< Data format of the destination matrix. + DataFormat lhs_format{}; ///< Data format of the LHS matrix. + DataFormat packed_lhs_format{}; ///< Data format of the packed LHS matrix. + DataFormat rhs_format{}; ///< Data format of the RHS matrix. + DataFormat packed_rhs_format{}; ///< Data format of the packed RHS matrix. + DataFormat bias_format{}; ///< Data format of the bias vector. /// Check if CPU supports required features. /// -- GitLab