From 3970b948b9787f160ad7165d3364f409bf66430d Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Mon, 7 Apr 2025 15:54:03 +0200 Subject: [PATCH 1/2] Change kernel lists to use lazy initialization There is an issue where the order of static initializations has no guaranteed order, which can cause test listing to be initialized before list of kernels. This can be solved by lazily initialize kernel lists on first use. This patch applies this fix for `matmul_test.cpp` Signed-off-by: Emil Ohlsson --- test/tests/matmul_test.cpp | 393 ++++++++++++++++++------------------- 1 file changed, 196 insertions(+), 197 deletions(-) diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index f65bb9c9..3360983e 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -61,204 +61,203 @@ namespace kai::test { /// List of supported matrix multiplication methods. -std::array matmul_methods{}; /// List of supported vector by matrix multiplication methods -std::array vecmul_methods{}; - -struct MatMulMethodInitializer { - MatMulMethodInitializer() { - matmul_methods[0].name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla"; - matmul_methods[0].m0 = 6; - matmul_methods[0].n0 = 16; - matmul_methods[0].dst_format = DataFormat(DataType::FP16); - matmul_methods[0].lhs_format = DataFormat(DataType::FP16); - matmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); - matmul_methods[0].rhs_format = DataFormat(DataType::FP16); - matmul_methods[0].packed_rhs_format = DataFormat( - DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1); - matmul_methods[0].bias_format = DataFormat(DataType::FP16); - matmul_methods[0].fn_is_supported = cpu_has_fp16; - matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; - matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; - matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; - matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = - kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; - matmul_methods[0].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; - matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; - matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - matmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; - - matmul_methods[1].name = "matmul_nt_nt_f16_f16p_f16p_2vlx2vl_sme2_mopa"; - matmul_methods[1].m0 = 2 * get_sme_vector_length(); - matmul_methods[1].n0 = 2 * get_sme_vector_length(); - matmul_methods[1].dst_format = DataFormat(DataType::FP16); - matmul_methods[1].lhs_format = DataFormat(DataType::FP16); - matmul_methods[1].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length(), 2); - matmul_methods[1].rhs_format = DataFormat(DataType::FP16); - matmul_methods[1].packed_rhs_format = DataFormat( - DataType::FP16, // Output type - 2 * get_sme_vector_length(), 2, // Block size - DataFormat::PackFormat::BIAS_PER_ROW, // Data layout - DataType::FP16, // Bias format - DataType::UNKNOWN, // Scaling type - 2 * get_sme_vector_length(), 2); // Sub-block - matmul_methods[1].bias_format = DataFormat(DataType::FP16); - matmul_methods[1].fn_is_supported = cpu_has_sme2; - matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; - matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme; - matmul_methods[1].fn_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme; - matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = - kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_offset = - kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - matmul_methods[1].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; - - matmul_methods[2].name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla"; - matmul_methods[2].m0 = 6; - matmul_methods[2].n0 = 8; - matmul_methods[2].dst_format = DataFormat(DataType::FP32); - matmul_methods[2].lhs_format = DataFormat(DataType::FP32); - matmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN); - matmul_methods[2].rhs_format = DataFormat(DataType::FP32); - matmul_methods[2].packed_rhs_format = DataFormat( - DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1); - matmul_methods[2].bias_format = DataFormat(DataType::FP32); - matmul_methods[2].fn_is_supported = cpu_has_advsimd; - matmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; - matmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; - matmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; - matmul_methods[2].fn_get_pack_rhs_packed_rhs_offset = - kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; - matmul_methods[2].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; - matmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; - matmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - matmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; - - matmul_methods[3].name = "matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa"; - matmul_methods[3].m0 = 2 * get_sme_vector_length(); - matmul_methods[3].n0 = 2 * get_sme_vector_length(); - matmul_methods[3].dst_format = DataFormat(DataType::FP32); - matmul_methods[3].lhs_format = DataFormat(DataType::FP32); - matmul_methods[3].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length(), 1); - matmul_methods[3].rhs_format = DataFormat(DataType::FP32); - matmul_methods[3].packed_rhs_format = DataFormat( - DataType::FP32, 2 * get_sme_vector_length(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, - DataType::UNKNOWN, 2 * get_sme_vector_length(), 1); - matmul_methods[3].bias_format = DataFormat(DataType::FP32); - matmul_methods[3].fn_is_supported = cpu_has_sme2; - matmul_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - matmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - matmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - matmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - matmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - matmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - matmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme; - matmul_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme; - matmul_methods[3].fn_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - matmul_methods[3].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme; - matmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_get_pack_rhs_packed_rhs_offset = - kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - matmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_pack_rhs_nxk_get_bias_offset = - kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_offset = - kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; - matmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - matmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - matmul_methods[3].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; - - vecmul_methods[0].name = "vecmul_kxn_f16_f16_f16p2vlx2b_1x16vl_sme2_dot"; - vecmul_methods[0].m0 = 1; - vecmul_methods[0].n0 = 16 * get_sme_vector_length(); - vecmul_methods[0].dst_format = DataFormat(DataType::FP16); - vecmul_methods[0].lhs_format = DataFormat(DataType::FP16); - vecmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); - vecmul_methods[0].rhs_format = DataFormat(DataType::FP16); - vecmul_methods[0].packed_rhs_format = DataFormat( - DataType::FP16, // Output type - 2 * get_sme_vector_length(), 2, // Block size - DataFormat::PackFormat::BIAS_PER_ROW, // Data layout - DataType::FP16, // Bias format - DataType::UNKNOWN, // Scaling type - 2 * get_sme_vector_length(), 2); // Sub-block - vecmul_methods[0].bias_format = DataFormat(DataType::FP16); - vecmul_methods[0].fn_is_supported = cpu_has_sme2; - vecmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; - vecmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; - vecmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; - vecmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; - vecmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - vecmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; - vecmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; - vecmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - vecmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - vecmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = - kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - vecmul_methods[0].fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; - vecmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - vecmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; - vecmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; - vecmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; - vecmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; - }; -}; -MatMulMethodInitializer init{}; +static const std::array& get_matmul_methods() { + static std::array matmul_methods{}; + matmul_methods[0].name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla"; + matmul_methods[0].m0 = 6; + matmul_methods[0].n0 = 16; + matmul_methods[0].dst_format = DataFormat(DataType::FP16); + matmul_methods[0].lhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); + matmul_methods[0].rhs_format = DataFormat(DataType::FP16); + matmul_methods[0].packed_rhs_format = DataFormat( + DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1); + matmul_methods[0].bias_format = DataFormat(DataType::FP16); + matmul_methods[0].fn_is_supported = cpu_has_fp16; + matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; + matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + matmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; + + matmul_methods[1].name = "matmul_nt_nt_f16_f16p_f16p_2vlx2vl_sme2_mopa"; + matmul_methods[1].m0 = 2 * get_sme_vector_length(); + matmul_methods[1].n0 = 2 * get_sme_vector_length(); + matmul_methods[1].dst_format = DataFormat(DataType::FP16); + matmul_methods[1].lhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length(), 2); + matmul_methods[1].rhs_format = DataFormat(DataType::FP16); + matmul_methods[1].packed_rhs_format = DataFormat( + DataType::FP16, // Output type + 2 * get_sme_vector_length(), 2, // Block size + DataFormat::PackFormat::BIAS_PER_ROW, // Data layout + DataType::FP16, // Bias format + DataType::UNKNOWN, // Scaling type + 2 * get_sme_vector_length(), 2); // Sub-block + matmul_methods[1].bias_format = DataFormat(DataType::FP16); + matmul_methods[1].fn_is_supported = cpu_has_sme2; + matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[1].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + matmul_methods[1].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; + + matmul_methods[2].name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla"; + matmul_methods[2].m0 = 6; + matmul_methods[2].n0 = 8; + matmul_methods[2].dst_format = DataFormat(DataType::FP32); + matmul_methods[2].lhs_format = DataFormat(DataType::FP32); + matmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN); + matmul_methods[2].rhs_format = DataFormat(DataType::FP32); + matmul_methods[2].packed_rhs_format = + DataFormat(DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1); + matmul_methods[2].bias_format = DataFormat(DataType::FP32); + matmul_methods[2].fn_is_supported = cpu_has_advsimd; + matmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; + matmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + matmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; + + matmul_methods[3].name = "matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa"; + matmul_methods[3].m0 = 2 * get_sme_vector_length(); + matmul_methods[3].n0 = 2 * get_sme_vector_length(); + matmul_methods[3].dst_format = DataFormat(DataType::FP32); + matmul_methods[3].lhs_format = DataFormat(DataType::FP32); + matmul_methods[3].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length(), 1); + matmul_methods[3].rhs_format = DataFormat(DataType::FP32); + matmul_methods[3].packed_rhs_format = DataFormat( + DataType::FP32, 2 * get_sme_vector_length(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, + DataType::UNKNOWN, 2 * get_sme_vector_length(), 1); + matmul_methods[3].bias_format = DataFormat(DataType::FP32); + matmul_methods[3].fn_is_supported = cpu_has_sme2; + matmul_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme; + matmul_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme; + matmul_methods[3].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme; + matmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_pack_rhs_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; + matmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + matmul_methods[3].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; + + return matmul_methods; +} + +static const std::array& get_vecmul_methods() { + static std::array vecmul_methods{}; + + vecmul_methods[0].name = "vecmul_kxn_f16_f16_f16p2vlx2b_1x16vl_sme2_dot"; + vecmul_methods[0].m0 = 1; + vecmul_methods[0].n0 = 16 * get_sme_vector_length(); + vecmul_methods[0].dst_format = DataFormat(DataType::FP16); + vecmul_methods[0].lhs_format = DataFormat(DataType::FP16); + vecmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); + vecmul_methods[0].rhs_format = DataFormat(DataType::FP16); + vecmul_methods[0].packed_rhs_format = DataFormat( + DataType::FP16, // Output type + 2 * get_sme_vector_length(), 2, // Block size + DataFormat::PackFormat::BIAS_PER_ROW, // Data layout + DataType::FP16, // Bias format + DataType::UNKNOWN, // Scaling type + 2 * get_sme_vector_length(), 2); // Sub-block + vecmul_methods[0].bias_format = DataFormat(DataType::FP16); + vecmul_methods[0].fn_is_supported = cpu_has_sme2; + vecmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; + vecmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + vecmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + vecmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; + + return vecmul_methods; +} /// Matrix multiplication test fixture. class MatMulTest : public testing::TestWithParam { @@ -691,7 +690,7 @@ TEST_P(MatMulTest, Output) { INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTest, testing::Combine( - testing::ValuesIn(matmul_methods), + testing::ValuesIn(get_matmul_methods()), testing::Values( MatMulShape{1, 16, 16}, // MatMulShape{20, 1, 20}, // @@ -710,7 +709,7 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P( VecMul, MatMulTest, testing::Combine( - testing::ValuesIn(vecmul_methods), + testing::ValuesIn(get_vecmul_methods()), testing::Values( MatMulShape{1, 16, 16}, // MatMulShape{1, 1, 20}, // -- GitLab From 95363191599370bcfb36fdd15d785e26706eb3ae Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Mon, 7 Apr 2025 16:11:25 +0200 Subject: [PATCH 2/2] Move comments to match moved initializations Signed-off-by: Emil Ohlsson --- test/tests/matmul_test.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index 3360983e..f22a593e 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -60,12 +60,10 @@ namespace kai::test { -/// List of supported matrix multiplication methods. - -/// List of supported vector by matrix multiplication methods - static const std::array& get_matmul_methods() { + // List of supported matrix multiplication methods. static std::array matmul_methods{}; + matmul_methods[0].name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla"; matmul_methods[0].m0 = 6; matmul_methods[0].n0 = 16; @@ -220,6 +218,7 @@ static const std::array& get_matmul_methods() { } static const std::array& get_vecmul_methods() { + // List of supported vector by matrix multiplication methods static std::array vecmul_methods{}; vecmul_methods[0].name = "vecmul_kxn_f16_f16_f16p2vlx2b_1x16vl_sme2_dot"; -- GitLab